import json
import os
import time

from wandb_gql import gql

import wandb
from wandb import util
from wandb.apis.internal import Api
from wandb.sdk import lib as wandb_lib
from wandb.sdk.data_types.utils import val_to_json

DEEP_SUMMARY_FNAME = "wandb.h5"
H5_TYPES = ("numpy.ndarray", "tensorflow.Tensor", "torch.Tensor")
h5py = util.get_module("h5py")
np = util.get_module("numpy")


class SummarySubDict:
    """Nested dict-like object that proxies read and write operations through a root object.

    This lets us do synchronous serialization and lazy loading of large values.
    """

    def __init__(self, root=None, path=()):
        self._path = tuple(path)
        if root is None:
            self._root = self
            self._json_dict = {}
        else:
            self._root = root
            json_dict = root._json_dict
            for k in path:
                json_dict = json_dict.get(k, {})

            self._json_dict = json_dict
        self._dict = {}

        # We use this to track which keys the user has set explicitly
        # so that we don't automatically overwrite them when we update
        # the summary from the history.
        self._locked_keys = set()

    def __setattr__(self, k, v):
        k = k.strip()
        if k.startswith("_"):
            object.__setattr__(self, k, v)
        else:
            self[k] = v

    def __getattr__(self, k):
        k = k.strip()
        if k.startswith("_"):
            return object.__getattribute__(self, k)
        else:
            return self[k]

    def _root_get(self, path, child_dict):
        """Load a value at a particular path from the root.

        This should only be implemented by the "_root" child class.

        We pass the child_dict so the item can be set on it or not as
        appropriate. Returning None for a nonexistent path wouldn't be
        distinguishable from that path being set to the value None.
        """
        raise NotImplementedError

    def _root_set(self, path, new_keys_values):
        """Set a value at a particular path in the root.

        This should only be implemented by the "_root" child class.
        """
        raise NotImplementedError

    def _root_del(self, path):
        """Delete a value at a particular path in the root.

        This should only be implemented by the "_root" child class.
        """
        raise NotImplementedError

    def _write(self, commit=False):
        # should only be implemented on the root summary
        raise NotImplementedError

    def keys(self):
        # _json_dict has the full set of keys, including those for h5 objects
        # that may not have been loaded yet
        return self._json_dict.keys()

    def get(self, k, default=None):
        if isinstance(k, str):
            k = k.strip()
        if k not in self._dict:
            self._root._root_get(self._path + (k,), self._dict)
        return self._dict.get(k, default)

    def items(self):
        # not all items may be loaded into self._dict, so we
        # have to build the sequence of items from scratch
        for k in self.keys():
            yield k, self[k]

    def __getitem__(self, k):
        if isinstance(k, str):
            k = k.strip()

        self.get(k)  # load the value into _dict if it should be there
        res = self._dict[k]

        return res

    def __contains__(self, k):
        if isinstance(k, str):
            k = k.strip()

        return k in self._json_dict

    def __setitem__(self, k, v):
        if isinstance(k, str):
            k = k.strip()

        path = self._path

        if isinstance(v, dict):
            self._dict[k] = SummarySubDict(self._root, path + (k,))
            self._root._root_set(path, [(k, {})])
            self._dict[k].update(v)
        else:
            self._dict[k] = v
            self._root._root_set(path, [(k, v)])

        self._locked_keys.add(k)

        self._root._write()

        return v

    def __delitem__(self, k):
        k = k.strip()
        del self._dict[k]
        self._root._root_del(self._path + (k,))

        self._root._write()

    def __repr__(self):
        # use a copy of _dict, except add placeholders for h5 objects, etc.
        # that haven't been loaded yet
        repr_dict = dict(self._dict)
        for k in self._json_dict:
            v = self._json_dict[k]
            if (
                k not in repr_dict
                and isinstance(v, dict)
                and v.get("_type") in H5_TYPES
            ):
                # unloaded h5 objects may be very large. use a placeholder for them
                # if we haven't already loaded them
                repr_dict[k] = "..."
            else:
                repr_dict[k] = self[k]

        return repr(repr_dict)

    def update(self, key_vals=None, overwrite=True):
        """Locked keys will be overwritten unless overwrite=False.

        Otherwise, written keys will be added to the "locked" list.
        """
        if key_vals:
            write_items = self._update(key_vals, overwrite)
            self._root._root_set(self._path, write_items)
        self._root._write(commit=True)

    def _update(self, key_vals, overwrite):
        if not key_vals:
            return
        key_vals = {k.strip(): v for k, v in key_vals.items()}
        if overwrite:
            write_items = list(key_vals.items())
            self._locked_keys.update(key_vals.keys())
        else:
            write_keys = set(key_vals.keys()) - self._locked_keys
            write_items = [(k, key_vals[k]) for k in write_keys]

        for key, value in write_items:
            if isinstance(value, dict):
                self._dict[key] = SummarySubDict(self._root, self._path + (key,))
                self._dict[key]._update(value, overwrite)
            else:
                self._dict[key] = value

        return write_items


class Summary(SummarySubDict):
    """Store summary metrics (eg. accuracy) during and after a run.

    You can manipulate this as if it's a Python dictionary but the keys
    get mangled. .strip() is called on them, so spaces at the beginning
    and end are removed.
    """

    def __init__(self, run, summary=None):
        super().__init__()
        self._run = run
        self._h5_path = os.path.join(self._run.dir, DEEP_SUMMARY_FNAME)
        # Lazy load the h5 file
        self._h5 = None

        # Mirrored version of self._dict with versions of values that get written
        # to JSON kept up to date by self._root_set() and self._root_del().
        self._json_dict = {}

        if summary is not None:
            self._json_dict = summary

    def _json_get(self, path):
        pass

    def _root_get(self, path, child_dict):
        json_dict = self._json_dict
        for key in path[:-1]:
            json_dict = json_dict[key]

        key = path[-1]
        if key in json_dict:
            child_dict[key] = self._decode(path, json_dict[key])

    def _root_del(self, path):
        json_dict = self._json_dict
        for key in path[:-1]:
            json_dict = json_dict[key]

        val = json_dict[path[-1]]
        del json_dict[path[-1]]
        if isinstance(val, dict) and val.get("_type") in H5_TYPES:
            if not h5py:
                wandb.termerror("Deleting tensors in summary requires h5py")
            else:
                self.open_h5()
                h5_key = "summary/" + ".".join(path)
                del self._h5[h5_key]
                self._h5.flush()

    def _root_set(self, path, new_keys_values):
        json_dict = self._json_dict
        for key in path:
            json_dict = json_dict[key]

        for new_key, new_value in new_keys_values:
            json_dict[new_key] = self._encode(new_value, path + (new_key,))

    def write_h5(self, path, val):
        # ensure the file is open
        self.open_h5()

        if not self._h5:
            wandb.termerror("Storing tensors in summary requires h5py")
        else:
            try:
                del self._h5["summary/" + ".".join(path)]
            except KeyError:
                pass
            self._h5["summary/" + ".".join(path)] = val
            self._h5.flush()

    def read_h5(self, path, val=None):
        # ensure the file is open
        self.open_h5()

        if not self._h5:
            wandb.termerror("Reading tensors from summary requires h5py")
        else:
            return self._h5.get("summary/" + ".".join(path), val)

    def open_h5(self):
        if not self._h5 and h5py:
            self._h5 = h5py.File(self._h5_path, "a", libver="latest")

    def _decode(self, path, json_value):
        """Decode a `dict` encoded by `Summary._encode()`, loading h5 objects.

        h5 objects may be very large, so we won't have loaded them automatically.
        """
        if isinstance(json_value, dict):
            if json_value.get("_type") in H5_TYPES:
                return self.read_h5(path, json_value)
            elif json_value.get("_type") == "data-frame":
                wandb.termerror(
                    "This data frame was saved via the wandb data API. Contact support@wandb.com for help."
                )
                return None
            # TODO: transform wandb objects and plots
            else:
                return SummarySubDict(self, path)
        else:
            return json_value

    def _encode(self, value, path_from_root):
        """Normalize, compress, and encode sub-objects for backend storage.

        value: Object to encode.
        path_from_root: `tuple` of key strings from the top-level summary to the
            current `value`.

        Returns:
            A new tree of dict's with large objects replaced with dictionaries
            with "_type" entries that say which type the original data was.
        """

        # Constructs a new `dict` tree in `json_value` that discards and/or
        # encodes objects that aren't JSON serializable.

        if isinstance(value, dict):
            json_value = {}
            for key, value in value.items():
                json_value[key] = self._encode(value, path_from_root + (key,))
            return json_value
        else:
            path = ".".join(path_from_root)
            friendly_value, converted = util.json_friendly(
                val_to_json(self._run, path, value, namespace="summary")
            )
            json_value, compressed = util.maybe_compress_summary(
                friendly_value, util.get_h5_typename(value)
            )
            if compressed:
                self.write_h5(path_from_root, friendly_value)

            return json_value


def download_h5(run_id, entity=None, project=None, out_dir=None):
    api = Api()
    meta = api.download_url(
        project or api.settings("project"),
        DEEP_SUMMARY_FNAME,
        entity=entity or api.settings("entity"),
        run=run_id,
    )
    if meta and "md5" in meta and meta["md5"] is not None:
        # TODO: make this non-blocking
        wandb.termlog("Downloading summary data...")
        path, res = api.download_write_file(meta, out_dir=out_dir)
        return path


def upload_h5(file, run_id, entity=None, project=None):
    api = Api()
    wandb.termlog("Uploading summary data...")
    with open(file, "rb") as f:
        api.push(
            {os.path.basename(file): f}, run=run_id, project=project, entity=entity
        )


class FileSummary(Summary):
    def __init__(self, run):
        super().__init__(run)
        self._fname = os.path.join(run.dir, wandb_lib.filenames.SUMMARY_FNAME)
        self.load()

    def load(self):
        try:
            with open(self._fname) as f:
                self._json_dict = json.load(f)
        except (OSError, ValueError):
            self._json_dict = {}

    def _write(self, commit=False):
        # TODO: we just ignore commit to ensure backward capability
        with open(self._fname, "w") as f:
            f.write(util.json_dumps_safer(self._json_dict))
            f.write("\n")
            f.flush()
            os.fsync(f.fileno())
        if self._h5:
            self._h5.close()
            self._h5 = None


class HTTPSummary(Summary):
    def __init__(self, run, client, summary=None):
        super().__init__(run, summary=summary)
        self._run = run
        self._client = client
        self._started = time.time()

    def __delitem__(self, key):
        if key not in self._json_dict:
            raise KeyError(key)
        del self._json_dict[key]

    def load(self):
        pass

    def open_h5(self):
        if not self._h5 and h5py:
            download_h5(
                self._run.id,
                entity=self._run.entity,
                project=self._run.project,
                out_dir=self._run.dir,
            )
        super().open_h5()

    def _write(self, commit=False):
        mutation = gql(
            """
        mutation UpsertBucket( $id: String, $summaryMetrics: JSONString) {
            upsertBucket(input: { id: $id, summaryMetrics: $summaryMetrics}) {
                bucket { id }
            }
        }
        """
        )
        if commit:
            if self._h5:
                self._h5.close()
                self._h5 = None
            res = self._client.execute(
                mutation,
                variable_values={
                    "id": self._run.storage_id,
                    "summaryMetrics": util.json_dumps_safer(self._json_dict),
                },
            )
            assert res["upsertBucket"]["bucket"]["id"]
            entity, project, run = self._run.path
            if (
                os.path.exists(self._h5_path)
                and os.path.getmtime(self._h5_path) >= self._started
            ):
                upload_h5(self._h5_path, run, entity=entity, project=project)
        else:
            return False
