#!/usr/bin/python3

# This script evaluates whether firmware redirection is likely. It uses three cues:
# - Does the system offer up SPCR? This would indicate that the firmware is doing serial output.
#   Otherwise, there's no indication that the firmware cares about serial console.
# - Is the system EFI?  BIOS implementations may not intercept text draw calls after POST exit,
#   thus even when BIOS tells us serial port is in use, it may not be doing anything when
#   grub would be running
# - Is the serial port connected? In the event that firmware indicates serial port, but
#   serial port is not reporting DCD, then it doesn't look like a comfortable enough scenario

import fcntl
import os
import os.path
import struct
import subprocess
import termios


addrtoname = {
    0x3f8: '/dev/ttyS0',
    0x2f8: '/dev/ttyS1',
    0x3e8: '/dev/ttyS2',
    0x2e8: '/dev/ttyS3',
}
speedmap = {
    0: None,
    3: 9600,
    4: 19200,
    6: 57600,
    7: 115200,
}

termiobaud = {
    9600: termios.B9600,
    19200: termios.B19200,
    57600: termios.B57600,
    115200: termios.B115200,
}


def deserialize_grub_rh():
    if 'console=ttyS' in open('/proc/cmdline').read():
        return None # User manually indicated serial config
        # they own the grub behavior too for now
    grublines = []
    with open('/etc/default/grub') as grubin:
        grublines = grubin.read().split('\n')
    with open('/etc/default/grub', 'w') as grubout:
        for grubline in grublines:
            if grubline.startswith('GRUB_TERMINAL'):
                grubline = grubline.replace('serial ', '')
            grubout.write(grubline + '\n')
    subprocess.check_call(['grub2-mkconfig', '-o', '/boot/grub2/grub.cfg'])

def fixup_ubuntu_grub_serial():
    # Ubuntu aggressively tries to graphics up
    # grub. We will counter that for serial
    # They also aggressively hide UI and
    # block ability to interject. We will
    # compromise and lean on nodeboot <node> setup
    # as a means to give someone reasonable shot at
    # the short timeout
    with open('/etc/default/grub') as grubin:
        grublines = grubin.read().split('\n')
    with open('/etc/default/grub', 'w') as grubout:
        for grubline in grublines:
            if grubline.startswith('GRUB_TIMEOUT_STYLE=hidden'):
                grubline = 'GRUB_TIMEOUT_STYLE=menu'
            elif grubline.startswith('GRUB_TIMEOUT=0'):
                grubline = 'GRUB_TIMEOUT=2'
            elif grubline.startswith('#GRUB_TERMINAL=console'):
                grubline = grubline.replace('#', '')
            grubout.write(grubline + '\n')
    subprocess.check_call(['update-grub'])

def get_serial_config():
    if not os.path.exists('/sys/firmware/efi'):
        return None
    if not os.path.exists('/sys/firmware/acpi/tables/SPCR'):
        return None
    spcr = open("/sys/firmware/acpi/tables/SPCR", "rb")
    spcr = bytearray(spcr.read())
    if spcr[8] != 2 or spcr[36] != 0 or spcr[40] != 1:
        return None
    address = struct.unpack('<Q', spcr[44:52])[0]
    tty = None
    try:
        tty = addrtoname[address]
    except KeyError:
        return None
    retval = { 'tty': tty }
    try:
        retval['speed'] = speedmap[spcr[58]]
    except KeyError:
        return None
    if retval['speed']:
        ttyf = os.open(tty, os.O_RDWR | os.O_NOCTTY)
        currattr = termios.tcgetattr(ttyf)
        currattr[4:6] = [0, termiobaud[retval['speed']]]
        termios.tcsetattr(ttyf, termios.TCSANOW, currattr)
    retval['connected'] = bool(struct.unpack('<I', fcntl.ioctl(
        ttyf, termios.TIOCMGET, '\x00\x00\x00\x00'))[0] & termios.TIOCM_CAR)
    os.close(ttyf)
    return retval

def main():
    autoconscfg = get_serial_config()
    if not autoconscfg or not autoconscfg['connected']:
        return
    if os.path.exists('/etc/redhat-release'):  # redhat family
        deserialize_grub_rh()
    elif os.path.exists('/etc/os-release'):
        with open('/etc/os-release') as osr:
            if 'Ubuntu' in osr.read():
                fixup_ubuntu_grub_serial()

if __name__ == '__main__':
    main()
