from datetime import datetime
from pathlib import Path

import matplotlib.dates as mdates
import matplotlib.pyplot as plt

import numpy as np

# Import Data
models_path = "../../../demo/polders_boezem/"

p1_input = np.recfromcsv(
    fname=models_path + "linked_polders/polder1/input/timeseries_import.csv",
    encoding=None,
)
polder_outputs = {
    i: np.recfromcsv(
        fname=models_path + f"linked_polders/polder{i}/output/timeseries_export.csv",
        encoding=None,
    )
    for i in range(1, 11)
}
boezem_output = np.recfromcsv(
    fname=models_path + "boezem/output/timeseries_export.csv", encoding=None
)

# Get times as datetime objects
times = list(
    map(lambda x: datetime.strptime(x, "%Y-%m-%d %H:%M:%S"), boezem_output.time)
)

# For prettier plotting, make initial flow rate equal to the second. This is
# because the flow at the first timestep is not used or optimized in these models.
for i in range(1, 11):
    polder_outputs[i].outflow_q[0] = polder_outputs[i].outflow_q[1]
boezem_output.polder_inflow_q[0] = boezem_output.polder_inflow_q[1]
boezem_output.outflow_q[0] = boezem_output.outflow_q[1]

# Generate Plot
n_subplots = 2
fig, axarr = plt.subplots(n_subplots, sharex=True, figsize=(8, 3 * n_subplots))
axarr[0].set_title("Water Volumes and Discharges")

# Upper subplot
axarr[0].set_ylabel("Water Volume [m³]")
axarr[0].plot(times, boezem_output.channel_v, label="Boezem", color="r")
axarr[0].plot(times, [600] * len(times), label="Boezem Max", color="tab:gray")
axarr[0].plot(times, [400] * len(times), label="Boezem Min", color="tab:brown")

# Middle Subplot
axarr[1].set_ylabel("Flow Rate [m³/s]")
axarr[1].plot(times, p1_input.inflow_q, label="Polder 1 Inflow")
for i in range(1, 10):
    axarr[1].plot(times, polder_outputs[i].outflow_q, label=f"{i} -> {i+1}")
axarr[1].plot(
    times,
    boezem_output.polder_inflow_q,
    label="P-10 Outflow\n(to Boezem)",
    color="r",
    linestyle="--",
)
axarr[1].plot(times, boezem_output.outflow_q, label="Boezem Outflow", color="g")

# Format bottom axis label
axarr[-1].xaxis.set_major_formatter(mdates.DateFormatter("%H:%M"))

# Shrink margins
fig.tight_layout()

# Shrink each axis and put a legend to the right of the axis
for i in range(n_subplots):
    box = axarr[i].get_position()
    axarr[i].set_position([box.x0, box.y0, box.width * 0.78, box.height])
    axarr[i].legend(loc="center left", bbox_to_anchor=(1, 0.5), frameon=False)

plt.autoscale(enable=True, axis="x", tight=True)

# Output Plot
plt.show()
