# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at https://mozilla.org/MPL/2.0/.

# This file has been modified by NVIDIA CORPORATION.
#
# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.

import logging
import os
import re
from shutil import which
from subprocess import Popen, PIPE, STDOUT

logger = logging.getLogger("gstwebrtc_app_resize")
logger.setLevel(logging.DEBUG)

def fit_res(w, h, max_w, max_h):
    if w < max_w and h < max_h:
        # Input resolution fits
        return w, h

    # Reduce input dimensions until they fit
    new_w = float(w)
    new_h = float(h)
    while new_w > max_w or new_h > max_h:
        new_w = float(new_w * 0.9999)
        new_h = float(new_h * 0.9999)

    # Snap final resolution to be divisible by 2.
    new_w, new_h = [int(i) + int(i)%2 for i in (new_w, new_h)]
    return new_w, new_h

def get_new_res(res):
    screen_name = "screen"
    resolutions = []

    screen_pat = re.compile(r'(.*)? connected.*?')
    current_pat = re.compile(r'.*current (\d+ x \d+).*')
    res_pat = re.compile(r'^(\d+x\d+)\s.*$')

    found_screen = False
    curr_res = new_res = max_res = res
    with os.popen('xrandr') as pipe:
        for line in pipe:
            screen_ma = re.match(screen_pat, line.strip())
            current_ma = re.match(current_pat, line.strip())
            if screen_ma:
                found_screen = True
                screen_name, = screen_ma.groups()
            if current_ma:
                curr_res, = current_ma.groups()
                curr_res = curr_res.replace(" ", "")
            if found_screen:
                res_ma = re.match(res_pat, line.strip())
                if res_ma:
                    resolutions += res_ma.groups()

    if not found_screen:
        logger.error("failed to find screen info in xrandr output")
        return curr_res, new_res, resolutions, max_res

    w, h = [int(i) for i in res.split('x')]

    if screen_name.startswith("DVI"):
        # Set max resolution for hardware accelerator.
        max_res = "2560x1600"
        max_res = os.getenv("MAX_RESOLUTION", "2560x1600")
    else:
        max_res = os.getenv("MAX_RESOLUTION", "4096x2160")

    max_w, max_h = [int(i) for i in max_res.split('x')]
    new_w, new_h = fit_res(w, h, max_w, max_h)
    new_res = "%dx%d" % (new_w, new_h)
    
    dpi_adjustment = new_w/w

    resolutions.sort()
    return curr_res, new_res, resolutions, max_res, screen_name, dpi_adjustment

g_dpiAdjustment = 1.0

def set_dpi_adjustment(dpi_adjustment):
    global g_dpiAdjustment
    
    g_dpiAdjustment = round(dpi_adjustment * 4) / 4
    logger.info(f"DPI adjustment is set to {g_dpiAdjustment}")

def get_adjusted_dpi(dpi):
    if not dpi:
        return 96

    adjusted_dpi = int(dpi * g_dpiAdjustment)
    if adjusted_dpi <= 96:
        return 96
    if adjusted_dpi >= 192:
        return 192

    return int(adjusted_dpi)
    
def set_adjusted_dpi():
    old_dpi = get_dpi()
    if not old_dpi:
        return

    new_dpi = get_adjusted_dpi(old_dpi)
    
    logger.info("Setting adjusted DPI to: {}".format(new_dpi))
    if not _set_dpi(new_dpi):
        logger.error("failed to set DPI to {}".format(new_dpi))

    cursor_size = int(new_dpi / 6)
    logger.info("Setting cursor size to: {}".format(cursor_size))
    if not _set_cursor_size(cursor_size):
        logger.error("failed to set cursor size to {}".format(cursor_size))

def resize_display(res):
    curr_res, new_res, resolutions, max_res, screen_name, dpi_adjustment = get_new_res(res)
    if curr_res == new_res:
        logger.info("target resolution is the same: %s, skipping resize" % res)
        return False

    w, h = new_res.split("x")
    res = mode = new_res

    logger.info("resizing display to %s" % res)
    if res not in resolutions:
        logger.info("adding mode %s to xrandr screen '%s'" % (res, screen_name))

        # Generate modeline, this works for Xvfb, not sure about xserver with nvidia driver
        # https://securitronlinux.com/debian-testing/how-to-calculate-vesa-gtf-modelines-with-the-command-line-on-linux/
        mode, modeline = generate_xrandr_gtf_modeline(res)

        # Create new mode from modeline
        logger.info("creating new xrandr mode: %s %s" % (mode, modeline))
        cmd = ['xrandr', '--newmode', mode, *re.split('\s+', modeline)]
        p = Popen(cmd, stdout=PIPE, stderr=PIPE)
        stdout, stderr = p.communicate()
        if p.returncode != 0:
            logger.error("failed to create new xrandr mode: '%s %s': %s%s" % (mode, modeline, str(stdout), str(stderr)))
            return False

        # Add the mode to the screen.
        logger.info("adding xrandr mode '%s' to screen '%s'" % (mode, screen_name))
        cmd = ['xrandr', '--addmode', screen_name, mode]
        p = Popen(cmd, stdout=PIPE, stderr=PIPE)
        stdout, stderr = p.communicate()
        if p.returncode != 0:
            logger.error("failed to add mode '%s' using xrandr: %s%s" % (mode, str(stdout), str(stderr)))
            return False

    # Apply the resolution change
    logger.info("applying xrandr screen '%s' mode: %s" % (screen_name, mode))
    cmd = ['xrandr', '--output', screen_name, '--mode', mode]
    p = Popen(cmd, stdout=PIPE, stderr=PIPE)
    stdout, stderr = p.communicate()
    if p.returncode != 0:
        logger.error("failed to apply xrandr mode '%s': %s%s" % (mode, str(stdout), str(stderr)))
        return False

    set_dpi_adjustment(dpi_adjustment)
    set_adjusted_dpi()

    return True

def generate_xrandr_gtf_modeline(res):
    mode = ""
    modeline = ""
    modeline_pat = re.compile(r'^.*Modeline\s+"(.*?)"\s+(.*)')
    if len(res.split("x")) == 2:
        # have WxH format
        toks = res.split("x")
        gtf_res = "{} {} 60".format(toks[0], toks[1])
        mode = res
    elif len(res.split(" ")) == 2:
        # have W H format
        toks = res.split(" ")
        gtf_res = "{} {} 60".format(toks[0], toks[1])
        mode = "{}x{}".format(toks[0], toks[1])
    elif len(res.split(" ")) == 3:
        # have W H refresh format
        toks = res.split(" ")
        gtf_res = res
        mode = "{}x{}".format(toks[0], toks[1])
    else:
        raise Exception("unsupported input resolution format: {}".format(res))

    with os.popen('cvt -r ' + gtf_res) as pipe:
        for line in pipe:
            modeline_ma = re.match(modeline_pat, line.strip())
            if modeline_ma:
                _, modeline = modeline_ma.groups()
    return mode, modeline

def _set_dpi(dpi):
    cmd = ['xrdb', '-nocpp', '-merge']
    input_data = f"Xft.dpi: {dpi}\n"
    p = Popen(cmd, stdin=PIPE, stdout=PIPE, stderr=PIPE, text=True)
    stdout, stderr = p.communicate(input=input_data)
    if p.returncode != 0:
        logger.error("failed to set XFCE DPI to: '%d': %s%s" % (dpi, str(stdout), str(stderr)))
        return False

    return True

def set_dpi(dpi):
    new_dpi = get_adjusted_dpi(dpi)
    return _set_dpi(new_dpi)
    
def get_dpi():
    cmd = ['xrdb', '-query']
    p = Popen(cmd, stdout=PIPE, stderr=PIPE, text=True)
    stdout, stderr = p.communicate()
    if p.returncode != 0:
        logger.error("failed to get XFCE DPI: %s%s" % (str(stdout), str(stderr)))
        return None

    dpi = None
    for line in stdout.split('\n'):
        if line.startswith("Xft.dpi:"):
            dpi = int(line.split(':')[1].strip())
            break

    logger.info(f"Previous dpi: {dpi}")
    return dpi or 96

def set_cursor_size(size):
    new_size = int(size * g_dpiAdjustment)
    if new_size < 16:
        new_size = 16

    logger.info("Setting cursor size to the adjusted: {}".format(new_size))
    return _set_cursor_size(new_size)

def _set_cursor_size(size):
    if which("xfconf-query"):
        # Set cursor size
        cmd = ["xfconf-query", "-c", "xsettings", "-p", "/Gtk/CursorThemeSize", "-s", str(size), "--create", "-t", "int"]
        p = Popen(cmd, stdout=PIPE, stderr=PIPE)
        stdout, stderr = p.communicate()
        if p.returncode != 0:
            logger.error("failed to set XFCE cursor size to: '%d': %s%s" % (size, str(stdout), str(stderr)))
            return False
    else:
        logger.warning("failed to find supported window manager to set DPI.")
        return False

    return True

def main():
    import sys
    logging.basicConfig(level=logging.INFO)

    if len(sys.argv) < 2:
        print("USAGE: %s WxH" % sys.argv[0])
        sys.exit(1)
    res = sys.argv[1]
    print(resize_display(res))

if __name__ == "__main__":
    main()
