from datetime import datetime
from pathlib import Path

import matplotlib.dates as mdates
import matplotlib.pyplot as plt

import numpy as np

# Import Data
models_path = "../../../demo/linked_polders/"

p1_input = np.recfromcsv(
    fname=models_path + "polder1/input/timeseries_import.csv", encoding=None
)
p1_output = np.recfromcsv(
    fname=models_path + "polder1/output/timeseries_export.csv", encoding=None
)
p5_output = np.recfromcsv(
    fname=models_path + "polder5/output/timeseries_export.csv", encoding=None
)
p10_output = np.recfromcsv(
    fname=models_path + "polder10/output/timeseries_export.csv", encoding=None
)

# Get times as datetime objects
times = list(map(lambda x: datetime.strptime(x, "%Y-%m-%d %H:%M:%S"), p1_output.time))

# For prettier plotting, make initial flow rate equal to the second. This is
# because the flow at the first timestep is not used or optimized in these models.
p1_output.outflow_q[0] = p1_output.outflow_q[1]
p5_output.inflow_q[0] = p5_output.inflow_q[1]
p5_output.outflow_q[0] = p5_output.outflow_q[1]
p10_output.inflow_q[0] = p10_output.inflow_q[1]
p10_output.outflow_q[0] = p10_output.outflow_q[1]

# Generate Plot
n_subplots = 3
fig, axarr = plt.subplots(n_subplots, sharex=True, figsize=(8, 3 * n_subplots))
axarr[0].set_title("Water Volumes and Discharges")

# Upper subplot
axarr[0].set_ylabel("Water Volume [m³]")
axarr[0].ticklabel_format(style="sci", axis="y", scilimits=(0, 0))
axarr[0].plot(times, p1_output.storage_v, label="Polder 1", color="b")
axarr[0].plot(times, p5_output.storage_v, label="Polder 5", color="g")
axarr[0].plot(times, p10_output.storage_v, label="Polder 10", color="r")
axarr[0].plot(
    times, [30000] * len(times), label="Storage Max", color="tab:gray", linestyle="--"
)

# Middle Subplot
axarr[1].set_ylabel("Flow Rate [m³/s]")
axarr[1].plot(times, p1_input.inflow_q, label="P-1 Inflow", color="b")
axarr[1].plot(
    times, p1_output.outflow_q, label="P-1 Outflow", color="b", linestyle="--"
)
axarr[1].plot(times, p5_output.inflow_q, label="P-5 Inflow", color="g")
axarr[1].plot(
    times, p5_output.outflow_q, label="P-5 Outflow", color="g", linestyle="--"
)
axarr[1].plot(times, p10_output.inflow_q, label="P-10 Inflow", color="r")
axarr[1].plot(
    times, p10_output.outflow_q, label="P-10 Outflow", color="r", linestyle="--"
)

# Lower subplot
axarr[2].set_ylabel("Damage Multiplier [-]")
axarr[2].plot(times, p1_output.incurred_damage, label="Polder 1", color="b")
axarr[2].plot(times, p5_output.incurred_damage, label="Polder 5", color="g")
axarr[2].plot(times, p10_output.incurred_damage, label="Polder 10", color="r")

# Format bottom axis label
axarr[-1].xaxis.set_major_formatter(mdates.DateFormatter("%H:%M"))

# Shrink margins
fig.tight_layout()

# Shrink each axis and put a legend to the right of the axis
for i in range(n_subplots):
    box = axarr[i].get_position()
    axarr[i].set_position([box.x0, box.y0, box.width * 0.8, box.height])
    axarr[i].legend(loc="center left", bbox_to_anchor=(1, 0.5), frameon=False)

plt.autoscale(enable=True, axis="x", tight=True)

# Output Plot
plt.show()
