#! /usr/bin/env python3
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import h5py

import os

from argparse import ArgumentParser


def gridlines(obj, x, y):
    for j in range(1, x.shape[0] - 1):
        obj.plot(x[j, :], y[j, :], color="#7f7f7f", linewidth=0.1, alpha=0.3)
    for j in range(1, x.shape[1] - 1):
        obj.plot(x[:, j], y[:, j], color="#7f7f7f", linewidth=0.1, alpha=0.3)

    obj.plot(x[0, :], y[0, :], color="#7f7f7f", linewidth=0.2)
    obj.plot(x[-1, :], y[-1, :], color="#7f7f7f", linewidth=0.2)
    obj.plot(x[:, 0], y[:, 0], color="#7f7f7f", linewidth=0.2)
    obj.plot(x[:, -1], y[:, -1], color="#7f7f7f", linewidth=0.2)


def plot_all(grids, error: bool, save: bool, filename="figure.png"):
    sym_cmap = plt.get_cmap("PiYG")  # Symmetric around zero
    if error:
        e_cmap = sym_cmap
    else:
        e_cmap = plt.get_cmap("Greys")

    f, axarr = plt.subplots(2, 2)

    min_rho = min(np.min(g["rho"]) for g in grids)
    max_rho = max(np.max(g["rho"]) for g in grids)
    if error:
        r = 1.2 * max(abs(min_rho), abs(max_rho))
        rho_levels = np.linspace(-r, r, 34)
    else:
        r = 1.2 * max(abs(min_rho - 1), abs(max_rho - 1))
        rho_levels = np.linspace(1 - r, 1 + r, 34)

    min_rhou = min(np.min(g["rhou"]) for g in grids)
    max_rhou = max(np.max(g["rhov"]) for g in grids)
    if error:
        r = 1.2 * max(abs(min_rhou), abs(max_rhou))
        rhou_levels = np.linspace(-r, r, 20)
    else:
        r = 1.2 * max(abs(min_rhou - 1), abs(max_rhou - 1))
        rhou_levels = np.linspace(1 - r, 1 + r, 20)

    min_rhov = min(np.min(g["rhov"]) for g in grids)
    max_rhov = max(np.max(g["rhov"]) for g in grids)
    r = 1.2 * max(abs(min_rhov), abs(max_rhov))
    rhov_levels = np.linspace(-r, r, 20)

    min_e = min(np.min(g["e"]) for g in grids)
    max_e = max(np.max(g["e"]) for g in grids)
    if error:
        r = max(abs(min_e), abs(max_e))
        e_levels = np.linspace(-r, r, 20)
    else:
        e_levels = np.linspace(min_e, max_e)

    for g in grids:
        x = g["x"]
        y = g["y"]
        axarr[0, 0].contourf(x, y, g["rho"], cmap=sym_cmap, levels=rho_levels)
        gridlines(axarr[0, 0], x, y)

        axarr[0, 1].contourf(x, y, g["rhou"], cmap=sym_cmap, levels=rhou_levels)
        gridlines(axarr[0, 1], x, y)

        axarr[1, 0].contourf(x, y, g["rhov"], cmap=sym_cmap, levels=rhov_levels)
        gridlines(axarr[1, 0], x, y)

        axarr[1, 1].contourf(x, y, g["e"], cmap=e_cmap, levels=e_levels)
        gridlines(axarr[1, 1], x, y)

    axarr[0, 0].set_title(r"$\rho$")
    axarr[0, 0].set_xlabel("x")
    axarr[0, 0].set_ylabel("y")
    norm = mpl.colors.Normalize(vmin=rho_levels[0], vmax=rho_levels[-1])
    sm = plt.cm.ScalarMappable(cmap=sym_cmap, norm=norm)
    sm.set_array([])
    plt.colorbar(sm, ax=axarr[0, 0])

    axarr[0, 1].set_title(r"$\rho u$")
    axarr[0, 1].set_xlabel("x")
    axarr[0, 1].set_ylabel("y")
    norm = mpl.colors.Normalize(vmin=rhou_levels[0], vmax=rhou_levels[-1])
    sm = plt.cm.ScalarMappable(cmap=sym_cmap, norm=norm)
    sm.set_array([])
    plt.colorbar(sm, ax=axarr[0, 1])

    axarr[1, 0].set_title(r"$\rho v$")
    axarr[1, 0].set_xlabel("x")
    axarr[1, 0].set_ylabel("y")
    norm = mpl.colors.Normalize(vmin=rhov_levels[0], vmax=rhov_levels[-1])
    sm = plt.cm.ScalarMappable(cmap=sym_cmap, norm=norm)
    sm.set_array([])
    plt.colorbar(sm, ax=axarr[1, 0])

    axarr[1, 1].set_title(r"$e$")
    axarr[1, 1].set_xlabel("x")
    axarr[1, 1].set_ylabel("y")
    norm = mpl.colors.Normalize(vmin=e_levels[0], vmax=e_levels[-1])
    sm = plt.cm.ScalarMappable(cmap=e_cmap, norm=norm)
    sm.set_array([])
    plt.colorbar(sm, ax=axarr[1, 1])

    if save:
        plt.savefig(filename, bbox_inches="tight", dpi=600)

    plt.show()


def plot_total_error(grids, save: bool, filename="figure.png"):
    cmap = plt.get_cmap("Greys")

    total_err = [
        np.abs(g["rho"]) + np.abs(g["rhou"]) + np.abs(g["rhov"]) + np.abs(g["e"])
        for g in grids
    ]

    r = max(np.max(err) for err in total_err)

    levels = np.linspace(0, r, 30)

    for g, err in zip(grids, total_err):
        x = g["x"]
        y = g["y"]

        plt.contourf(x, y, err, cmap=cmap, levels=levels)
        gridlines(plt, x, y)

    plt.title("Total error")
    norm = mpl.colors.Normalize(vmin=levels[0], vmax=levels[-1])
    sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
    sm.set_array([])
    plt.colorbar(sm)

    plt.xlabel("x")
    plt.ylabel("y")

    if save:
        plt.savefig(args.output, bbox_inches="tight", dpi=600)

    plt.show()


def plot_pressure(grids, save: bool, filename="figure.png"):
    cmap = plt.get_cmap("RdGy")
    gamma = 1.4  # Assumption might be wrong
    Mach = 0.5

    p = [
        (gamma - 1) * (g["e"] - (g["rhou"] ** 2 + g["rhov"] ** 2) / (2 * g["rho"]))
        for g in grids
    ]

    flat_p = np.array([])
    for p_ in p:
        flat_p = np.append(flat_p, p_)

    max_p = np.max(flat_p)
    min_p = np.min(flat_p)

    p_inf = 1 / (gamma * Mach ** 2)

    r = max(max_p - p_inf, p_inf - min_p)

    levels = np.linspace(p_inf - r, p_inf + r, 30)

    for g, p_ in zip(grids, p):
        x = g["x"]
        y = g["y"]

        plt.contourf(x, y, p_, cmap=cmap, levels=levels)
        gridlines(plt, x, y)

    plt.title("Pressure")
    norm = mpl.colors.Normalize(vmin=levels[0], vmax=levels[-1])
    sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
    sm.set_array([])
    plt.colorbar(sm)

    plt.xlabel("x")
    plt.ylabel("y")

    if save:
        plt.savefig(filename, bbox_inches="tight", dpi=600)

    plt.show()


def read_from_file(filename):
    grids = []

    file = h5py.File(filename, "r")

    for groupname in file:
        group = file[groupname]
        if not isinstance(group, h5py.Group):
            continue
        grids.append(
            {
                "x": group["x"][:],
                "y": group["y"][:],
                "rho": group["rho"][-1, :, :],
                "rhou": group["rhou"][-1, :, :],
                "rhov": group["rhov"][-1, :, :],
                "e": group["e"][-1, :, :],
            }
        )

    return grids


def read_from_legacy_file(filename):
    grids = []

    with open(filename, "rb") as f:
        ngrids = int(np.fromfile(f, dtype=np.uint32, count=1))
        for i in range(ngrids):
            (neta, nxi) = np.fromfile(f, dtype=np.uint32, count=2)

            x = np.fromfile(f, dtype=np.double, count=neta * nxi)
            x = x.reshape((neta, nxi))

            y = np.fromfile(f, dtype=np.double, count=neta * nxi)
            y = y.reshape((neta, nxi))

            rho = np.fromfile(f, dtype=np.double, count=neta * nxi)
            rho = rho.reshape((neta, nxi))

            rhou = np.fromfile(f, dtype=np.double, count=neta * nxi)
            rhou = rhou.reshape((neta, nxi))

            rhov = np.fromfile(f, dtype=np.double, count=neta * nxi)
            rhov = rhov.reshape((neta, nxi))

            e = np.fromfile(f, dtype=np.double, count=neta * nxi)
            e = e.reshape((neta, nxi))

            grids.append(
                {"x": x, "y": y, "rho": rho, "rhou": rhou, "rhov": rhov, "e": e}
            )

    return grids


if __name__ == "__main__":
    parser = ArgumentParser(description="Plot a solution from the eulersolver")
    parser.add_argument("filename", metavar="filename", type=str)
    parser.add_argument(
        "-e",
        help="Scale is centered around zero (implies -a)",
        action="store_true",
        dest="error",
    )
    parser.add_argument(
        "-te", help="Plots total error", action="store_true", dest="total_error"
    )
    parser.add_argument("-s", help="Save figure", action="store_true", dest="save")
    parser.add_argument(
        "-o",
        help="Output of saved figure",
        type=str,
        default="figure.png",
        dest="output",
    )
    parser.add_argument(
        "-a", help="Show all four variables", action="store_true", dest="all"
    )

    args = parser.parse_args()
    filename = args.filename
    if not os.path.isfile(filename):
        filename = "solution{:03}.bin".format(int(filename))

    grids = read_from_file(filename)

    if args.all or args.error:
        plot_all(grids, args.error, args.save, args.output)
    elif args.total_error:
        plot_total_error(grids, args.save, args.output)
    else:
        plot_pressure(grids, args.save, args.output)