#!/usr/bin/python3

import argparse
import atexit
import binascii
import http.server
import os
import platform
import re
import select
import shutil
import signal
import socket
import socketserver
import subprocess
import sys
import tempfile
import time

is_verbose = False
def print_verbose(s):
    if is_verbose:
        print(s)

def print_error(s):
    print(s, file=sys.stderr)

def exit_error(s):
    print_error("Error: " + s)
    sys.exit(1)

def bool_arg(val):
    return "on" if val else "off"

def find_qemu(arch):
    binary_names = [ f"qemu-system-{arch}" ]
    if arch == platform.machine():
        binary_names.append("qemu-kvm")

    for binary_name in binary_names:
        if "QEMU_BUILD_DIR" in os.environ:
            p = os.path.join(os.environ["QEMU_BUILD_DIR"], binary_name)
            if os.path.isfile(p):
                return p
            else:
                exit_error(f"Can't find {binary_name}")

        qemu_bin_dirs = ["/usr/bin", "/usr/libexec"]
        if "PATH" in os.environ:
            qemu_bin_dirs += os.environ["PATH"].split(":")

        for d in qemu_bin_dirs:
            p = os.path.join(d, binary_name)
            if os.path.isfile(p):
                return p

    exit_error(f"Can't find {binary_name}")

def download_aboot_u_boot(arch, dest_path):
    with tempfile.TemporaryDirectory(prefix=f"runvm-uboot") as tmpdir:
        autosigrepo="https://mirror.stream.centos.org/SIGs/9-stream/automotive"
        subprocess.run(["dnf", "download", "-q",
                        "--downloaddir", tmpdir,
                        "--disablerepo", "*",
                        "--repofrompath", f"autosig,{autosigrepo}/{arch}/packages-main/",
                        "autosig-u-boot"], check=True)
        rpm_file = os.listdir(tmpdir)[0]
        p1 = subprocess.Popen(["rpm2cpio", os.path.join(tmpdir, rpm_file)], stdout=subprocess.PIPE)
        subprocess.run(["cpio", "-id", "--quiet", "-D", tmpdir], stdin=p1.stdout, check=True)
        subprocess.run(["cp", os.path.join(tmpdir, "boot/u-boot.bin"), dest_path], check=True)

def qemu_available_accels(qemu):
    cmd = qemu + ' -accel help'
    info = subprocess.check_output(cmd.split(" ")).decode('utf-8')
    accel_list = []
    for accel in ('kvm', 'xen', 'hvf', 'hax', 'tcg'):
        if info.find(accel) > 0:
            accel_list.append(accel)
    return accel_list

def random_id():
    return binascii.b2a_hex(os.urandom(8)).decode('utf8')

def machine_id():
    try:
        with open("/etc/machine-id", "r") as f:
            mid = f.read().strip()
    except FileNotFoundError:
        if sys.platform == "darwin":
            # for macOS
            import plistlib
            cmd = "ioreg -rd1 -c IOPlatformExpertDevice -a"
            plist_data = subprocess.check_output(cmd.split(" "))
            mid = plistlib.loads(plist_data)[0]["IOPlatformUUID"].replace("-","")
        else:
            # fallback for the other distros
            hostname = socket.gethostname()
            mid = ''.join(hex(ord(x))[2:] for x in (hostname*16)[:16])

    return mid

def generate_mac_address():
    # create a new mac address based on our machine id
    data = machine_id()

    maclst = ["FE"] + [data[x:x+2] for x in range(-12, -2, 2)]
    return  ":".join(maclst)

def run_http_server(path):
    writer, reader = socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM)
    child_pid = os.fork()
    if child_pid == 0:
        reader.close()

        # Child
        os.chdir(path)

        class HTTPHandler(http.server.SimpleHTTPRequestHandler):
            def log_message(self, format, *args):
                pass # Silence logs

        httpd = socketserver.TCPServer(("127.0.0.1", 0),  HTTPHandler)
        writer.send(str(httpd.server_address[1]).encode("utf8"))
        writer.close()

        httpd.serve_forever()
        sys.exit(0)

    # Parent
    writer.close()
    atexit.register(os.kill, child_pid,signal.SIGTERM)

    http_port = int(reader.recv(128).decode("utf8"))
    reader.close()

    return http_port

def run_virtiofs_server(socket, sharedir):
    vio_args = [
        "/usr/libexec/virtiofsd",
        "--socket-path=" + socket,
        "-o", "source=" + sharedir,
        "-o", "cache=always"
    ]
    if not is_verbose:
        vio_args += [
            "--log-level", "off"
        ]
    print_verbose(f"Running: {' '.join(vio_args)}")
    return subprocess.Popen(vio_args)

def find_ovmf(args):
    dirs = [
        "~/.local/share/ovmf",
        "/usr/share/OVMF",
        "/usr/share/edk2/ovmf",
    ]
    if args.ovmf_dir:
        dirs.insert(0, args.ovmf_dir)

    for d in dirs:
        path = os.path.expanduser(d)
        if args.secureboot:
            suffix = ".secboot"
        else:
            suffix = ""
        if (os.path.exists(f"{path}/OVMF_CODE{suffix}.fd") and
            os.path.exists(f"{path}/OVMF_VARS{suffix}.fd")):
            return path

    raise RuntimeError("Could not find OMVF")

qemu_dirs = [
    "/usr/local/share/qemu",
    "/opt/homebrew/share/qemu",
    "/usr/share/edk2/aarch64",
    "/usr/share/qemu"
]

# location can differ depending on how qemu is installed
def find_edk2():
    for path in qemu_dirs:
        if os.path.exists(path):
            return path

    raise RuntimeError("Could not find edk2 directory")

def find_edk2_code_fd():
    files = [
        "QEMU_EFI.fd",
        "edk2-aarch64-code.fd"
    ]

    for d in qemu_dirs:
        for f in files:
            dir_and_file = os.path.join(d, f)
            if os.path.exists(dir_and_file):
                return dir_and_file

    raise RuntimeError("Could not find edk2 code fd file")

def qemu_run_command(qmp_socket_path, command):
    sock2 = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
    sock2.connect(qmp_socket_path)
    r = sock2.recv(1024)
    sock2.send('{"execute":"qmp_capabilities"}\n'.encode("utf8"))
    r = sock2.recv(1024)
    sock2.send(f'{command}\n'.encode("utf8"))
    r = sock2.recv(1024)
    sock2.close()

def virtio_serial_connect(virtio_socket_path):
    sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
    while True:
        time.sleep(0.1)
        try:
            sock.connect(virtio_socket_path)
            return sock
        except FileNotFoundError:
            pass

def available_tcp_port(port_range_from = 1024):
    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
    port = port_range_from
    port_range_to = port_range_from + 32 # limit for retry
    while port < port_range_to:
        try:
            s.bind(('', port))
        except OSError:
            port += 1
            continue
        break
    s.close()
    return port

class WatchdogCommand:
    START = 1
    STOP = 2

    def __init__(self, op, arg = None):
        self.op = op
        self.arg = arg

def parse_watchdog_commands(sock):
    commands = []
    data = sock.recv(16).decode("utf8")
    for l in data.splitlines():
        if l.startswith("START"):
            try:
                arg = int(l[5:])
            except ValueError:
                arg = 30 # Default if not specified
            commands.append( WatchdogCommand(WatchdogCommand.START, arg) )
        elif l.startswith("STOP"):
            commands.append( WatchdogCommand(WatchdogCommand.STOP) )
        else:
            print_verbose(f"Unsupported watchdog command {l}")
    return commands

def run_watchdog(watch_socket_path, qmp_socket_path):
    sock = virtio_serial_connect(watch_socket_path)

    p = select.poll()
    p.register(sock, select.POLLIN)

    watchdog_timeout = None
    watchdog_delay = 30

    while True:
        timeout = None
        if watchdog_timeout != None:
            timeout = max(watchdog_timeout - time.time(), 0) * 1000

        poll_res = p.poll(timeout)

        if len(poll_res) > 0:
            v = poll_res[0][1]
            if v & select.POLLHUP:
                sys.exit(0)
            commands = parse_watchdog_commands(sock)
            for cmd in commands:
                if cmd.op == WatchdogCommand.START:
                    print_verbose(f"Starting watchdog for {cmd.arg} sec")
                    watchdog_timeout = time.time() + cmd.arg
                if cmd.op == WatchdogCommand.STOP:
                    print_verbose(f"Stopped watchdog")
                    watchdog_timeout = None

        if watchdog_timeout != None and time.time() >= watchdog_timeout:
            print_verbose(f"Triggering watchdog")
            qemu_run_command(qmp_socket_path, '{"execute": "system_reset"}')

            # Queue a new timeout in case the next boot fails, until disabled
            watchdog_timeout = time.time() + watchdog_delay


def main():
    parser = argparse.ArgumentParser(description="Boot virtual machine images")
    parser.add_argument("--verbose", default=False, action="store_true")
    parser.add_argument("--arch", default=platform.machine(), action="store",
                        help=f"Arch to run for (default {platform.machine()})")
    parser.add_argument("--publish-dir", action="store",
                        help=f"Publish the specified directory over http in the vm")
    parser.add_argument("--memory", default="2G",
                        help=f"Memory size (default 2G)")
    parser.add_argument("--nographics", default=False, action="store_true",
                        help=f"Run without graphics")
    parser.add_argument("--nosmp", default=False, action="store_true",
                        help=f"Use a single core")
    parser.add_argument("--aboot", default=False, action="store_true",
                        help=f"Boot with aboot")
    parser.add_argument("--watchdog", default=False, action="store_true",
                        help=f"Enable watchdog")
    parser.add_argument("--tpm2", default=False, action="store_true",
                        help=f"Enable TPM2")
    parser.add_argument("--nvme", default=False, action="store_true",
                        help=f"Use nvme instead of virtio")
    parser.add_argument("--snapshot", default=False, action="store_true",
                        help=f"Work on a snapshot  of the image")
    parser.add_argument("--ovmf-dir", action="store",
                        help="Specify directory for OVMF files (Open Virtual Machine Firmware)")
    parser.add_argument("--secureboot", dest="secureboot", action="store_true", default=False,
                        help="Enable SecureBoot")
    parser.add_argument("--ssh-port", type=int, default=2222,
                        help="SSH port forwarding to SSH_PORT (default 2222)")
    parser.add_argument("--port-forward", type=str, metavar="host:guest,...",
                        help="Add port forwarding rules by host:guest format with comma separation.\ne.g - \"8443:443,8143:143\" will forward the port 8443 of host machine to the post 443 of guest OS and the port 8143 of host to the port 143 of guest too")
    parser.add_argument("--cdrom", action="store",
                        help="Specify .iso to load")
    parser.add_argument("--noaccel", default=False, action="store_true",
                        help=f"Disable acceleration (kvm or hvf)")
    parser.add_argument("--sharedir", action="store",
                        help=f"Share directory using virtiofs")
    parser.add_argument("image", type=str, help="The image to boot")
    parser.add_argument('extra_args', nargs=argparse.REMAINDER, metavar="...", help="extra qemu arguments")

    args = parser.parse_args(sys.argv[1:])

    global is_verbose
    is_verbose = args.verbose

    # arm64 is an alias for aarch64 on macOS
    if args.arch == "arm64":
        args.arch = "aarch64"

    if args.aboot:
        if args.arch != "aarch64":
            exit_error("--aboot only supported with --arch=aarch64")
        aboot_bios = f"qemu-u-boot-{args.arch}.bin"
        if not os.path.exists(aboot_bios):
            print(f"Missing file '{aboot_bios}'. Downloading")
            download_aboot_u_boot(args.arch, aboot_bios)

    qemu = find_qemu(args.arch)
    accel_list = qemu_available_accels(qemu)
    qemu_args = [qemu]

    num_cpus = os.cpu_count()

    if args.arch == "x86_64":
        machine = "q35"
        default_cpu = "qemu64,+ssse3,+sse4.1,+sse4.2,+popcnt"

        ovmf = find_ovmf(args)
        if args.secureboot:
            qemu_args += [
                "-drive", f"file={ovmf}/OVMF_CODE.secboot.fd,if=pflash,format=raw,unit=0,readonly=on",
                "-drive", f"file={ovmf}/OVMF_VARS.secboot.fd,if=pflash,format=raw,unit=1,snapshot=on,readonly=off",
            ]
        else:
            qemu_args += [
                "-drive", f"file={ovmf}/OVMF_CODE.fd,if=pflash,format=raw,unit=0,readonly=on",
                "-drive", f"file={ovmf}/OVMF_VARS.fd,if=pflash,format=raw,unit=1,snapshot=on,readonly=off",
            ]
    elif args.arch == "aarch64":
        machine = "virt"
        default_cpu = "cortex-a57"
        num_cpus = min(os.cpu_count(), 8) # for up to 8 cores (limitation of qemu-system-aarch64)
        if sys.platform == "darwin":
            qemu_args += [
                "-device", "virtio-gpu-pci", # for display
                "-display", "default,show-cursor=on", # for display
                "-device", "qemu-xhci", # for keyboard
                "-device", "usb-kbd", # for keyboard
                "-device", "usb-tablet", # for mouse
            ]
        if args.aboot:
            args.memory = "4G" # this is hardcoded in our dtb files
            qemu_args += [
                "-bios", aboot_bios
            ]
        elif sys.platform == "darwin":
            edk2 = find_edk2()
            qemu_args += [
                "-drive", f"file={edk2}/edk2-aarch64-code.fd,if=pflash,format=raw,unit=0,readonly=on",
                "-drive", f"file={edk2}/edk2-arm-vars.fd,if=pflash,format=raw,unit=1,snapshot=on,readonly=off"
            ]
        else:
            edk2_file = find_edk2_code_fd()
            qemu_args += [
                "-bios", f"{edk2_file}",
                "-boot", "efi"
            ]
    else:
        exit_error(f"unsupported architecture {args.arch}")

    if not args.nosmp and num_cpus > 1:
        qemu_args += [
            "-smp", str(num_cpus)
        ]

    accel_enabled = True

    # There are some cases that acceleration may not work,
    # kvm accelerated aboot is one, kernel crash
    if args.noaccel:
        accel_enabled = False
    elif 'kvm' in accel_list and os.path.exists("/dev/kvm"):
        qemu_args += ['-enable-kvm']
    elif 'hvf' in accel_list:
        qemu_args += ['-accel', 'hvf']
    else:
        accel_enabled = False

    if not accel_enabled:
        print_verbose("Acceleration: off")

    qemu_args += [
        "-m", str(args.memory),
        "-machine", machine,
        "-cpu", "host" if accel_enabled else default_cpu
    ]

    guestfwds=""

    if args.publish_dir:
        if shutil.which("netcat") is None:
            print("Command `netcat` not found in path, ignoring publish-dir")
        else:
            httpd_port = run_http_server(args.publish_dir)
            guestfwds = f"guestfwd=tcp:10.0.2.100:80-cmd:netcat 127.0.0.1 {httpd_port},"
            print_verbose(f"publishing {args.publish_dir} on http://10.0.2.100/")

    portfwd = {
        available_tcp_port(args.ssh_port): 22
    }

    if args.port_forward:
        for rule in args.port_forward.split(','):
            match = re.search('([0-9]+):([0-9]+)', rule)
            if match:
                host, guest = match.groups()
                portfwd[available_tcp_port(int(host))] = int(guest)
            else:
                exit_error(f'Invalid port-forward rule "{rule}"')

    for local, remote in portfwd.items():
        print_verbose(f"port: {local} → {remote}")

    fwds = [f"hostfwd=tcp::{h}-:{g}" for h, g in portfwd.items()]

    macstr = generate_mac_address()
    print_verbose(f"MAC: {macstr}")

    net_dev = "virtio-net-pci"
    if args.aboot:
        net_dev = "virtio-net-device"
    qemu_args += [
        "-device", f"{net_dev},netdev=n0,mac={macstr}",
        "-netdev", "user,id=n0,net=10.0.2.0/24," + guestfwds + ",".join(fwds),
    ]

    if args.nographics:
        qemu_args += ["-nographic"]

    runvm_id = random_id()

    tmpdir = tempfile.TemporaryDirectory(prefix=f"runvm-{runvm_id}")

    watchdog_pid = 0
    if args.watchdog:
        qmp_socket_path = os.path.join(tmpdir.name, "qmp-socket")
        watch_socket_path = os.path.join(tmpdir.name, "watch-socket")

        qemu_args += [
            "-qmp", f"unix:{qmp_socket_path},server=on,wait=off",
            "-device", "virtio-serial", "-chardev", f"socket,path={watch_socket_path},server=on,wait=off,id=watchdog",
            "-device", "virtserialport,chardev=watchdog,name=watchdog.0"
        ]

        watchdog_pid = os.fork()
        if watchdog_pid == 0:
            run_watchdog(watch_socket_path, qmp_socket_path)
            sys.exit(0)

    if args.tpm2:
        if shutil.which("swtpm") is None:
            exit_error("Command `swtpm` not found in path, this is needed for tpm2 support")

        tpm2_socket = os.path.join(tmpdir.name, "tpm-socket")

        if args.snapshot:
            tpm2_path = os.path.join(tmpdir.name, "tpm2_state")
        else:
            tpm2_path = ".tpm2_state"
        os.makedirs(tpm2_path, exist_ok=True)

        swtpm_args = ["swtpm", "socket", "--tpm2", "--tpmstate", f"dir={tpm2_path}", "--ctrl", f"type=unixio,path={tpm2_socket}" ]
        res = subprocess.Popen(swtpm_args)

        qemu_args += [
            "-chardev", f"socket,id=chrtpm,path={tpm2_socket}",
            "-tpmdev", "emulator,id=tpm0,chardev=chrtpm",
            "-device", "tpm-tis,tpmdev=tpm0"
        ]

    print_verbose(f"Image: {args.image}")

    # Chose disk device
    disk_if = "none"
    if args.nvme:
        qemu_args += [ "-device", "nvme,serial=deadbeef,drive=rootdisk"  ]
    elif args.aboot:
        # virtio-blk-pci doesn't work with aboot, so use virtio-blk-device
        qemu_args += [ "-device", "virtio-blk-device,scsi=off,drive=rootdisk"  ]
    else:
        # Default "if=virtio", typically a pci based virtio device
        disk_if = "virtio"

    disk_format = "qcow2"
    if args.image.endswith(".raw") or args.image.endswith(".img"):
        disk_format = "raw"

    qemu_args += [
        "-drive", f"file={args.image},index=0,media=disk,format={disk_format},if={disk_if},id=rootdisk,snapshot={bool_arg(args.snapshot)}",
    ]

    if args.cdrom:
        qemu_args += [
            "-cdrom", args.cdrom,
            "-boot", "d"
        ]

    virtiod = None
    if args.sharedir:
        if not os.path.isdir(args.sharedir):
            exit_error(f"Shared dir {args.sharedir} is not a valid directory")

        vhostsocket = os.path.join(tmpdir.name, "vhost")

        virtiod = run_virtiofs_server(vhostsocket, args.sharedir)

        qemu_args += [
            "-chardev", "socket,id=char0,path=" + vhostsocket,
            "-device", "vhost-user-fs-pci,queue-size=1024,chardev=char0,tag=host",
            "-object", "memory-backend-file,id=mem,size="+str(args.memory)+",mem-path=/dev/shm,share=on",
            "-numa", "node,memdev=mem"
        ]
        print(f"Sharing directory {args.sharedir}, mount using 'mount -t virtiofs host /mnt'")

    qemu_args += args.extra_args

    print_verbose(f"Running: {' '.join(qemu_args)}")

    try:
        res = subprocess.run(qemu_args, check=False)
    except KeyboardInterrupt:
        exit_error("Aborted")

    if watchdog_pid:
        os.kill(watchdog_pid, signal.SIGTERM)

    if virtiod:
        virtiod.terminate()

    tmpdir.cleanup()

    return res.returncode


if __name__ == "__main__":
    sys.exit(main())
