#!/usr/bin/env python3

import os
import sys
import logging
import django

import signal
signal.signal(signal.SIGCHLD, signal.SIG_IGN)

DIR_PATH = os.path.dirname(os.path.abspath(__file__))
ROOT_PATH = os.path.abspath(os.path.join(DIR_PATH, "..", ""))
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "webvirtcloud.settings")
CERT = DIR_PATH + "/cert.pem"

if ROOT_PATH not in sys.path:
    sys.path.append(ROOT_PATH)

django.setup()

import re
import socket

# from six.moves import http_cookies as Cookie
from http import cookies as Cookie
from webvirtcloud.settings import WS_PORT, WS_HOST, WS_CERT
from vrtManager.connection import CONN_SSH, CONN_SOCKET
from console.sshtunnels import SSHTunnels
from optparse import OptionParser

parser = OptionParser()

parser.add_option(
    "-v",
    "--verbose",
    dest="verbose",
    action="store_true",
    help="Verbose mode",
    default=False,
)

parser.add_option(
    "-d", "--debug", dest="debug", action="store_true", help="Debug mode", default=False
)

parser.add_option(
    "-H", "--host", dest="host", action="store", help="Listen host", default=WS_HOST
)

parser.add_option(
    "-p",
    "--port",
    dest="port",
    action="store",
    help="Listen port",
    default=WS_PORT or 6080,
)

parser.add_option(
    "-c",
    "--cert",
    dest="cert",
    action="store",
    help="Certificate file path",
    default=WS_CERT or CERT,
)

(options, args) = parser.parse_args()

FORMAT = "%(asctime)s - %(name)s - %(levelname)s : %(message)s"
if options.debug:
    logging.basicConfig(level=logging.DEBUG, format=FORMAT)
    options.verbose = True
elif options.verbose:
    logging.basicConfig(level=logging.INFO, format=FORMAT)
else:
    logging.basicConfig(level=logging.WARNING, format=FORMAT)

try:
    from websockify import WebSocketProxy

    try:
        from websockify import ProxyRequestHandler
    except ImportError:
        USE_HANDLER = False
    else:
        USE_HANDLER = True
except ImportError:
    try:
        from novnc.wsproxy import WebSocketProxy
    except ImportError:
        print("Unable to import a websockify implementation,\n please install one")
        sys.exit(1)
    else:
        USE_HANDLER = False


def get_connection_infos(token):
    from instances.models import Instance
    from vrtManager.instance import wvmInstance

    try:
        temptoken = token.split("-", 1)
        host = int(temptoken[0])
        uuid = temptoken[1]
        instance = Instance.objects.get(compute_id=host, uuid=uuid)
        conn = wvmInstance(
            instance.compute.hostname,
            instance.compute.login,
            instance.compute.password,
            instance.compute.type,
            instance.name,
        )
        if instance.compute.hostname.count(":"):
            connhost = instance.compute.hostname.split(":")[0]
            connport = instance.compute.hostname.split(":")[1]
        else:
            connhost = instance.compute.hostname
            connport = 22
        connuser = instance.compute.login
        conntype = instance.compute.type
        console_host = conn.get_console_listener_addr()
        console_port = conn.get_console_port()
        console_socket = conn.get_console_socket()
    except Exception as e:
        logging.error(
            "Fail to retrieve console connection infos for token %s : %s" % (token, e)
        )
        raise
    return (
        connhost,
        connport,
        connuser,
        conntype,
        console_host,
        console_port,
        console_socket,
    )


class CompatibilityMixIn(object):
    def _new_client(self, daemon, socket_factory):
        # NoVNC uses it's own convention that forward token
        # from the request to a cookie header, we should check
        # also for this behavior
        hcookie = self.headers.get("cookie")

        if hcookie:
            cookie = Cookie.SimpleCookie()
            for hcookie_part in hcookie.split(";"):
                hcookie_part = hcookie_part.lstrip()
                try:
                    cookie.load(hcookie_part)
                except Cookie.CookieError:
                    # NOTE(stgleb): Do not print out cookie content
                    # for security reasons.
                    self.msg("Found malformed cookie")
                else:
                    if "token" in cookie:
                        token = cookie["token"].value

        (
            connhost,
            connport,
            connuser,
            conntype,
            console_host,
            console_port,
            console_socket,
        ) = get_connection_infos(token)

        cnx_debug_msg = "Connection infos :\n"
        cnx_debug_msg += "- connhost : '%s'\n" % connhost
        cnx_debug_msg += "- connport : '%s'\n" % connport
        cnx_debug_msg += "- connuser : '%s'\n" % connuser
        cnx_debug_msg += "- conntype : '%s'\n" % conntype
        cnx_debug_msg += "- console_host : '%s'\n" % console_host
        cnx_debug_msg += "- console_port : '%s'\n" % console_port
        cnx_debug_msg += "- console_socket : '%s'\n" % console_socket
        logging.debug(cnx_debug_msg)

        if console_socket and conntype == CONN_SOCKET:
            # Local socket on local host
            self.msg("Try to open local socket %s" % console_socket)
            tsock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
            tsock.connect(console_socket)
        elif console_socket or re.match("^127\.", console_host):
            # Need tunnel to physical host
            if conntype != CONN_SSH:
                self.msg(
                    "Need a tunnel to access console but can't mount "
                    + "one because it's not a SSH host"
                )
                raise Exception(self.msg)
            try:
                # generate a string with all placeholders to avoid TypeErrors
                # in sprintf
                # https://github.com/retspen/webvirtmgr/pull/497
                error_msg = "Try to open tunnel on %s@%s:%s on console %s:%s "
                error_msg += "(or socket %s)"
                self.msg(
                    error_msg
                    % (
                        connuser,
                        connhost,
                        connport,
                        console_host,
                        console_port,
                        console_socket,
                    )
                )
                tunnel = SSHTunnels(
                    connhost,
                    connuser,
                    connport,
                    console_host,
                    console_port,
                    console_socket,
                )
                fd = tunnel.open_new()
                tunnel.unlock()
                tsock = socket.fromfd(fd, socket.AF_INET, socket.SOCK_STREAM)
            except Exception as e:
                self.msg("Fail to open tunnel : %s" % e)
                raise
            self.msg("Tunnel opened")
        else:
            # Direct access
            self.msg("connecting to: %s:%s" % (connhost, console_port))
            tsock = socket_factory(connhost, console_port, connect=True)
            tunnel = None

        if self.verbose and not daemon:
            print(self.traffic_legend)

        # Start proxying
        try:
            self.msg("Start proxying")
            self.do_proxy(tsock)
        except Exception:
            if tunnel:
                self.vmsg(
                    "%s:%s (via %s@%s:%s) : Websocket client or Target closed"
                    % (console_host, console_port, connuser, connhost, connport)
                )
                if tsock:
                    tsock.shutdown(socket.SHUT_RDWR)
                    tsock.close()
                tunnel.close_all()
            raise


if USE_HANDLER:

    class NovaProxyRequestHandler(ProxyRequestHandler, CompatibilityMixIn):
        def msg(self, *args, **kwargs):
            self.log_message(*args, **kwargs)

        def vmsg(self, *args, **kwargs):
            if self.verbose:
                self.msg(*args, **kwargs)

        def new_websocket_client(self):
            """
            Called after a new WebSocket connection has been established.
            """
            # Setup variable for compatibility
            daemon = self.server.daemon
            socket_factory = self.server.socket

            self._new_client(daemon, socket_factory)

else:

    class NovaWebSocketProxy(WebSocketProxy, CompatibilityMixIn):
        def new_client(self):
            """
            Called after a new WebSocket connection has been established.
            """
            # Setup variable for compatibility
            daemon = self.daemon
            socket_factory = self.socket

            self._new_client(daemon, socket_factory)


if __name__ == "__main__":
    if USE_HANDLER:
        # Create the WebSocketProxy with NovaProxyRequestHandler handler
        server = WebSocketProxy(
            RequestHandlerClass=NovaProxyRequestHandler,
            listen_host=options.host,
            listen_port=options.port,
            source_is_ipv6=False,
            verbose=options.verbose,
            cert=options.cert,
            key=None,
            ssl_only=False,
            daemon=False,
            record=False,
            web=False,
            traffic=False,
            target_host="ignore",
            target_port="ignore",
            wrap_mode="exit",
            wrap_cmd=None,
        )
    else:
        # Create the NovaWebSockets proxy
        server = NovaWebSocketProxy(
            listen_host=options.host,
            listen_port=options.port,
            source_is_ipv6=False,
            verbose=options.verbose,
            cert=options.cert,
            key=None,
            ssl_only=False,
            daemon=False,
            record=False,
            web=False,
            target_host="ignore",
            target_port="ignore",
            wrap_mode="exit",
            wrap_cmd=None,
        )
    server.start_server()