#!/usr/bin/env python3
"""systemctl: subset of systemctl used for image construction

Mask/preset systemd units
"""

import argparse
import fnmatch
import os
import re
import sys

from collections import namedtuple
from pathlib import Path

version = 1.0

ROOT = Path("/")
SYSCONFDIR = Path("etc")
BASE_LIBDIR = Path("lib")
LIBDIR = Path("usr", "lib")

locations = list()


class SystemdFile():
    """Class representing a single systemd configuration file"""
    def __init__(self, root, path):
        self.sections = dict()
        self._parse(root, path)
        dirname = os.path.basename(path.name) + ".d"
        for location in locations:
            for path2 in sorted((root / location / "system" / dirname).glob("*.conf")):                
                self._parse(root, path2)

    def _parse(self, root, path):
        """Parse a systemd syntax configuration file

        Args:
            path: A pathlib.Path object pointing to the file

        """
        skip_re = re.compile(r"^\s*([#;]|$)")
        section_re = re.compile(r"^\s*\[(?P<section>.*)\]")
        kv_re = re.compile(r"^\s*(?P<key>[^\s]+)\s*=\s*(?P<value>.*)")
        section = None

        if path.is_symlink():
            try:
                path.resolve()
            except FileNotFoundError:
                # broken symlink, try relative to root
                path = root / Path(os.readlink(str(path))).relative_to(ROOT)

        with path.open() as f:
            for line in f:
                if skip_re.match(line):
                    continue

                line = line.strip()
                m = section_re.match(line)
                if m:
                    if m.group('section') not in self.sections:
                        section = dict()
                        self.sections[m.group('section')] = section
                    else:
                        section = self.sections[m.group('section')]
                    continue

                while line.endswith("\\"):
                    line += f.readline().rstrip("\n")

                m = kv_re.match(line)
                k = m.group('key')
                v = m.group('value')
                if k not in section:
                    section[k] = list()
                section[k].extend(v.split())

    def get(self, section, prop):
        """Get a property from section

        Args:
            section: Section to retrieve property from
            prop: Property to retrieve

        Returns:
            List representing all properties of type prop in section.

        Raises:
            KeyError: if ``section`` or ``prop`` not found
        """
        return self.sections[section][prop]


class Presets():
    """Class representing all systemd presets"""
    def __init__(self, scope, root):
        self.directives = list()
        self._collect_presets(scope, root)

    def _parse_presets(self, presets):
        """Parse presets out of a set of preset files"""
        skip_re = re.compile(r"^\s*([#;]|$)")
        directive_re = re.compile(r"^\s*(?P<action>enable|disable)\s+(?P<unit_name>(.+))")

        Directive = namedtuple("Directive", "action unit_name")
        for preset in presets:
            with preset.open() as f:
                for line in f:
                    m = directive_re.match(line)
                    if m:
                        directive = Directive(action=m.group('action'),
                                              unit_name=m.group('unit_name'))
                        self.directives.append(directive)
                    elif skip_re.match(line):
                        pass
                    else:
                        sys.exit("Unparsed preset line in {}".format(preset))

    def _collect_presets(self, scope, root):
        """Collect list of preset files"""
        presets = dict()
        for location in locations:
            paths = (root / location / scope).glob("*.preset")
            for path in paths:
                # earlier names override later ones
                if path.name not in presets:
                    presets[path.name] = path

        self._parse_presets([v for k, v in sorted(presets.items())])

    def state(self, unit_name):
        """Return state of preset for unit_name

        Args:
            presets: set of presets
            unit_name: name of the unit

        Returns:
            None: no matching preset
            `enable`: unit_name is enabled
            `disable`: unit_name is disabled
        """
        for directive in self.directives:
            if fnmatch.fnmatch(unit_name, directive.unit_name):
                return directive.action

        return None


def add_link(path, target):
    try:
        path.parent.mkdir(parents=True)
    except FileExistsError:
        pass
    if not path.is_symlink():
        print("ln -s {} {}".format(target, path))
        path.symlink_to(target)


class SystemdUnitNotFoundError(Exception):
    pass


class SystemdUnit():
    def __init__(self, root, unit):
        self.root = root
        self.unit = unit
        self.config = None

    def _path_for_unit(self, unit):
        for location in locations:
            path = self.root / location / "system" / unit
            if path.exists() or path.is_symlink():
                return path

        raise SystemdUnitNotFoundError(self.root, unit)

    def _process_deps(self, config, service, location, prop, dirstem):
        systemdir = self.root / SYSCONFDIR / "systemd" / "system"

        target = ROOT / location.relative_to(self.root)
        try:
            for dependent in config.get('Install', prop):
                wants = systemdir / "{}.{}".format(dependent, dirstem) / service
                add_link(wants, target)

        except KeyError:
            pass

    def enable(self):
        # if we're enabling an instance, first extract the actual instance
        # then figure out what the template unit is
        template = re.match(r"[^@]+@(?P<instance>[^\.]*)\.", self.unit)
        if template:
            instance = template.group('instance')
            unit = re.sub(r"@[^\.]*\.", "@.", self.unit, 1)
        else:
            instance = None
            unit = self.unit

        path = self._path_for_unit(unit)

        if path.is_symlink():
            # ignore aliases
            return

        config = SystemdFile(self.root, path)
        if instance == "":
            try:
                default_instance = config.get('Install', 'DefaultInstance')[0]
            except KeyError:
                # no default instance, so nothing to enable
                return

            service = self.unit.replace("@.",
                                        "@{}.".format(default_instance))
        else:
            service = self.unit

        self._process_deps(config, service, path, 'WantedBy', 'wants')
        self._process_deps(config, service, path, 'RequiredBy', 'requires')

        try:
            for also in config.get('Install', 'Also'):
                SystemdUnit(self.root, also).enable()

        except KeyError:
            pass

        systemdir = self.root / SYSCONFDIR / "systemd" / "system"
        target = ROOT / path.relative_to(self.root)
        try:
            for dest in config.get('Install', 'Alias'):
                alias = systemdir / dest
                add_link(alias, target)

        except KeyError:
            pass

    def mask(self):
        systemdir = self.root / SYSCONFDIR / "systemd" / "system"
        add_link(systemdir / self.unit, "/dev/null")


def collect_services(root):
    """Collect list of service files"""
    services = set()
    for location in locations:
        paths = (root / location / "system").glob("*")
        for path in paths:
            if path.is_dir():
                continue
            services.add(path.name)

    return services


def preset_all(root):
    presets = Presets('system-preset', root)
    services = collect_services(root)

    for service in services:
        state = presets.state(service)

        if state == "enable" or state is None:
            SystemdUnit(root, service).enable()

    # If we populate the systemd links we also create /etc/machine-id, which
    # allows systemd to boot with the filesystem read-only before generating
    # a real value and then committing it back.
    #
    # For the stateless configuration, where /etc is generated at runtime
    # (for example on a tmpfs), this script shouldn't run at all and we
    # allow systemd to completely populate /etc.
    (root / SYSCONFDIR / "machine-id").touch()


def main():
    if sys.version_info < (3, 4, 0):
        sys.exit("Python 3.4 or greater is required")

    parser = argparse.ArgumentParser()
    parser.add_argument('command', nargs=1, choices=['enable', 'mask',
                                                     'preset-all'])
    parser.add_argument('service', nargs=argparse.REMAINDER)
    parser.add_argument('--root')
    parser.add_argument('--preset-mode',
                        choices=['full', 'enable-only', 'disable-only'],
                        default='full')

    args = parser.parse_args()

    root = Path(args.root) if args.root else ROOT

    locations.append(SYSCONFDIR / "systemd")
    # Handle the usrmerge case by ignoring /lib when it's a symlink
    if not (root / BASE_LIBDIR).is_symlink():
        locations.append(BASE_LIBDIR / "systemd")
    locations.append(LIBDIR / "systemd")

    command = args.command[0]
    if command == "mask":
        for service in args.service:
            SystemdUnit(root, service).mask()
    elif command == "enable":
        for service in args.service:
            SystemdUnit(root, service).enable()
    elif command == "preset-all":
        if len(args.service) != 0:
            sys.exit("Too many arguments.")
        if args.preset_mode != "enable-only":
            sys.exit("Only enable-only is supported as preset-mode.")
        preset_all(root)
    else:
        raise RuntimeError()


if __name__ == '__main__':
    main()