import seaborn as sns
import matplotlib.pyplot as plt
sns.set_theme(style="whitegrid")


import metaheuristic_designer as mhd
from metaheuristic_designer.benchmarks import Rastrigin
from metaheuristic_designer.initializers import UniformInitializer
from metaheuristic_designer.strategies import DE
from metaheuristic_designer.algorithms import Algorithm
from metaheuristic_designer.history_tracker import HistoryTracker
from metaheuristic_designer.stopping_condition import StoppingCondition
from metaheuristic_designer.parameter_schedules import ExponentialDecaySchedule

rng = mhd.check_random_state(42)
DIM = 5
objfunc = Rastrigin(DIM, mode="min")

strategy = DE(
   initializer=UniformInitializer(
      objfunc.dimension, objfunc.lower_bound, objfunc.upper_bound,
      population_size=100, random_state=rng
   ),
   de_operator_name="DE/rand/1",
   F=ExponentialDecaySchedule(init_value=1, final_value=0.05, alpha=0.99),
   Cr=0.9,
   name="DE",
   random_state=rng,
)

algo = mhd.Algorithm(
    objfunc,
    strategy,
    stop_cond="max_iterations",
    max_iterations=200,
    reporter="silent",
    history_tracker=mhd.HistoryTracker(
      track_median=True,
      track_worst=True,
      track_full_objective=True,
      track_diversity=True,
      track_parameters=True,
   )
)
algo.optimize()

df = algo.history_tracker.to_pandas()
full_obj_df = algo.history_tracker.to_pandas_full_objective()

fig, ax1 = plt.subplots(figsize=(8, 5))
ax2 = ax1.twinx()

sns.lineplot(data=df, x="iteration", y="best_objective", ax=ax1,
            linewidth=2, color="tab:blue", label="Objective", zorder=100)
sns.lineplot(data=df, x="iteration", y="diversity", ax=ax2,
            linewidth=2, color="tab:red", label="Diversity", zorder=100)

ax1.set_xlabel("Generation")
ax1.set_ylabel("Objective", color="tab:blue")
ax2.set_ylabel("Diversity", color="tab:red")
ax1.set_title("Convergence and Diversity")
ax2.grid(False)

ax1.axvline(0, color="gray")

lines1, labels1 = ax1.get_legend_handles_labels()
lines2, labels2 = ax2.get_legend_handles_labels()
ax2.legend(lines1 + lines2, labels1 + labels2, loc="upper right")
plt.tight_layout()
plt.show()