import http.server
import socketserver
import urllib.parse
import dnslib
import base64

try:

    def resolve(zone, qname, qtype):
        for record in zone:
            if record["name"] == qname and record["type"] == qtype and "value" in record:
                return record["value"]


    class DnsHttpRequestHandler(http.server.BaseHTTPRequestHandler):
        def do_GET(self):
            try:
                with open("/dns/zone.json", "r") as f:
                    import json
                    zone = json.load(f)

                    url = urllib.parse.urlparse(self.path)
                    if url.path != "/dns-query":
                        self.send_response(404)
                        return
                    query = urllib.parse.parse_qs(url.query)
                    if "dns" not in query:
                        self.send_response(400)
                        return
                    query_base64 = query["dns"][0]
                    padded = query_base64 + "=" * (4 - len(query_base64) % 4)
                    raw = base64.b64decode(padded)
                    dns = dnslib.DNSRecord.parse(raw)

                    response = dnslib.DNSRecord(dnslib.DNSHeader(id=dns.header.id, qr=1, aa=1, ra=1), q=dns.q)

                    record = resolve(zone, dns.q.qname, dnslib.QTYPE[dns.q.qtype])
                    if record:
                        if dns.q.qtype == dnslib.QTYPE.SRV:
                            print("SRV record")
                            reply = dnslib.SRV(record["priority"], record["weight"], record["port"], record["target"])
                        response.add_answer(dnslib.RR(dns.q.qname, dns.q.qtype, rdata=reply))
                    else:
                        response.header.rcode = dnslib.RCODE.NXDOMAIN

                    print(response)

                    self.send_response(200)
                    self.send_header("Content-type", "application/dns-message")
                    self.end_headers()
                    pack = response.pack()
                    self.wfile.write(pack)
                    return
            except Exception as e:
                print(f"Error: {e}")
                self.send_response(500)
                self.send_header("Content-type", "text/html")
                self.end_headers()
                self.wfile.write(b"Internal Server Error")


    handler_object = DnsHttpRequestHandler

    PORT = 8053
    my_server = socketserver.TCPServer(("", PORT), handler_object)

    # Start the server
    print(f"Starting server on port {PORT}")
    my_server.serve_forever()

except Exception as e:
    print(f"Error: {e}")