# Copyright ExplsionAI GmbH, released under BSD.
import numpy
import numpy.random
from .py import gemm, einsum
from timeit import default_timer as timer

numpy.random.seed(0)


def create_data(nO, nI, batch_size):
    X = numpy.zeros((batch_size, nI), dtype="f")
    X += numpy.random.uniform(-1.0, 1.0, X.shape)
    W = numpy.zeros((nO, nI), dtype="f")
    W += numpy.random.uniform(-1.0, 1.0, W.shape)
    return X, W


def get_numpy_blas():
    blas_libs = numpy.__config__.blas_opt_info["libraries"]
    return blas_libs[0]


def numpy_gemm(X, W, n=1000):
    nO, nI = W.shape
    batch_size = X.shape[0]
    total = 0.0
    y = numpy.zeros((batch_size, nO), dtype="f")
    for i in range(n):
        numpy.dot(X, W, out=y)
        total += y.sum()
        y.fill(0)
    print("Total:", total)


def blis_gemm(X, W, n=1000):
    nO, nI = W.shape
    batch_size = X.shape[0]
    total = 0.0
    y = numpy.zeros((batch_size, nO), dtype="f")
    for i in range(n):
        gemm(X, W, out=y)
        total += y.sum()
        y.fill(0.0)
    print("Total:", total)


def numpy_einsum(X, W, n=1000):
    nO, nI = W.shape
    batch_size = X.shape[0]
    total = 0.0
    y = numpy.zeros((nO, batch_size), dtype="f")
    for i in range(n):
        numpy.einsum("ab,cb->ca", X, W, out=y)
        total += y.sum()
        y.fill(0.0)
    print("Total:", total)


def blis_einsum(X, W, n=1000):
    nO, nI = W.shape
    batch_size = X.shape[0]
    total = 0.0
    y = numpy.zeros((nO, batch_size), dtype="f")
    for i in range(n):
        einsum("ab,cb->ca", X, W, out=y)
        total += y.sum()
        y.fill(0.0)
    print("Total:", total)


def main(nI=128 * 3, nO=128 * 3, batch_size=2000):
    print(
        "Setting up data for gemm. 1000 iters,  "
        "nO={nO} nI={nI} batch_size={batch_size}".format(**locals())
    )
    numpy_blas = get_numpy_blas()
    X1, W1 = create_data(nI, nO, batch_size)
    X2 = X1.copy()
    W2 = W1.copy()
    print("Blis gemm...")
    start = timer()
    blis_gemm(X2, W2, n=1000)
    end = timer()
    blis_time = end - start
    print("%.2f seconds" % blis_time)
    print("Numpy (%s) gemm..." % numpy_blas)
    start = timer()
    numpy_gemm(X1, W1)
    end = timer()
    numpy_time = end - start
    print("%.2f seconds" % numpy_time)
    print("Blis einsum ab,cb->ca")
    start = timer()
    blis_einsum(X2, W2, n=1000)
    end = timer()
    blis_time = end - start
    print("%.2f seconds" % blis_time)
    print("Numpy (%s) einsum ab,cb->ca" % numpy_blas)
    start = timer()
    numpy_einsum(X2, W2)
    end = timer()
    numpy_time = end - start
    print("%.2f seconds" % numpy_time)


if __name__:
    main()
