"""Utility methods for gruut"""
import itertools
import logging
import os
import re
import ssl
import typing
import xml.etree.ElementTree as etree
from pathlib import Path
from urllib.request import urlopen

import networkx as nx
from gruut_ipa import IPA

from gruut.const import (
    DATA_PROP,
    LANG_ALIASES,
    NODE_TYPE,
    EndElement,
    GraphType,
    InlineLexicon,
    Lexeme,
    Node,
    WordRole,
)
from gruut.resources import _DIR

_LOGGER = logging.getLogger("gruut.utils")

# -----------------------------------------------------------------------------
# Language utilities
# -----------------------------------------------------------------------------

LANG_SPLIT_PATTERN = re.compile(r"[-_]")


def resolve_lang(lang: str) -> str:
    """
    Try to resolve language using aliases.

    Args:
        lang: Language name or alias

    Returns:
        Resolved language name
    """
    lang = lang.lower().replace("_", "-")

    return LANG_ALIASES.get(lang, lang)


def find_lang_dir(
    lang: str,
    search_dirs: typing.Optional[typing.Iterable[typing.Union[str, Path]]] = None,
) -> typing.Optional[Path]:
    """
    Search for a language's model directory by name.

    Tries to find a directory by:

    #. Importing a module name ``gruut_lang_<short_lang>`` where short_lang is "en" for "en-us", etc.
    #. Looking for ``<lang>/lexicon.db`` in each directory in order:

       * ``search_dirs``
       * ``$XDG_CONFIG_HOME/gruut``
       * A "data" directory next to the gruut module

    Args:
        lang: Full language name (e.g., en-us)
        search_dirs: Optional iterable of directory paths to search first

    Returns:
        Path to the language model directory or None if it can't be found
    """
    base_lang = LANG_SPLIT_PATTERN.split(lang)[0].lower()
    lang_module_name = f"gruut_lang_{base_lang}"

    try:
        lang_module = __import__(lang_module_name)

        _LOGGER.debug("(%s) successfully imported %s", lang, lang_module_name)

        return lang_module.get_lang_dir()
    except ImportError:
        _LOGGER.debug("(%s) couldn't import module %s", lang, lang_module_name)
        pass

    search_dirs = typing.cast(typing.List[Path], [Path(p) for p in search_dirs or []])

    # ${XDG_CONFIG_HOME}/gruut or ${HOME}/gruut
    maybe_config_home = os.environ.get("XDG_CONFIG_HOME")
    if maybe_config_home:
        search_dirs.append(Path(maybe_config_home) / "gruut")
    else:
        search_dirs.append(Path.home() / ".config" / "gruut")

    # Data directory *next to* gruut
    search_dirs.append(_DIR.parent / "data")

    _LOGGER.debug("(%s) searching %s for language file(s)", lang, search_dirs)

    for check_dir in search_dirs:
        lang_dir = check_dir / lang
        lexicon_path = lang_dir / "lexicon.db"
        if lexicon_path.is_file():
            _LOGGER.debug("(%s) found language file(s) in %s", lang, lang_dir)
            return lang_dir

    return None


# -----------------------------------------------------------------------------
# Babel
# -----------------------------------------------------------------------------


def get_currency_names(locale_str: str) -> typing.Dict[str, str]:
    """
    Try to get currency names and symbols for a Babel locale.

    Returns:
        Dictionary whose keys are currency symbols (like "$") and whose values are currency names (like "USD")
    """
    currency_names = {}

    try:
        import babel
        import babel.numbers

        locale = babel.Locale(locale_str)
        currency_names = {
            babel.numbers.get_currency_symbol(cn): cn for cn in locale.currency_symbols
        }
    except ImportError:
        # Expected if babel is not installed
        pass
    except Exception:
        _LOGGER.warning("get_currency_names")

    return currency_names


# -----------------------------------------------------------------------------
# Iteration
# -----------------------------------------------------------------------------


def pairwise(iterable):
    "s -> (s0,s1), (s1,s2), (s2, s3), ..."
    a, b = itertools.tee(iterable)
    next(b, None)
    return zip(a, b)


def grouper(iterable, n, fillvalue=None):
    "Collect data into fixed-length chunks or blocks"
    # grouper('ABCDEFG', 3, 'x') --> ABC DEF Gxx"
    args = [iter(iterable)] * n
    return itertools.zip_longest(*args, fillvalue=fillvalue)


def sliding_window(iterable, n=2):
    """Returns a sliding window of size n over an iterable"""
    iterables = itertools.tee(iterable, n)

    for win_iter, num_skipped in zip(iterables, itertools.count()):
        for _ in range(num_skipped):
            next(win_iter, None)

    return zip(*iterables)


# -----------------------------------------------------------------------------
# XML
# -----------------------------------------------------------------------------

NO_NAMESPACE_PATTERN = re.compile(r"^{[^}]+}")


def tag_no_namespace(tag: str) -> str:
    """Remove namespace from XML tag"""
    return NO_NAMESPACE_PATTERN.sub("", tag)


def attrib_no_namespace(
    element: etree.Element, name: str, default: typing.Any = None
) -> typing.Any:
    """Search for an attribute by key without namespaces"""
    for key, value in element.attrib.items():
        key_no_ns = NO_NAMESPACE_PATTERN.sub("", key)
        if key_no_ns == name:
            return value

    return default


def text_and_elements(element, is_last=False):
    """Yields element, text, sub-elements, end element, and tail"""
    element_metadata = None

    if is_last:
        # True if this is the last child element of a parent.
        # Used to preserve whitespace.
        element_metadata = {"is_last": True}

    yield element, element_metadata

    # Text before any tags (or end tag)
    text = element.text if element.text is not None else ""
    if text.strip():
        yield text

    children = list(element)
    last_child_idx = len(children) - 1

    for child_idx, child in enumerate(children):
        # Sub-elements
        is_last = child_idx == last_child_idx
        yield from text_and_elements(child, is_last=is_last)

    # End of current element
    yield EndElement(element)

    # Text after the current tag
    tail = element.tail if element.tail is not None else ""
    if tail.strip():
        yield tail


def load_lexicon(
    uri: str,
    lexicon: InlineLexicon,
    ssl_context: typing.Optional[ssl.SSLContext] = None,
):
    """Loads a pronunciation lexicon from a URI"""
    if ssl_context is None:
        ssl_context = ssl.create_default_context()

    with urlopen(uri, context=ssl_context) as response:
        tree = etree.parse(response)
        for lexeme_elem in tree.getroot():
            if tag_no_namespace(lexeme_elem.tag) != "lexeme":
                continue

            lexeme = Lexeme()

            role_str = attrib_no_namespace(lexeme_elem, "role")
            if role_str:
                lexeme.roles = set(role_str.strip().split())

            for lexeme_child in lexeme_elem:

                child_tag = tag_no_namespace(lexeme_child.tag)
                if child_tag == "grapheme":
                    if lexeme_child.text:
                        lexeme.grapheme = lexeme_child.text.strip()
                elif child_tag == "phoneme":
                    if lexeme_child.text:
                        lexeme.phonemes = maybe_split_ipa(lexeme_child.text.strip())

            if lexeme.grapheme and lexeme.phonemes:
                role_phonemes = lexicon.words.get(lexeme.grapheme)
                if role_phonemes is None:
                    role_phonemes = {}
                    lexicon.words[lexeme.grapheme] = role_phonemes

                assert role_phonemes is not None

                roles = lexeme.roles or [WordRole.DEFAULT]
                for role in roles:
                    role_phonemes[role] = lexeme.phonemes


# -----------------------------------------------------------------------------
# Text
# -----------------------------------------------------------------------------

NON_WORDS_PATTERN = re.compile(r"\W")


def remove_non_word_chars(s: str) -> str:
    """Removes non-word characters from a string"""
    return NON_WORDS_PATTERN.sub("", s)


def maybe_split_ipa(s: str) -> typing.List[str]:
    """Split on whitespace if a space is present, otherwise return string as list of graphemes"""
    if " " in s:
        # Manual separation
        return s.split()

    # Automatic separation
    return IPA.graphemes(s)


# -----------------------------------------------------------------------------
# Graph
# -----------------------------------------------------------------------------


def print_graph(
    graph: GraphType,
    node: typing.Union[NODE_TYPE, Node],
    indent: str = "--",
    level: int = 1,
    print_func=print,
):
    """Prints a graph to the console"""
    if isinstance(node, Node):
        n_data = node
        graph_node = node.node
    else:
        graph_node = node
        n_data = typing.cast(Node, graph.nodes[graph_node][DATA_PROP])

    print_func(indent * level, graph_node, n_data)
    for succ_node in graph.successors(graph_node):
        print_graph(
            graph, succ_node, indent=indent, level=level + 1, print_func=print_func
        )


def leaves(graph: GraphType, node: Node):
    """Iterate through the leaves of a graph in depth-first order"""
    for dfs_node in nx.dfs_preorder_nodes(graph, node.node):
        if not graph.out_degree(dfs_node) == 0:
            continue

        yield graph.nodes[dfs_node][DATA_PROP]


def pipeline_split(split_func, graph: GraphType, parent_node: Node,) -> bool:
    """Splits leaf nodes of tree into zero or more sub-nodes"""
    was_changed = False

    for leaf_node in list(leaves(graph, parent_node)):
        for node_class, node_kwargs in split_func(graph, leaf_node):
            new_node = node_class(node=len(graph), **node_kwargs)
            graph.add_node(new_node.node, data=new_node)
            graph.add_edge(leaf_node.node, new_node.node)
            was_changed = True

    return was_changed


def pipeline_transform(transform_func, graph: GraphType, parent_node: Node,) -> bool:
    """Transforms leaves of tree with a custom function"""
    was_changed = False

    for leaf_node in list(leaves(graph, parent_node)):
        if transform_func(graph, leaf_node):
            was_changed = True

    return was_changed
