import socket import ssl from time import sleep from queue import Queue from threading import Thread, Event, Lock, get_native_id from select import select class MicroProxy: ### CLS FUNCTIONS ############################################################# wait_symbols = "-\|/" def _get_socket(proxy_type="TCP"): return socket.socket( socket.AF_INET, ( socket.SOCK_STREAM if proxy_type == "TCP" else socket.SOCK_DGRAM ) ) def init_tunnle(request, sock_in, sock_out, host, port): if request.startswith(b"CONNECT"): try: sock_out.connect((host,port)) sock_in.sendall(b"HTTP/1.1 200 established\r\n\r\n") except Exception as e: print("Cannot initiate proxy tunnel:", e) def proxy_forward_filter(request): #looks ugly, yes; but is able to run on Pico W micro-controller :D header = request.split('\n')[0] url = header.split()[1] port = 80 protocol = None has_port = False has_protocol = False if url.startswith("http"): protocol, host_part = url.split('://') has_protocol = True else: host_part = url if ":" in host_part: splitter = host_part.split(':') host_domain = splitter[0] port = int(splitter[1]) has_port = True elif "/" in host_part: host_domain = host_part.split('/')[0] if not has_protocol and has_port: if port == 443: protocol = "https" else: protocol = "http" if not has_port: if protocol == "https": port = 443 else: port = 80 return (protocol, host_domain, port) ### OBJ FUNCTIONS ############################################################# def __init__(self, buf_byte_size=4096, client_timeout=0.5): self._buf_byte_size = buf_byte_size self._client_timeout = client_timeout self._listener_event = Event() self._is_listening = False self._incoming = [] self._outgoing = [] self._channel_map = {} self._channel_init = {} self._channel_from_client = [] def _set_listener(self): # Check what kind of socket is needed to # bind onto. # Take the first possible socket and the # required IP info for binding. self._addr_listen = socket.getaddrinfo( self._addr, self._port )[0][-1] if hasattr(self, '_socket_listen') and self._socket_listen is not None: self.stop() self._socket_listen = MicroProxy._get_socket() self._socket_listen.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) def listen(self, addr, port, proxy_type="TCP", backlog=0): if not self._is_listening: self._addr = addr self._port = port self._proxy_type = "UDP" if not proxy_type == "TCP" else proxy_type self._set_listener() self._socket_listen.bind(self._addr_listen) self._socket_listen.listen(backlog) self._incoming.append(self._socket_listen) self._listen_thread = Thread( target=self._listener_thread, args=(self._listener_event,) ) self._listen_thread.start() print(f"init done for serving on {self._addr_listen}") self._is_listening = True def stop(self): if self._is_listening: self._listener_event.set() ctr = 0 while self._listen_thread.is_alive(): print( ( "Waiting for listener thread to finish... " f"{MicroProxy.wait_symbols[ctr%len(MicroProxy.wait_symbols)]}\r" ), end="" ) ctr += 1 sleep(0.5) else: print("Listener thread finished closing safely.") self._socket_listen.close() self._is_listening = False self._incoming.clear() self._outgoing.clear() self._channel_map.clear() self._channel_init.clear() def join(self): if self._is_listening: self._listen_thread.join() def _listener_thread(self, event): while self._incoming and not event.is_set(): inrecv, outsend, excpt = select( self._incoming, self._outgoing, self._incoming ) for sock in inrecv: if sock is self._socket_listen: self._handle_connection_incoming() elif ( sock in self._channel_init and not self._channel_init[sock] and sock not in self._channel_from_client ): continue else: data = sock.recv(self._buf_byte_size) if data: self._handle_connection_receive(sock, data) else: self._handle_connection_close(sock) event.clear() def _handle_connection_incoming(self): conn, addr = self._socket_listen.accept() conn.settimeout(self._client_timeout) reverse_conn = MicroProxy._get_socket(self._proxy_type) reverse_conn.settimeout(self._client_timeout) self._channel_from_client.append(conn) self._incoming.append(conn) self._channel_map[conn] = reverse_conn self._channel_map[reverse_conn] = conn self._channel_init[conn] = False self._channel_init[reverse_conn] = False def _handle_connection_receive(self, sock, data): reverse_sock = self._channel_map[sock] if not self._channel_init[sock] and not self._channel_init[reverse_sock]: protocol, host_domain, port = MicroProxy.proxy_forward_filter(data.decode()) if protocol == "https" or port == 443: MicroProxy.init_tunnle( data, sock, reverse_sock, host_domain, port ) else: reverse_sock.connect((host_domain,port)) #not a tunnel request, directly forward reverse_sock.sendall(data) self._incoming.append(reverse_sock) self._channel_init[sock] = True self._channel_init[reverse_sock] = True else: reverse_sock.sendall(data) def _handle_connection_close(self, sock): reverse_sock = self._channel_map[sock] for s in (sock, reverse_sock): if s in self._outgoing: self._outgoing.remove(s) if s in self._incoming: self._incoming.remove(s) if s in self._channel_from_client: self._channel_from_client.remove(s) s.close() del self._channel_init[s] del self._channel_map[s] def main(): mitm = MicroProxy() mitm.listen(addr='0.0.0.0', port=8080) if __name__ == "__main__": main() class ThreadProxy: wait_symbols = "-\|/" def __init__( self, n_threads_max = 2, buf_byte_size=4096, client_timeout=0.5): self._n_threads = 0 self._n_threads_max = n_threads_max self._buf_byte_size = buf_byte_size self._client_timeout = client_timeout self._thread_events = [] self._threads = [] self._job_queue = Queue() self._threads_lock = Lock() self._max_lock = Lock() self._listener_event = Event() self._is_listening = False def set_max_thread_count(self, n_threads_max): self._max_lock.acquire() self._n_threads_max = n_threads_max if self._n_threads > self._n_threads_max: self.rescale(self._n_threads_max) self._max_lock.release() def max_thread_count(self): return self._n_threads_max def thread_count(self): return self._n_threads def _get_socket(self): return socket.socket( socket.AF_INET, ( socket.SOCK_STREAM if self._proxy_type == "TCP" else socket.SOCK_DGRAM ) ) def _set_listener(self): # Check what kind of socket is needed to # bind onto. # Take the first possible socket and the # required IP info for binding. self._addr_listen = socket.getaddrinfo( self._addr, self._port )[0][-1] if hasattr(self, '_socket_listen') and self._socket_listen is not None: self.stop() self._socket_listen = self._get_socket() self._socket_listen.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) def listen(self, addr, port, proxy_type="TCP", backlog=0): if not self._is_listening: self._addr = addr self._port = port self._proxy_type = "UDP" if not proxy_type == "TCP" else proxy_type self._set_listener() self._socket_listen.bind(self._addr_listen) self._socket_listen.listen(backlog) self.rescale(self._n_threads_max) self._listen_thread = Thread( target=self._listener_thread, args=(self._listener_event,) ) self._listen_thread.start() print(f"init done for serving on {self._addr_listen}") self._is_listening = True def stop(self): if self._is_listening: self._listener_event.set() ctr = 0 while self._listen_thread.is_alive(): print( ( "Waiting for listener thread to finish... " f"{ThreadProxy.wait_symbols[ctr%len(ThreadProxy.wait_symbols)]}\r" ), end="" ) ctr += 1 sleep(0.5) else: print("Listener thread finished closing safely.") self.rescale(0) self._socket_listen.close() del self._listen_thread self._is_listening = False def rescale(self, n_threads): new_n_threads = min( ( n_threads if isinstance(n_threads, int) and max(-1,n_threads) >= 0 else self._n_threads ), self._n_threads_max ) old_n_threads = self._n_threads if self._n_threads < new_n_threads: self._spin_up(new_n_threads - self._n_threads) elif self._n_threads > new_n_threads: self._spin_down(self._n_threads - new_n_threads) if new_n_threads != old_n_threads: print( "Changed worker thread size " f"from {old_n_threads} to {new_n_threads}." ) def join(self): if self._is_listening: self._listen_thread.join() def _spin_up(self, thread_cnt_new): self._threads_lock.acquire() new_thread_events = [ Event() for _ in range(thread_cnt_new) ] self._thread_events.extend(new_thread_events) new_threads = [ Thread(target=self._worker_thread, args=(event,)) for event in new_thread_events ] self._threads.extend(new_threads) self._n_threads = len(self._threads) for thread in new_threads: thread.start() self._threads_lock.release() def _spin_down(self, thread_cnt_del): self._threads_lock.acquire() remaining = self._n_threads - thread_cnt_del thread_events_stop = self._thread_events[remaining:] threads_stop = self._threads[remaining:] self._thread_events = self._thread_events[:remaining] self._threads = self._threads[:remaining] for event in thread_events_stop: event.set() ctr = 0 while any( thread.is_alive() for thread in threads_stop ): print( ( "Wait for worker threads to finish... " f"{ThreadProxy.wait_symbols[ctr%len(ThreadProxy.wait_symbols)]}\r" ), end="" ) ctr += 1 sleep(0.5) else: print("Worker threads finished spinning down safely.") for idx in range(len(thread_events_stop)-1, -1, -1): del thread_events_stop[idx] del threads_stop[idx] del thread_events_stop del threads_stop self._threads_lock.release() def _listener_thread(self, event): while not event.is_set(): conn, addr = self._socket_listen.accept() self._job_queue.put((addr, conn)) # clear event to indicate it stopped at spindown task event.clear() def init_tunnle(request, sock_in, sock_out, host, port): if request.startswith(b"CONNECT"): try: sock_out.connect((host,port)) sock_in.sendall(b"HTTP/1.1 200 established\r\n\r\n") except Exception as e: print("Cannot initiate HTTPS connection:", e) return sock_out, sock_in def receive_data(sock, buf_byte): data = b"" is_complete = False while not is_complete: try: part_data = sock.recv(buf_byte) if len(part_data) > 0: data += part_data else: is_complete = True except socket.timeout: is_complete = True return data def _sendrecv(self, sock_in, sock_out): init_data = ThreadProxy.receive_data(sock_in, self._buf_byte_size) protocol, host_domain, port = ThreadProxy.proxy_forward_filter(init_data.decode()) if protocol == "https" or port == 443: sock_out, sock_in = ThreadProxy.init_tunnle(init_data, sock_in, sock_out, host_domain, port) #initial request is CONNECT and handled by init_tunnle is_init = False else: sock_out.connect((host_domain,port)) is_init = True is_last_request = len(init_data) == 0 while not is_last_request: if is_init: request = init_data is_init = False else: request = ThreadProxy.receive_data(sock_in, self._buf_byte_size) if len(request) > 0: sock_out.sendall(request) response = ThreadProxy.receive_data(sock_out, self._buf_byte_size) if len(response) > 0: sock_in.sendall(response) else: is_last_request = True def _worker_thread(self, event): print(f"Worker Thread {get_native_id()}: Start working...") while not event.is_set(): if not self._job_queue.empty(): addr, conn = self._job_queue.get() print(f"Worker Thread {get_native_id()}: Handle request of {addr}") conn.settimeout(self._client_timeout) socket_client_thread = self._get_socket() socket_client_thread.settimeout(self._client_timeout) try: self._sendrecv(conn, socket_client_thread) except Exception as e: print("ERROR occured in Thread: ", e) conn.close() socket_client_thread.close() else: sleep(0.1) # clear event to indicate it stopped at spindown task event.clear() def proxy_forward_filter(request): header = request.split('\n')[0] url = header.split()[1] port = 80 protocol = None has_port = False has_protocol = False if url.startswith("http"): protocol, host_part = url.split('://') has_protocol = True else: host_part = url if ":" in host_part: splitter = host_part.split(':') host_domain = splitter[0] port = int(splitter[1]) has_port = True elif "/" in host_part: host_domain = host_part.split('/')[0] if not has_protocol and has_port: if port == 443: protocol = "https" else: protocol = "http" if not has_port: if protocol == "https": port = 443 else: port = 80 return (protocol, host_domain, port) def main_alt_thread(): mitm = MicroProxy(n_threads_max=20) mitm.listen(addr='0.0.0.0', port=8080)