Add test decorator for pathhandling

This commit is contained in:
Lars Hahn 2023-09-20 20:58:07 +02:00
parent f96d33b9f9
commit d5fd1bacc5

View File

@ -3,6 +3,7 @@ import network
import ubinascii import ubinascii
import time import time
import usocket import usocket
import json
import sys import sys
@ -31,7 +32,6 @@ class Request:
) )
if len(http_request_path_split) > 1: if len(http_request_path_split) > 1:
print(http_request_path_split[1])
http_request.parameter = dict( http_request.parameter = dict(
param.split('=') param.split('=')
for param in http_request_path_split[1].split('&') for param in http_request_path_split[1].split('&')
@ -116,6 +116,7 @@ class Request:
def content(self, content): def content(self, content):
if self._content is None or len(self._content) == 0: if self._content is None or len(self._content) == 0:
self._content = content self._content = content
self._header['Content-Length'] = str(len(content))
@property @property
def request_bytes(self): def request_bytes(self):
request = bytearray(self._header_request_bytes) request = bytearray(self._header_request_bytes)
@ -127,6 +128,8 @@ class Request:
strpath = self._path strpath = self._path
if len(self._parameter) > 0: if len(self._parameter) > 0:
strpath += f"?{'&'.join('='.join(item) for item in self._parameter.items())}" strpath += f"?{'&'.join('='.join(item) for item in self._parameter.items())}"
if len(self._header) == 0:
return f"{self._method} {strpath} {self._protocol}\r\n\r\n".encode()
return f"{self._method} {strpath} {self._protocol}\r\n{headers}\r\n\r\n".encode() return f"{self._method} {strpath} {self._protocol}\r\n{headers}\r\n\r\n".encode()
class Response(Request): class Response(Request):
@ -139,7 +142,7 @@ class Response(Request):
def __init__(self, method, path, protocol, status, def __init__(self, method, path, protocol, status,
host="", header={}, content=None, parameter={}): host="", header={}, content=None, parameter={}):
super().__init__(method, path, protocol, host, header, content, parameter) super().__init__(method, path, protocol, host, header, content, parameter)
self._status = status self._status = str(status)
def __repr__(self): def __repr__(self):
return repr({ return repr({
@ -193,11 +196,27 @@ class MicroWebServer:
content = request_content_bytes.decode() content = request_content_bytes.decode()
request.content = content request.content = content
return request return request
@staticmethod
def default_locator():
return Response(None, None, "HTTP/1.0", 404)
def __init__(self, listen_addr="0.0.0.0", port=80): def __init__(self, listen_addr="0.0.0.0", port=80):
self._listen_addr = listen_addr self._listen_addr = listen_addr
self._port = port self._port = port
self._locations = {}
def location(self, path):
def decorator(functor):
self._locations[path] = functor
return functor
return decorator
@property
def listen_addr(self):
return self._listen_addr
@property
def port(self):
return self._port
def _create_socket(self): def _create_socket(self):
self._socket_addrinfo = usocket.getaddrinfo(self._listen_addr, self._port) self._socket_addrinfo = usocket.getaddrinfo(self._listen_addr, self._port)
@ -207,8 +226,10 @@ class MicroWebServer:
def _handle_request(self, request): def _handle_request(self, request):
logger.debug(f"Request: \n{request}") logger.debug(f"Request: \n{request}")
response = Response(request.method, request.path, request.protocol, "200") if request.path in self._locations.keys():
return response return self._locations[request.path]()
else:
return MicroWebServer.default_locator()
def serve(self): def serve(self):
self._create_socket() self._create_socket()
@ -257,6 +278,8 @@ def connect_wlan(ssid, passphrase, hostname="PicoW", country="DE", power_save=Fa
print(f"\rConnection to '{ssid}' established. Connected via address '{ip}' from '{wlan_mac_address}'.") print(f"\rConnection to '{ssid}' established. Connected via address '{ip}' from '{wlan_mac_address}'.")
return wlan return wlan
def main(): def main():
ssid = "SSID" ssid = "SSID"
pw = "PASSWORD" pw = "PASSWORD"
@ -264,8 +287,26 @@ def main():
get_logger() get_logger()
webserver = MicroWebServer() webserver = MicroWebServer()
@webserver.location(path="/")
def test_handler():
return Response("GET", "/", "HTTP/1.0", 204)
@webserver.location(path="/my/path")
def test_handler_content():
body = {
"Hello":"World"
}
resp = Response("GET", "/", "HTTP/1.0", 200)
# TODO: add some nicer handling with headers and content...
resp.header = {
"Content-Type": "application/json"
}
resp.content = json.dumps(body)
print(resp.header)
return resp
webserver.serve() webserver.serve()
if __name__=='__main__': if __name__=='__main__':
main() main()