#!/usr/bin/env python3
"""
Grapheme to phoneme prediction using python CRF suite.

Training requires pre-aligned corpus in Phonetisaurus format.
https://github.com/AdolfVonKleist/Phonetisaurus

The format of this corpus is:

    t}t e}ˈɛ s}s t}t

Each line contains a single word, with graphemes and phonemes separated by "}".
Multiple graphemes are separated by "|":

    s|h}ʃ o|w}ˈoʊ

The empty phoneme is "_":

    w}w h}_ y}ˈaɪ

Example:

.. code-block:: sh

    python3 -m gruut.g2p train --corpus g2p.corpus --output model.crf

Pre-trained models have the following settings:

* c1 = 0
* c2 = 1
* max-iterations = 100
"""
import argparse
import base64
import itertools
import logging
import os
import sys
import time
import typing
import unicodedata
from pathlib import Path

import pycrfsuite

_LOGGER = logging.getLogger("gruut.g2p")

# -----------------------------------------------------------------------------

FEATURES_TYPE = typing.Dict[str, typing.Union[str, bool, int, float]]
EPS_PHONEME = "_"
PHONEME_JOIN = "|"


class GraphemesToPhonemes:
    """Grapheme to phoneme CRF tagger"""

    def __init__(
        self,
        crf_tagger: typing.Union[str, Path, pycrfsuite.Tagger],
        eps_phoneme: str = EPS_PHONEME,
        phoneme_join: str = PHONEME_JOIN,
    ):
        if isinstance(crf_tagger, pycrfsuite.Tagger):
            self.crf_tagger = crf_tagger
        else:
            # Load model
            self.crf_tagger = pycrfsuite.Tagger()
            self.crf_tagger.open(str(crf_tagger))

        # Empty phoneme (dropped)
        self.eps_phoneme = eps_phoneme

        # String used to join multiple predicted phonemes
        self.phoneme_join = phoneme_join

    def __call__(self, word: str, normalize: bool = True) -> typing.Sequence[str]:
        """Guess phonemes for word"""
        features = GraphemesToPhonemes.word2features(word, normalize=normalize)
        coded_phonemes = self.crf_tagger.tag(features)
        phonemes: typing.List[str] = []

        for coded_ps in coded_phonemes:
            decoded_ps = GraphemesToPhonemes.decode_string(coded_ps)
            for p in decoded_ps.split(self.phoneme_join):
                if p != self.eps_phoneme:
                    phonemes.append(p)

        return phonemes

    # -------------------------------------------------------------------------

    @staticmethod
    def word2features(
        word: typing.Union[str, typing.List[str]], normalize: bool = True, **kwargs
    ):
        """Create feature dicts for all graphemes in a word"""
        if normalize and isinstance(word, str):
            # Combine characters
            # See: https://docs.python.org/3/library/unicodedata.html#unicodedata.normalize
            word = unicodedata.normalize("NFC", word)

        return [
            GraphemesToPhonemes.grapheme2features(word, i, **kwargs)
            for i in range(len(word))
        ]

    @staticmethod
    def grapheme2features(
        word: typing.Union[str, typing.Sequence[str]],
        i: int,
        add_begin: bool = True,
        add_end: bool = True,
        chars_backward: int = 3,
        chars_forward: int = 3,
        bias: float = 1.0,
        encode: bool = True,
    ) -> FEATURES_TYPE:
        """Create feature dict for single grapheme"""
        g = word[i]
        num_g = len(word)

        features: FEATURES_TYPE = {
            "bias": bias,
            "grapheme": GraphemesToPhonemes.encode_string(g) if encode else g,
        }

        if (i == 0) and add_begin:
            features["begin"] = True

        for j in range(1, chars_backward + 1):
            if i >= j:
                g_prev = word[i - j]
                features[f"grapheme-{j}"] = (
                    GraphemesToPhonemes.encode_string(g_prev) if encode else g_prev
                )

        for j in range(1, chars_forward + 1):
            if i < (num_g - j):
                g_next = word[i + j]
                features[f"grapheme+{j}"] = (
                    GraphemesToPhonemes.encode_string(g_next) if encode else g_next
                )

        if (i == (num_g - 1)) and add_end:
            features["end"] = True

        return features

    @staticmethod
    def encode_string(s: str) -> str:
        """Encodes string in a form that crfsuite will accept (ASCII) and can be decoded"""
        return base64.b64encode(s.encode()).decode("ascii")

    @staticmethod
    def decode_string(s: str) -> str:
        """Decodes a string encoded by encode_string"""
        return base64.b64decode(s.encode("ascii")).decode()


# -----------------------------------------------------------------------------


def train(
    corpus_path: typing.Union[str, Path],
    output_path: typing.Union[str, Path],
    group_separator: str = "}",
    item_separator: str = "|",
    phoneme_join: str = PHONEME_JOIN,
    eps_phoneme: str = EPS_PHONEME,
    remove_phonemes: typing.Optional[typing.Iterable[str]] = None,
    c1: float = 0.0,
    c2: float = 1.0,
    max_iterations: int = 100,
):
    """Train a new G2P model"""
    corpus_path = Path(corpus_path)
    output_path = Path(output_path)
    output_path.parent.mkdir(parents=True, exist_ok=True)

    remove_phonemes = set(remove_phonemes or [])

    trainer = pycrfsuite.Trainer(verbose=False)

    with open(corpus_path, "r", encoding="utf-8") as corpus:
        for i, line in enumerate(corpus):
            line = line.strip()
            if not line:
                continue

            # Parse into graphemes and phonemes
            skip_line = False
            parts = line.split()
            aligned_g = []
            aligned_p = []

            for part in parts:
                # Graphemes/phonemes are separated by }
                gs_str, ps_str = part.split(group_separator, maxsplit=1)

                # Multiple graphemes and phonemes are separated by |
                gs = gs_str.split(item_separator)
                ps = [
                    p for p in ps_str.split(item_separator) if p not in remove_phonemes
                ]

                # Align graphemes and phonemes, allowing for empty phonemes only
                for g1, p1 in itertools.zip_longest(gs, [ps], fillvalue=None):
                    if g1 is None:
                        skip_line = True
                        break

                    aligned_g.append(g1)

                    if p1:
                        aligned_p.append(phoneme_join.join(p1))
                    else:
                        aligned_p.append(eps_phoneme)

                if skip_line:
                    break

            if skip_line:
                _LOGGER.warning(
                    "Failed to align line %s: %s (graphemes=%s, phonemes=%s)",
                    i + 1,
                    line,
                    gs,
                    ps,
                )
                continue

            # Add example to trainer
            try:
                encoded_p = [GraphemesToPhonemes.encode_string(p) for p in aligned_p]
                trainer.append(GraphemesToPhonemes.word2features(aligned_g), encoded_p)
            except Exception as e:
                _LOGGER.exception("graphemes=%s phonemes=%s", aligned_g, aligned_p)
                raise e

    trainer.set_params(
        {
            "c1": c1,  # coefficient for L1 penalty
            "c2": c2,  # coefficient for L2 penalty
            "max_iterations": max_iterations,  # stop earlier
            # include transitions that are possible, but not observed
            "feature.possible_transitions": True,
        }
    )

    _LOGGER.debug(trainer.get_params())

    # Begin training
    _LOGGER.info("Training")

    start_time = time.perf_counter()
    trainer.train(str(output_path))
    end_time = time.perf_counter()

    _LOGGER.info("Training completed in %s second(s)", end_time - start_time)
    _LOGGER.info(trainer.logparser.last_iteration)


def do_train(args):
    """CLI method for train"""
    train(
        corpus_path=args.corpus,
        output_path=args.output,
        group_separator=args.group_separator,
        item_separator=args.item_separator,
        remove_phonemes=args.remove_phonemes,
        c1=args.c1,
        c2=args.c2,
        max_iterations=args.max_iterations,
    )


def do_predict(args):
    """CLI method for predict"""
    tagger = GraphemesToPhonemes(args.model)

    if args.texts:
        lines = args.texts
    else:
        lines = sys.stdin

        if os.isatty(sys.stdin.fileno()):
            print("Reading words from stdin...", file=sys.stderr)

    for line in lines:
        line = line.strip()
        if not line:
            continue

        word = line
        phonemes = tagger(word)

        print(word, *phonemes)


def do_test(args):
    """CLI method for test"""
    try:
        from rapidfuzz.string_metric import levenshtein
    except ImportError as e:
        _LOGGER.critical("rapidfuzz library is needed for levenshtein distance")
        _LOGGER.critical("pip install 'rapidfuzz>=1.4.1'")
        raise e

    tagger = GraphemesToPhonemes(args.model)

    # Load lexicon
    if args.texts:
        lines = args.texts
    else:
        lines = sys.stdin

        if os.isatty(sys.stdin.fileno()):
            print("Reading lexicon lines from stdin...", file=sys.stderr)

    lexicon = {}
    for line in lines:
        line = line.strip()
        if (not line) or (" " not in line):
            continue

        word, actual_phonemes = line.split(maxsplit=1)
        lexicon[word] = actual_phonemes

    # Predict phonemes
    predicted_phonemes = {}

    start_time = time.perf_counter()

    for word in lexicon:
        phonemes = tagger(word)
        predicted_phonemes[word] = " ".join(phonemes)

    end_time = time.perf_counter()

    # Calculate PER
    num_errors = 0
    num_missing = 0
    num_phonemes = 0

    for word, actual_phonemes in lexicon.items():
        expected_phonemes = predicted_phonemes.get(word, "")

        if expected_phonemes:
            distance = levenshtein(expected_phonemes, actual_phonemes)
            num_errors += distance
            num_phonemes += len(actual_phonemes)
        else:
            num_missing += 1
            _LOGGER.warning("No pronunciation for %s", word)

    assert num_phonemes > 0, "No phonemes were read"

    # Calculate results
    per = round(num_errors / num_phonemes, 2)
    wps = round(len(predicted_phonemes) / (end_time - start_time), 2)
    print("PER:", per, "Errors:", num_errors, "words/sec:", wps)

    if num_missing > 0:
        print("Total missing:", num_missing)


# -----------------------------------------------------------------------------


def main():
    """Main entry point"""
    parser = argparse.ArgumentParser(prog="g2p.py")

    # Create subparsers for each sub-command
    sub_parsers = parser.add_subparsers()
    sub_parsers.required = True
    sub_parsers.dest = "command"

    # -----
    # Train
    # -----
    train_parser = sub_parsers.add_parser(
        "train", help="Train a new G2P model from a pre-aligned Phonetisaurus corpus"
    )
    train_parser.add_argument(
        "--corpus", required=True, help="Path to aligned Phonetisaurus g2p corpus"
    )
    train_parser.add_argument(
        "--output", required=True, help="Path to output tagger model"
    )
    train_parser.add_argument("--c1", type=float, default=0.0, help="L1 penalty")
    train_parser.add_argument("--c2", type=float, default=1.0, help="L2 penalty")
    train_parser.add_argument(
        "--max-iterations",
        type=int,
        default=100,
        help="Maximum number of training iterations (default: 100)",
    )
    train_parser.add_argument(
        "--group-separator",
        default="}",
        help="Separator between graphemes and phonemes",
    )
    train_parser.add_argument(
        "--item-separator", default="|", help="Separator between items in a group"
    )
    train_parser.add_argument(
        "--remove-phonemes", nargs="*", help="Remove phonemes from examples"
    )
    train_parser.set_defaults(func=do_train)

    # -------
    # Predict
    # -------
    predict_parser = sub_parsers.add_parser(
        "predict", help="Predict phonemes for word(s)"
    )
    predict_parser.add_argument(
        "--model", required=True, help="Path to G2P tagger model"
    )
    predict_parser.add_argument("texts", nargs="*", help="Words")
    predict_parser.set_defaults(func=do_predict)

    # ----
    # Test
    # ----
    test_parser = sub_parsers.add_parser("test", help="Test G2P model on a lexicon")
    test_parser.add_argument("--model", required=True, help="Path to G2P tagger model")
    test_parser.add_argument(
        "texts", nargs="*", help="Lines with '<word> <phoneme> <phoneme> ...'"
    )
    test_parser.set_defaults(func=do_test)

    # ----------------
    # Shared arguments
    # ----------------
    for sub_parser in [train_parser, predict_parser, test_parser]:
        sub_parser.add_argument(
            "--debug", action="store_true", help="Print DEBUG messages to console"
        )

    args = parser.parse_args()

    if args.debug:
        logging.basicConfig(level=logging.DEBUG)
    else:
        logging.basicConfig(level=logging.INFO)

    _LOGGER.debug(args)

    args.func(args)


# -----------------------------------------------------------------------------

if __name__ == "__main__":
    main()
