import seaborn as sns
import matplotlib.pyplot as plt
sns.set_theme(style="whitegrid")


import metaheuristic_designer as mhd
from metaheuristic_designer.benchmarks import Rastrigin
from metaheuristic_designer.initializers import UniformInitializer
from metaheuristic_designer.strategies import DE
from metaheuristic_designer.algorithms import Algorithm
from metaheuristic_designer.history_tracker import HistoryTracker
from metaheuristic_designer.stopping_condition import StoppingCondition
from metaheuristic_designer.parameter_schedules import ExponentialDecaySchedule

rng = mhd.check_random_state(42)
DIM = 5
objfunc = Rastrigin(DIM, mode="min")

strategy = DE(
   initializer=UniformInitializer(
      objfunc.dimension, objfunc.lower_bound, objfunc.upper_bound,
      population_size=100, random_state=rng
   ),
   de_operator_name="DE/rand/1",
   F=ExponentialDecaySchedule(init_value=1, final_value=0.05, alpha=0.99),
   Cr=0.9,
   name="DE",
   random_state=rng,
)

algo = mhd.Algorithm(
    objfunc,
    strategy,
    stop_cond="max_iterations",
    max_iterations=200,
    reporter="silent",
    history_tracker=mhd.HistoryTracker(
      track_median=True,
      track_worst=True,
      track_full_objective=True,
      track_diversity=True,
      track_parameters=True,
   )
)
algo.optimize()

df = algo.history_tracker.to_pandas()
full_obj_df = algo.history_tracker.to_pandas_full_objective()
long_df = full_obj_df.melt(id_vars="iteration", var_name="individual", value_name="objective")

param_cols = ["DE/rand/1.F", "DE/rand/1.Cr"]

fig = plt.figure(figsize=(14, 10))

# Convergence
ax1 = fig.add_subplot(2, 2, 1)
sns.lineplot(data=df, x="iteration", y="best_objective", ax=ax1)
ax1.set_ylabel("Objective")
ax1.set_title("Best Objective")
ax1.axvline(0, color="gray")
ax1.axhline(0, color="gray")

# Fitness distribution (every 10th generation)
ax2 = fig.add_subplot(2, 2, 2)
plot_data = long_df[long_df["iteration"] % 10 == 0]
sns.boxplot(data=plot_data, x="iteration", y="objective", ax=ax2, width=0.6)
ax2.set_ylabel("Objective")
ax2.set_title("Fitness Distribution")
ax2.grid()
ax2.axhline(0, color="gray")

# Diversity
ax3 = fig.add_subplot(2, 2, 3)
ax3_twin = ax3.twinx()
sns.lineplot(data=df, x="iteration", y="diversity", ax=ax3_twin, color="tab:red", label="Diversity")
sns.lineplot(data=df, x="iteration", y="best_objective", ax=ax3, color="tab:blue", label="Objective")
ax3.set_title("Diversity vs Objective")
ax3.set_xlabel("Generation")
ax3.set_ylabel("Objective", color="tab:blue")
ax3_twin.set_ylabel("Diversity", color="tab:red")
ax3_twin.grid(False)
ax3.axvline(0, color="gray")

# Merge legends
lines1, labels1 = ax3.get_legend_handles_labels()
lines2, labels2 = ax3_twin.get_legend_handles_labels()
ax3_twin.legend(lines1 + lines2, labels1 + labels2, loc="upper right")

# Parameter(s) if present
ax4 = fig.add_subplot(2, 2, 4)
for col in param_cols:
   sns.lineplot(data=df, x="iteration", y=col, ax=ax4, label=col)
ax4.set_ylabel("Parameter values")
ax4.set_title("Scheduled Parameters")
ax4.legend()
ax4.axvline(0, color="gray")
ax4.axhline(0, color="gray")

plt.tight_layout()
plt.show()