# coding: utf-8
from __future__ import print_function, unicode_literals

import argparse
import json
import os
import re
import time

from .__init__ import ANYWIN, MACOS
from .authsrv import AXS, VFS, AuthSrv
from .bos import bos
from .util import chkcmd, json_hesc, min_ex, undot

class Fstab(object):
    def __init__(self, log , args , verbose ):
        self.log_func = log
        self.verbose = verbose

        self.warned = False
        self.trusted = False
        self.tab  = None
        self.oldtab  = None
        self.srctab = "a"
        self.cache    = {}
        self.age = 0.0
        self.maxage = args.mtab_age

    def log(self, msg , c   = 0)  :
        if not c or self.verbose:
            return
        self.log_func("fstab", msg, c)

    def get(self, path )   :
        now = time.time()
        if now - self.age > self.maxage or len(self.cache) > 9000:
            self.age = now
            self.oldtab = self.tab or self.oldtab
            self.tab = None
            self.cache = {}

        mp = ""
        fs = "ext4"
        msg = "failed to determine filesystem at %r; assuming %s\n%s"

        if ANYWIN:
            fs = "vfat"
            try:
                path = self._winpath(path)
            except:
                self.log(msg % (path, fs, min_ex()), 3)
                return fs, ""

        path = undot(path)
        try:
            return self.cache[path]
        except:
            pass

        try:
            fs, mp = self.get_w32(path) if ANYWIN else self.get_unix(path)
        except:
            self.log(msg % (path, fs, min_ex()), 3)

        fs = fs.lower()
        self.cache[path] = (fs, mp)
        self.log("found %s at %r, %r" % (fs, mp, path))
        return fs, mp

    def _winpath(self, path )  :

        path = path.replace("/", "\\")
        vid = path.split(":", 1)[0].strip("\\").split("\\", 1)[0]
        try:
            return "{}*{}".format(vid, bos.stat(path).st_dev)
        except:
            return vid

    def build_fallback(self)  :
        self.tab = VFS(self.log_func, "idk", "/", "/", AXS(), {})
        self.trusted = False

    def _from_sp_mount(self)   :
        sptn = r"^.*? on (.*) type ([^ ]+) \(.*"
        if MACOS:
            sptn = r"^.*? on (.*) \(([^ ]+), .*"

        ptn = re.compile(sptn)
        so, _ = chkcmd(["mount"])
        dtab   = {}
        for ln in so.split("\n"):
            m = ptn.match(ln)
            if not m:
                continue

            zs1, zs2 = m.groups()
            dtab[str(zs1)] = str(zs2)

        return dtab

    def _from_proc(self)   :
        ret   = {}
        with open("/proc/self/mounts", "rb", 262144) as f:
            src = f.read(262144).decode("utf-8", "replace").split("\n")
        for zsl in [x.split(" ") for x in src]:
            if len(zsl) < 3:
                continue
            zs = zsl[1]
            zs = zs.replace("\\011", "\t").replace("\\040", " ").replace("\\134", "\\")
            ret[zs] = zsl[2]
        return ret

    def build_tab(self)  :
        self.log("inspecting mtab for changes")
        dtab = self._from_sp_mount() if MACOS else self._from_proc()

        srctab = str(sorted(dtab.items()))
        if srctab == self.srctab:
            self.tab = self.oldtab
            return

        self.log("mtab has changed; reevaluating support for sparse files")

        try:
            fuses = [mp for mp, fs in dtab.items() if fs == "fuseblk"]
            if not fuses or MACOS:
                raise Exception()
            try:
                so, _ = chkcmd(["lsblk", "-nrfo", "FSTYPE,MOUNTPOINT"])
            except:
                so, _ = chkcmd(["lsblk", "-nrfo", "FSTYPE,MOUNTPOINTS"])
            for ln in so.split("\n"):
                zsl = ln.split(" ", 1)
                if len(zsl) != 2:
                    continue
                fs, mp = zsl
                if mp in fuses:
                    dtab[mp] = fs
        except:
            pass

        tab1 = list(dtab.items())
        tab1.sort(key=lambda x: (len(x[0]), x[0]))
        path1, fs1 = tab1[0]
        tab = VFS(self.log_func, fs1, path1, path1, AXS(), {})
        for path, fs in tab1[1:]:
            zs = path.lstrip("/")
            tab.add(fs, zs, zs)

        self.tab = tab
        self.srctab = srctab

    def relabel(self, path , nval )  :
        self.cache = {}
        if ANYWIN:
            path = self._winpath(path)

        path = undot(path)
        ptn = re.compile(r"^[^\\/]*")
        vn, rem = self.tab._find(path)
        if not self.trusted:

            if "/" in rem:
                zs = os.path.join(vn.vpath, rem.split("/")[0])
                self.tab.add("idk", zs, zs)
            if rem:
                self.tab.add(nval, path, path)
            else:
                vn.realpath = nval

            return

        visit = [vn]
        while visit:
            vn = visit.pop()
            vn.realpath = ptn.sub(nval, vn.realpath)
            visit.extend(list(vn.nodes.values()))

    def get_unix(self, path )   :
        if not self.tab:
            try:
                self.build_tab()
                self.trusted = True
            except:

                if not self.warned:
                    self.warned = True
                    t = "failed to associate fs-mounts with the VFS (this is fine):\n%s"
                    self.log(t % (min_ex(),), 6)
                self.build_fallback()

        ret = self.tab._find(path)[0]
        if self.trusted or path == ret.vpath:
            return ret.realpath.split("/")[0], ret.vpath
        else:
            return "idk", ""

    def get_w32(self, path )   :
        if not self.tab:
            self.build_fallback()

        ret = self.tab._find(path)[0]
        return ret.realpath, ""


_fstab  = None
winfs = set(("msdos", "vfat", "ntfs", "exfat"))



def ramdisk_chk(asrv )  :

    global _fstab
    mods = []
    ramfs = ("tmpfs", "overlay")
    log = asrv.log_func or print
    if not _fstab:
        _fstab = Fstab(log, asrv.args, False)
    for vn in asrv.vfs.all_nodes.values():
        if not vn.axs.uwrite or "wram" in vn.flags:
            continue
        ap = vn.realpath
        if not ap or os.path.isfile(ap):
            continue
        fs, mp = _fstab.get(ap)
        mp = "/" + mp.strip("/")
        if fs == "tmpfs" or (mp == "/" and fs in ramfs):
            mods.append((vn.vpath, ap, fs, mp))
            vn.axs.uwrite.clear()
            vn.axs.umove.clear()
            for un, ztsp in list(vn.uaxs.items()):
                zsl = list(ztsp)
                zsl[1] = False
                zsl[2] = False
                vn.uaxs[un] = tuple(zsl)
    if mods:
        t = "WARNING: write-access was removed from the following volumes because they are not mapped to an actual HDD for storage! All uploaded data would live in RAM only, and all uploaded files would be LOST on next reboot. To allow uploading and ignore this hazard, enable the 'wram' option (global/volflag). List of affected volumes:"
        t2 = ["\n  volume=[/%s], abspath=%r, type=%s, root=%r" % x for x in mods]
        log("vfs", t + "".join(t2) + "\n", 1)

    assume = "mac" if MACOS else "lin"
    for vol in asrv.vfs.all_nodes.values():
        if not vol.realpath or vol.flags.get("is_file"):
            continue
        zs = vol.flags["fsnt"].strip()[:3].lower()
        if ANYWIN and not zs:
            zs = "win"
        if zs in ("lin", "win", "mac"):
            vol.flags["fsnt"] = zs
            continue
        fs = _fstab.get(vol.realpath)[0]
        fs = "win" if fs in winfs else assume
        htm = json.loads(vol.js_htm)
        vol.flags["fsnt"] = vol.js_ls["fsnt"] = htm["fsnt"] = fs
        vol.js_htm = json_hesc(json.dumps(htm))
