import pdb

from mitmproxy.flow import Flow
import threading
import time
import zmq
import json
import os
from enum import Enum
from dataclasses import dataclass
from typing import List, Dict, Any
from queue import Queue, Empty


# this method is used to convert flow states (generated with get_state()) to json
def convert_to_strings(obj: Any) -> Any:
    if isinstance(obj, dict):
        return {convert_to_strings(key): convert_to_strings(value)
                for key, value in obj.items()}
    elif isinstance(obj, list) or isinstance(obj, tuple):
        return [convert_to_strings(element) for element in obj]
    elif isinstance(obj, bytes):
        try:
            data = obj.decode('unicode-escape').encode('latin1').decode('utf-8')
        except:
            print(obj)
            data = str(obj)[2:-1]
        return data

    return obj


@dataclass
class bHeader:
    key: str
    value: str


@dataclass
class bRequest:
    server_ip_address: str

    tls: str
    content: str
    scheme: str
    method: str
    host: str
    port: int
    http_version: str
    path: str
    timestamp_start: float
    timestamp_end: float
    # [("Header","Data")]
    headers: List[bHeader]

    error: str

    # init from flow dict
    def __init__(self, flow: dict):
        flow = convert_to_strings(flow)

        self.server_ip_address = flow["server_conn"]["ip_address"][0]
        self.tls = flow["server_conn"]["tls_established"]
        self.content = flow["request"]["content"]
        self.scheme = flow["request"]["scheme"]
        self.method = flow["request"]["method"]
        self.host = flow["request"]["host"]
        self.port = flow["request"]["port"]
        self.http_version = flow["request"]["http_version"]
        self.timestamp_start = flow["request"]["timestamp_start"]
        self.timestamp_end = flow["request"]["timestamp_end"]

        self.headers = []
        for k, v in flow["request"].get("headers", {}):
            self.headers.append(bHeader(str(k), str(v)))

    def json(self) -> dict:
        state = vars(self).copy()
        state["headers"] = {h.key: h.value for h in self.headers}
        return state


@dataclass
class bResponse:
    status_code: int
    http_version: str
    reason: str
    content: str
    timestamp_start: float
    timestamp_end: float
    # [("Header","Data")]
    headers: List[bHeader]

    def __init__(self, flow: dict):
        flow = convert_to_strings(flow)
        self.status_code = flow["response"]["status_code"]
        self.http_version = flow["response"]["http_version"]
        self.reason = flow["response"]["reason"]
        self.content = flow["response"]["content"]
        self.timestamp_start = flow["response"]["timestamp_start"]
        self.timestamp_end = flow["response"]["timestamp_end"]

        self.headers = []
        for k, v in flow["response"].get("headers", {}):
            self.headers.append(bHeader(k, v))

    def json(self) -> dict:
        state = vars(self).copy()
        state["headers"] = [{h.key: h.value} for h in self.headers]
        return state


@dataclass
class bFlowState(Enum):
    ERROR = 0
    UNSENT_HTTP_REQUEST = 1
    SENT_HTTP_REQUEST = 2
    UNSENT_HTTP_RESPONSE = 3
    SENT_HTTP_RESPONSE = 4


@dataclass
class bPacketType:
    NACK = 0
    ACK = 1
    KILL = 2
    WARNING = 3
    ERROR = 4
    PING = 5
    HTTP_REQUEST = 6
    HTTP_RESPONSE = 7


@dataclass
class bPacket:
    ptype: bPacketType
    flowid: str
    data: str


@dataclass
class FlowItem:
    state: bFlowState
    flow: Flow
    time: float = 0
    retries: int = 5


"""

The network thread communicates with the bigsnitch plugin using zeromq.

"""
class NetworkThread(threading.Thread):
    def __init__(self, name: str, queue: Queue, path: str = None):
        threading.Thread.__init__(self)
        self.name = name
        # path
        self.path = path
        if not self.path:
            self.path = os.environ.get("BIGSNITCH_PATH", None)
        if not self.path:
            self.path = "tcp://127.0.0.1:12345"

        # queue for communicating with the main mitmproxy thread
        # contains tuples of (id, FlowItem)
        self.q = queue
        # for zmq use
        self.context = zmq.Context()
        # all current self.flows being handled by mitmproxy
        self.flows: Dict[FlowItem] = {}
        # timer for sending pings to check if the connection broke
        self.timer = time.monotonic()
        # retries left for reconnecting
        self.retries = 5

    # send a single message, no checks involved
    def send(self, msg):
        #a = convert_to_strings(msg)
        self.socket.send(str.encode(json.dumps(msg)))

    # add new self.flows from the queue
    def get_new_flows(self):
        while True:
            try:
                # get new self.flows that may occured
                i, flowitem = self.q.get(block=False)
                if self.flows.get(i, None):
                    raise ValueError(f"flow {i} doubled? ignoring...")
                    continue
                else:
                    self.flows[i] = flowitem

            except Empty:
                break

    def send_packet(self, pkg: bPacket):
        msg = {"type": pkg.ptype, "id": pkg.flowid, "data": pkg.data}
        self.send(msg)

    # update all current self.flows
    # handles the state machine for each flow
    def update_flows(self):
        # force copy of item list, so we can remove dict items in the loop
        for id, flow in list(self.flows.items()):
            if self.flows[id].retries <= 0:
                if self.flows[id].flow:
                    self.flows[id].flow.kill()
                print(f"http flow {id} timed out! flow killed.")
                del self.flows[id]

            delta = time.monotonic() - self.flows[id].time

            if flow.state == bFlowState.UNSENT_HTTP_REQUEST or \
                    flow.state == bFlowState.SENT_HTTP_REQUEST and delta > 5:
                pkg = bPacket(bPacketType.HTTP_REQUEST, id, bRequest(flow.flow.get_state()).json())
                self.send_packet(pkg)

                self.flows[id].time = time.monotonic()
                self.flows[id].state = bFlowState.SENT_HTTP_REQUEST
                self.flows[id].retries -= 1

            elif flow.state == bFlowState.UNSENT_HTTP_RESPONSE or \
                    flow.state == bFlowState.SENT_HTTP_RESPONSE and delta > 5:
                pkg = bPacket(bPacketType.HTTP_RESPONSE, id, bResponse(flow.flow.get_state()).json())
                self.send_packet(pkg)

                self.flows[id].time = time.monotonic()
                self.flows[id].state = bFlowState.SENT_HTTP_RESPONSE
                self.flows[id].retries -= 1

            if flow.state == bFlowState.ERROR:
                print(f"error in flow {id}!")
                del self.flows[id]

    def handle_packet(self, pkg):
        flow = self.flows.get(pkg.flowid, None)

        # flow ACKed
        if pkg.ptype == bPacketType.ACK:
            if flow and flow.flow:
                flow.flow.resume()
            else:
                raise ValueError("unknown flow")

        # flow killed
        elif pkg.ptype == bPacketType.KILL:
            if flow and flow.flow:
                flow.flow.kill()
            else:
                raise ValueError("unknown flow")

            if flow:
                del self.flows[pkg.flowid]

        else:
            print(f"got unexpected message {pkg.ptype}")

    # handle incoming packets / update the statemachine
    def handle_packets(self):
        while((self.socket.poll(50) & zmq.POLLIN) != 0):
            msg = self.socket.recv()
            try:
                if msg:
                    result = json.loads(str(msg))
                    pkg = bPacket(json=result)
                    self.handle_packet(pkg)
            except json.JSONDecodeError:
                print(f"malformed message received {msg}")

    def run(self):
        print("thread started")
        self.connect()
        while True:
            self.timer = time.monotonic()
            self.get_new_flows()
            self.handle_packets()
            self.update_flows()

            if self.timer - time.monotonic() < -5:
                pass
                #self.send_msg_and_ack({"msg": "ping"})

    def disconnect(self):
        self.socket.setsockopt(zmq.LINGER,0)
        self.socket.close()
        print("disconnected")

    def reconnect(self):
        print("reconnecting")
        self.disconnect()
        time.sleep(1)
        self.connect()

    def connect(self):
        self.socket = self.context.socket(zmq.PAIR)
        self.socket.connect(self.path)
        #self.send_msg_and_ack({"msg": "ping"})
        print("connected")

    """
    def send_msg_and_ack(self, msg):
            self.timer = time.monotonic()
            while True:
                    #print("m sending")
                    self.send(msg)
                    if (self.socket.poll(50) & zmq.POLLIN) != 0:
                            msg = self.socket.recv()
                            try:
                                    if msg:
                                            result = json.loads(msg)
                                            if result["msg"] == "ack":
                                                    print("m ack received")
                                                    return result
                                            else:
                                                    print("got unexpected message {result}")
                                    
                            except json.JSONDecodeError:
                                    print(f"malformed message received {msg}")
                    print("no ack received, reconnecting...")
                    self.reconnect()
            return NO_MSG
    """