From f91da7a47cc9aa92b86e8fe3f953692ac27cd951 Mon Sep 17 00:00:00 2001 From: End Date: Wed, 12 Aug 2020 15:11:29 +0200 Subject: [PATCH] statemachine --- mitmaddon/littlesnitch.py | 78 +++++++++++++++++++++++++++------------ 1 file changed, 54 insertions(+), 24 deletions(-) diff --git a/mitmaddon/littlesnitch.py b/mitmaddon/littlesnitch.py index 3eb833e..d5b4c59 100644 --- a/mitmaddon/littlesnitch.py +++ b/mitmaddon/littlesnitch.py @@ -19,7 +19,7 @@ class NetworkState(Enum): CONNECTED = auto() PING = auto() SENDING = auto() - + def convert_to_strings(obj): if isinstance(obj, dict): return {convert_to_strings(key): convert_to_strings(value) @@ -30,44 +30,74 @@ def convert_to_strings(obj): return str(obj)[2:-1] return obj -def get_msg(socket): - - msg = socket.recv() - try: - if msg: - return json.loads(msg) - except json.JSONDecodeError: - print(f"malformed message received {msg}") +class NetworkThread(threading.Thread): + def __init__(self, name, queue): + threading.Thread.__init__(self) + self.name = name + self.q = queue + self.context = zmq.Context() - return NO_MSG + def run(self): + self.connect() + msg = self.send_msg_and_expect() + def disconnect(self): + self.socket.setsockopt(zmq.LINGER,0) + self.socket.close() + + def reconnect(self): + self.disconnect() + self.connect() -def send_msg(msg, socket): - a = convert_to_strings(msg) - socket.send(str.encode(json.dumps(a))) + def connect(self): + self.socket = self.context.socket(zmq.PAIR) + self.socket.connect("tcp://127.0.0.1:12345") + def send_msg_and_expect(this, msg, expect, timeout=5, retries=3): + while retries: + a = convert_to_strings(msg) + self.socket.send(str.encode(json.dumps(a))) + if (client.poll(REQUEST_TIMEOUT) & zmq.POLLIN) != 0: + msg = self.socket.recv() + try: + if msg: + result = json.loads(msg) + if result["msg"] in expect: + return result + else: + print("got unexpected message {result}") + + except json.JSONDecodeError: + print(f"malformed message received {msg}") + retries -= 1 + self.reconnect() + return NO_MSG + + """ def networking(q): print("starting thread") - context = zmq.Context() - connected = False + state = NetworkState.DISCONNECTED a = None - while not connected: + while state == NetworkState.DISCONNECTED: socket = context.socket(zmq.PAIR) socket.connect("tcp://127.0.0.1:12345") msg = get_msg(socket) if msg["msg"] == "init": send_msg(ACK_MSG, socket) - connected = True + state = NetworkState.CONNECTED + + + timer = time.monotonic() - while connected: - if timer - time.monotonic() >= 5: + while state != NetworkState.DISCONNECTED: + if state == NetworkState.CONNECTED and timer - time.monotonic() >= 5: timer = time.monotonic() send_msg(PING_MSG,socket) msg = get_msg(socket) if msg["msg"] != "pong": - connected = False + state = NetworkState. msg = get_msg(socket) if msg['msg'] == "ping": @@ -88,11 +118,11 @@ def networking(q): self.q.task_done() else: connected = False - -class Counter: +""" +class LittleSnitchBridge: def __init__(self): self.q = Queue() - self.thread = threading.Thread(name="NetworkThread", target=networking, args=(self.q,)) + self.thread = NetworkThread("network", self.q) self.thread.start() def request(self, flow): @@ -104,5 +134,5 @@ class Counter: self.q.join() addons = [ - Counter() + LittleSnitchBridge() ]