#!/usr/bin/python

import glob
import json
import os
import socket
import sys
import time
import shlex
import subprocess
try:
    import yaml
except ImportError:
    pass
try:
    from importlib.machinery import SourceFileLoader
    def load_source(mod, path):
        return SourceFileLoader(mod, path).load_module()
except ImportError:
    from imp import load_source

try:
    apiclient = load_source('apiclient', '/opt/confluent/bin/apiclient')
except IOError:
    apiclient = load_source('apiclient', '/etc/confluent/apiclient')

def add_lla(iface, mac):
    pieces = mac.split(':')
    initbyte = int(pieces[0], 16) ^ 2
    lla = 'fe80::{0:x}{1}:{2}ff:fe{3}:{4}{5}/64'.format(initbyte, pieces[1], pieces[2], pieces[3], pieces[4], pieces[5])
    try:
        with open('/proc/sys/net/ipv6/conf/{0}/disable_ipv6'.format(iface), 'w') as setin:
            setin.write('0')
        subprocess.check_call(['ip', 'addr', 'add', 'dev', iface, lla, 'scope', 'link'])
    except Exception:
        return None
    return lla

#cli = apiclient.HTTPSClient(json=True)
#c = cli.grab_url_with_status('/confluent-api/self/netcfg')
def add_missing_llas():
    #NetworkManager goes out of its way to suppress ipv6 lla, so will just add some
    added = {}
    linkinfo = subprocess.check_output(['ip', '-br', 'l']).decode('utf8')
    ifaces = {}
    for line in linkinfo.split('\n'):
        line = line.strip().split()
        if not line or 'LOOPBACK' in line[-1] or 'NO-CARRIER' in line[-1]:
            continue
        if 'UP' not in line[-1]:
            subprocess.call(['ip', 'link', 'set', line[0], 'up'])
        ifaces[line[0]] = line[2]
    ips = {}
    ipinfo = subprocess.check_output(['ip', '-br', '-6', 'a']).decode('utf8')
    for line in ipinfo.split('\n'):
        line = line.strip().split(None, 2)
        if not line:
            continue
        ips[line[0]] = line[2]
    for iface in ifaces:
        for addr in ips.get(iface, '').split():
            if addr.startswith('fe80::'):
                break
        else:
            newlla = add_lla(iface, ifaces[iface])
            if newlla:
                added[iface] = newlla
    return added

def rm_tmp_llas(tmpllas):
    for iface in tmpllas:
        subprocess.check_call(['ip', 'addr', 'del', 'dev', iface, tmpllas[iface]])

def await_tentative():
    maxwait = 10
    while b'tentative' in subprocess.check_output(['ip', 'a']):
        if maxwait == 0:
            break
        maxwait -= 1
        time.sleep(1)

def map_idx_to_name():
    map = {}
    devtype = {}
    prevdev = None
    for line in subprocess.check_output(['ip', 'l']).decode('utf8').splitlines():
        if line.startswith(' ') and 'link/' in line:
            typ = line.split()[0].split('/')[1]
            devtype[prevdev] = typ if typ != 'ether' else 'ethernet'
        if line.startswith(' '):
            continue
        idx, iface, rst = line.split(':', 2)
        prevdev = iface.strip()
        rst = rst.split()
        try:
            midx = rst.index('master')
            continue
        except ValueError:
            pass
        idx = int(idx)
        iface = iface.strip()
        map[idx] = iface
    return map, devtype


def get_interface_name(iname, settings):
    explicitname = settings.get('interface_names', None)
    if explicitname:
        return explicitname
    if settings.get('current_nic', False):
        return iname
    return None

class NetplanManager(object):
    def __init__(self, deploycfg):
        self.cfgbydev = {}
        self.read_connections()
        self.deploycfg = deploycfg

    def read_connections(self):
        for plan in glob.glob('/etc/netplan/*.y*ml'):
            with open(plan) as planfile:
                planinfo = yaml.safe_load(planfile)
                if not planinfo:
                    continue
            nicinfo = planinfo.get('network', {}).get('ethernets', {})
            for devname in nicinfo:
                if devname == 'lo':
                    continue
                if 'gateway4' in nicinfo[devname]:
                    # normalize deprecated syntax on read in
                    gw4 = nicinfo[devname]['gateway4']
                    del nicinfo[devname]['gateway4']
                    routeinfo = nicinfo[devname].get('routes', [])
                    for ri in routeinfo:
                        if ri.get('via', None) == gw4 and ri.get('to', None) in ('default', '0.0.0.0/0', '0/0'):
                            break
                    else:
                        routeinfo.append({
                            'to': 'default',
                            'via': gw4
                        })
                    nicinfo[devname]['routes'] = routeinfo
                self.cfgbydev[devname] = nicinfo[devname]
    
    def apply_configuration(self, cfg):
        devnames = cfg['interfaces']
        if len(devnames) != 1:
            raise Exception('Multi-nic team/bonds not yet supported')
        stgs = cfg['settings']
        needcfgapply = False
        for devname in devnames:
            needcfgwrite = False
            # ipv6_method missing at uconn...
            if stgs.get('ipv6_method', None) == 'static':
                curraddr = stgs['ipv6_address']
                currips = self.getcfgarrpath([devname, 'addresses'])
                if curraddr not in currips:
                    needcfgwrite = True
                    currips.append(curraddr)
            if stgs.get('ipv4_method', None) == 'static':
                curraddr = stgs['ipv4_address']
                currips = self.getcfgarrpath([devname, 'addresses'])
                if curraddr not in currips:
                    needcfgwrite = True
                    currips.append(curraddr)
            if stgs.get('mtu', None):
                devdict = self.getcfgarrpath([devname])
                devdict['mtu'] = int(stgs['mtu'])
            gws = []
            gws.append(stgs.get('ipv4_gateway', None))
            gws.append(stgs.get('ipv6_gateway', None))
            for gwaddr in gws:
                if gwaddr:
                    cfgroutes = self.getcfgarrpath([devname, 'routes'])
                    for rinfo in cfgroutes:
                        if rinfo.get('via', None) == gwaddr:
                            break
                    else:
                        needcfgwrite = True
                        cfgroutes.append({'via': gwaddr, 'to': 'default'})
            dnsips = self.deploycfg.get('nameservers', [])
            dnsdomain = self.deploycfg.get('dnsdomain', '')
            if dnsips:
                currdnsips = self.getcfgarrpath([devname, 'nameservers', 'addresses'])
                for dnsip in dnsips:
                    if dnsip and dnsip not in currdnsips:
                        needcfgwrite = True
                        currdnsips.append(dnsip)
            if dnsdomain:
                currdnsdomain = self.getcfgarrpath([devname, 'nameservers', 'search'])
                if dnsdomain not in currdnsdomain:
                    needcfgwrite = True
                    currdnsdomain.append(dnsdomain)
            if needcfgwrite:
                needcfgapply = True
                newcfg = {'network': {'version': 2, 'ethernets': {devname: self.cfgbydev[devname]}}}
                oumask = os.umask(0o77)
                with open('/etc/netplan/{0}-confluentcfg.yaml'.format(devname), 'w') as planout:
                    planout.write(yaml.dump(newcfg))
                os.umask(oumask)
        if needcfgapply:
            subprocess.call(['netplan', 'apply'])

    def getcfgarrpath(self, devpath):
        currptr = self.cfgbydev
        for k in devpath[:-1]:
            if k not in currptr:
                currptr[k] = {}
            currptr = currptr[k]
        if devpath[-1] not in currptr:
            currptr[devpath[-1]] = []
        return currptr[devpath[-1]]
        


class WickedManager(object):
    def __init__(self):
        self.teamidx = 0
        self.read_connections()

    def read_connections(self):
        self.cfgbydev = {}
        for ifcfg in glob.glob('/etc/sysconfig/network/ifcfg-*'):
            devname = ifcfg.replace('/etc/sysconfig/network/ifcfg-', '')
            if devname == 'lo':
                continue
            currcfg = {}
            self.cfgbydev[devname] = currcfg
            for cfg in open(ifcfg).read().splitlines():
                cfg = cfg.split('#', 1)[0]
                try:
                    kv = ' '.join(shlex.split(cfg)).split('=', 1)
                except Exception:
                    # unparseable line, likely having something we can't handle
                    del self.cfgbydev[devname]
                if len(kv) != 2:
                    continue
                k, v = kv
                k = k.strip()
                v = v.strip()
                currcfg[k] = v

    def apply_configuration(self, cfg):
        stgs = cfg['settings']
        ipcfg = 'STARTMODE=auto\n'
        routecfg = ''
        bootproto4 = stgs.get('ipv4_method', 'none')
        bootproto6 = stgs.get('ipv6_method', 'none')
        if bootproto4 == 'dhcp' and bootproto6 == 'dhcp':
            ipcfg += 'BOOTPROTO=dhcp\n'
        elif bootproto4 == 'dhcp':
            ipcfg += 'BOOTPROTO=dhcp4\n'
        elif bootproto6 == 'dhcp':
            ipcfg += 'BOOTPROTO=dhcp6\n'
        else:
            ipcfg += 'BOOTPROTO=static\n'
        if stgs.get('ipv4_address', None):
            ipcfg += 'IPADDR=' + stgs['ipv4_address'] + '\n'
        v4gw = stgs.get('ipv4_gateway', None)
        if stgs.get('ipv6_address', None):
            ipcfg += 'IPADDR_V6=' + stgs['ipv6_address'] + '\n'
        v6gw = stgs.get('ipv6_gateway', None)
        cname = None
        if len(cfg['interfaces']) > 1:  # creating new team
            if not stgs.get('team_mode', None):
                sys.stderr.write("Warning, multiple interfaces ({0}) without a team_mode, skipping setup\n".format(','.join(cfg['interfaces'])))
                return
            if not stgs.get('connection_name', None):
                stgs['connection_name'] = 'bond{0}'.format(self.teamidx)
                self.teamidx += 1
            cname = stgs['connection_name']
            with open('/etc/sysconfig/network/ifcfg-{0}'.format(cname), 'w') as teamout:
                teamout.write(ipcfg)
                if stgs['team_mode'] == 'lacp':
                    stgs['team_mode'] = '802.3ad'
                teamout.write("BONDING_MODULE_OPTS='mode={0} miimon=100'\nBONDING_MASTER=yes\n".format(stgs['team_mode']))
                idx = 1
                for iface in cfg['interfaces']:
                    subprocess.call(['wicked', 'ifdown', iface])
                    try:
                        os.remove('/etc/sysconfig/network/ifcfg-{0}'.format(iface))
                        os.remove('/etc/sysconfig/network/ifroute-{0}'.format(iface))
                    except OSError:
                        pass
                    teamout.write('BONDING_SLAVE{0}={1}\n'.format(idx, iface))
                    idx += 1
        else:
            cname = list(cfg['interfaces'])[0]
            priorcfg = self.cfgbydev.get(cname, {})
            for cf in priorcfg:
                if cf.startswith('TEAM_'):
                    ipcfg += '{0}={1}\n'.format(cf, priorcfg[cf])
            with open('/etc/sysconfig/network/ifcfg-{0}'.format(cname), 'w') as iout:
                iout.write(ipcfg)
        if v4gw:
            routecfg += 'default {0} - {1}\n'.format(v4gw, cname)
        if v6gw:
            routecfg += 'default {0} - {1}\n'.format(v6gw, cname)
        if routecfg:
            with open('/etc/sysconfig/network/ifroute-{0}'.format(cname), 'w') as routeout:
                routeout.write(routecfg)
        subprocess.call(['wicked', 'ifup', cname])


class NetworkManager(object):
    bondtypes = {
            'lacp': '802.3ad',
            'loadbalance': 'balance-alb',
            'roundrobin': 'balance-rr',
            'activebackup': 'active-backup',
    }
    def __init__(self, devtypes, deploycfg):
        self.deploycfg = deploycfg
        self.connections = {}
        self.uuidbyname = {}
        self.uuidbydev = {}
        self.connectiondetail = {}
        self.read_connections()
        self.teamidx = 0
        self.devtypes = devtypes

    def read_connections(self):
        self.connections = {}
        self.uuidbyname = {}
        self.uuidbydev = {}
        self.connectiondetail = {}
        ci = subprocess.check_output(['nmcli', '-t', 'c']).decode('utf8')
        for inf in ci.splitlines():
            n, u, t, dev = inf.split(':')
            if n == 'NAME':
                continue
            if dev == '--':
                dev = None
            self.uuidbyname[n] = u
            if dev:
                self.uuidbydev[dev] = u
            self.connections[u] = {'name': n, 'uuid': u, 'type': t, 'dev': dev}
            deats = {}
            for deat in subprocess.check_output(['nmcli', 'c', 's', u]).decode('utf8').splitlines():
                k, v = deat.split(':', 1)
                v = v.strip()
                if v == '--':
                    continue
                if '(default)' in v:
                    continue
                deats[k] = v
            self.connectiondetail[u] = deats


    def add_team_member(self, team, member):
        bondcfg = {}
        if member in self.uuidbydev:
            myuuid = self.uuidbydev[member]
            deats = self.connectiondetail[myuuid]
            currteam = deats.get('connection.master', None)
            if currteam == team:
                return
            for stg in ('ipv4.dhcp-hostname', 'ipv4.dns', 'ipv6.dns', 'ipv6.dhcp-hostname', 'ipv4.dns-search', 'ipv6.dns-search'):
                if deats.get(stg, None):
                    bondcfg[stg] = deats[stg]
        if member in self.uuidbyname:
            subprocess.check_call(['nmcli', 'c', 'del', self.uuidbyname[member]])
        devtype = self.devtypes.get(member, 'bond-slave')
        subprocess.check_call(['nmcli', 'c', 'add', 'type', devtype, 'master', team, 'con-name', member, 'connection.interface-name', member])
        if bondcfg:
            args = []
            for parm in bondcfg:
                args.append(parm)
                args.append(bondcfg[parm])
            subprocess.check_call(['nmcli', 'c', 'm', team] + args)

    def apply_configuration(self, cfg, lastchance=False):
        cmdargs = {}
        cmdargs['connection.autoconnect'] = 'yes'
        stgs = cfg['settings']
        cmdargs['ipv6.method'] = stgs.get('ipv6_method', 'link-local')
        if stgs.get('ipv6_address', None):
            cmdargs['ipv6.addresses'] = stgs['ipv6_address']
        cmdargs['ipv4.method'] = stgs.get('ipv4_method', 'disabled')
        if stgs.get('ipv4_address', None):
            cmdargs['ipv4.addresses'] = stgs['ipv4_address']
        if stgs.get('ipv4_gateway', None):
            cmdargs['ipv4.gateway'] = stgs['ipv4_gateway']
        if stgs.get('ipv6_gateway', None):
            cmdargs['ipv6.gateway'] = stgs['ipv6_gateway']
        if stgs.get('mtu', None):
            cmdargs['802-3-ethernet.mtu'] = stgs['mtu']
        dnsips = self.deploycfg.get('nameservers', [])
        if not dnsips:
            dnsips = []
        dns4 = []
        dns6 = []
        for dnsip in dnsips:
            if '.' in dnsip:
                dns4.append(dnsip)
            elif ':' in dnsip:
                dns6.append(dnsip)
        if dns4:
            cmdargs['ipv4.dns'] = ','.join(dns4)
        if dns6:
            cmdargs['ipv6.dns'] = ','.join(dns6)
        if len(cfg['interfaces']) > 1:  # team time.. should be..
            if not cfg['settings'].get('team_mode', None):
                sys.stderr.write("Warning, multiple interfaces ({0}) without a team_mode, skipping setup\n".format(','.join(cfg['interfaces'])))
                return
            if not cfg['settings'].get('connection_name', None):
                cfg['settings']['connection_name'] = 'team{0}'.format(self.teamidx)
                self.teamidx += 1
            cname = cfg['settings']['connection_name']
            cargs = []
            for arg in cmdargs:
                cargs.append(arg)
                cargs.append('{}'.format(cmdargs[arg]))
            if stgs['team_mode'] in self.bondtypes:
                stgs['team_mode'] = self.bondtypes[stgs['team_mode']]
            subprocess.check_call(['nmcli', 'c', 'add', 'type', 'bond', 'con-name', cname, 'connection.interface-name', cname, 'bond.options', 'miimon=100,mode={}'.format(stgs['team_mode'])] + cargs)
            for iface in cfg['interfaces']:
                self.add_team_member(cname, iface)
            subprocess.check_call(['nmcli', 'c', 'u', cname])
        else:
            cname = stgs.get('connection_name', None)
            iname = list(cfg['interfaces'])[0]
            ctype = self.devtypes.get(iname, None)
            if not ctype:
                if lastchance:
                    sys.stderr.write("Warning, no device found for interface_name ({0}), skipping setup\n".format(iname))
                return 1
            if stgs.get('vlan_id', None):
                vlan = stgs['vlan_id']
                if ctype == 'infiniband':
                    vlan = '0x{0}'.format(vlan) if not vlan.startswith('0x') else vlan
                    cmdargs['infiniband.parent'] = iname
                    cmdargs['infiniband.p-key'] = vlan
                    iname = '{0}.{1}'.format(iname, vlan[2:])
                elif ctype == 'ethernet':
                    ctype = 'vlan'
                    cmdargs['vlan.parent'] = iname
                    cmdargs['vlan.id'] = vlan
                    iname = '{0}.{1}'.format(iname, vlan)
                else:
                    sys.stderr.write("Warning, unknown interface_name ({0}) device type ({1}) for VLAN/PKEY, skipping setup\n".format(iname, ctype))
                    return
            cname = iname if not cname else cname
            u = self.uuidbyname.get(cname, None)
            cargs = []
            for arg in cmdargs:
                cargs.append(arg)
                cargs.append('{}'.format(cmdargs[arg]))
            if u:
                subprocess.check_call(['nmcli', 'c', 'm', u, 'connection.interface-name', iname] + cargs)
                subprocess.check_call(['nmcli', 'c', 'u', u])
            else:
                subprocess.check_call(['nmcli', 'c', 'add', 'type', ctype, 'con-name', cname, 'connection.interface-name', iname] + cargs)
                self.read_connections()
                u = self.uuidbyname.get(cname, None)
                if u:
                    subprocess.check_call(['nmcli', 'c', 'u', u])



if __name__ == '__main__':
    checktarg = None
    if '-c' in sys.argv:
        checktarg = sys.argv[sys.argv.index('-c') + 1]
    havefirewall = subprocess.call(['systemctl', 'status', 'firewalld'])
    havefirewall = havefirewall == 0
    if havefirewall:
        subprocess.check_call(['systemctl', 'stop', 'firewalld'])
    tmpllas = add_missing_llas()
    await_tentative()
    idxmap, devtypes = map_idx_to_name()
    netname_to_interfaces = {}
    myaddrs = apiclient.get_my_addresses()
    srvs, _ = apiclient.scan_confluents()
    doneidxs = set([])
    dc = None
    if not srvs:  # the multicast scan failed, fallback to deploycfg cfg file
        with open('/etc/confluent/confluent.deploycfg', 'r') as dci:
            for cfgline in dci.read().split('\n'):
                if cfgline.startswith('deploy_server:'):
                    srvs = [cfgline.split()[1]]
                    break
    for srv in srvs:
        try:
            s = socket.create_connection((srv, 443))
        except socket.error:
            continue
        myname = s.getsockname()
        s.close()
        curridx = None
        if len(myname) == 4:
            curridx = myname[-1]
        else:
            myname = myname[0]
            myname = socket.inet_pton(socket.AF_INET, myname)
            for addr in myaddrs:
                if myname == addr[1]:
                    curridx = addr[-1]
        if curridx is not None and curridx in doneidxs:
            continue
        for tries in (1, 2, 3):
            try:
                status, nc = apiclient.HTTPSClient(usejson=True, host=srv).grab_url_with_status('/confluent-api/self/netcfg')
                break
            except Exception:
                if tries == 3:
                    raise
                time.sleep(1)
                continue
        nc = json.loads(nc)
        if not dc:
            for tries in (1, 2, 3):
                try:
                    status, dc = apiclient.HTTPSClient(usejson=True, host=srv).grab_url_with_status('/confluent-api/self/deploycfg2')
                    break
                except Exception:
                    if tries == 3:
                        raise
                    time.sleep(1)
                    continue
            dc = json.loads(dc)
        iname = get_interface_name(idxmap[curridx], nc.get('default', {}))
        if iname:
            for iname in iname.split(','):
                if 'default' in netname_to_interfaces:
                    netname_to_interfaces['default']['interfaces'].add(iname)
                else:
                    netname_to_interfaces['default'] = {'interfaces': set([iname]), 'settings': nc['default']}
        for netname in nc.get('extranets', {}):
            uname = '_' + netname
            iname = get_interface_name(idxmap[curridx], nc['extranets'][netname])
            if iname:
                for iname in iname.split(','):
                    if uname in netname_to_interfaces:
                        netname_to_interfaces[uname]['interfaces'].add(iname)
                    else:
                        netname_to_interfaces[uname] = {'interfaces': set([iname]), 'settings': nc['extranets'][netname]}
        doneidxs.add(curridx)
    if 'default' in netname_to_interfaces:
        for netn in netname_to_interfaces:
            if netn == 'default':
                continue
            netname_to_interfaces['default']['interfaces'] -= netname_to_interfaces[netn]['interfaces']
        if not netname_to_interfaces['default']['interfaces']:
            del netname_to_interfaces['default']
    # Make sure VLAN/PKEY connections are created last
    netname_to_interfaces = dict(sorted(netname_to_interfaces.items(), key=lambda item: 'vlan_id' in item[1]['settings']))
    rm_tmp_llas(tmpllas)
    if os.path.exists('/usr/sbin/netplan'):
        nm = NetplanManager(dc)
    elif os.path.exists('/usr/bin/nmcli'):
        nm = NetworkManager(devtypes, dc)
    elif os.path.exists('/usr/sbin/wicked'):
        nm = WickedManager()
    retrynics = []
    for netn in netname_to_interfaces:
        redo = nm.apply_configuration(netname_to_interfaces[netn])
        if redo == 1:
            retrynics.append(netn)
    if retrynics:
        idxmap, devtypes = map_idx_to_name()
        if os.path.exists('/usr/sbin/netplan'):
            nm = NetplanManager(dc)
        if os.path.exists('/usr/bin/nmcli'):
            nm = NetworkManager(devtypes, dc)
        elif os.path.exists('/usr/sbin/wicked'):
            nm = WickedManager()
        for netn in retrynics:
            nm.apply_configuration(netname_to_interfaces[netn], lastchance=True)
    if havefirewall:
        subprocess.check_call(['systemctl', 'start', 'firewalld'])
    await_tentative()
    maxwait = 10
    while maxwait:
        try:
            tclient = apiclient.HTTPSClient(checkonly=True)
            tclient.check_connections()
            break
        except Exception:
            maxwait -= 1
            time.sleep(1)
    maxwait = 10
    if checktarg:
        while maxwait:
            try:
                addrinf = socket.getaddrinfo(checktarg, 443)[0]
                psock = socket.socket(addrinf[0], socket.SOCK_STREAM)
                psock.settimeout(10)
                psock.connect(addrinf[4])
                psock.close()
                break
            except Exception:
                maxwait -= 1
                time.sleep(1)


