import matplotlib.pyplot as plt
import csv
import numpy as np

def readCSV(file_name, delimiter=','):
    file = csv.DictReader(open(file_name), delimiter=delimiter)
    return list([row for row in file])


def plot(name, subplot, file_name):
    subplot.set_title(name)

    rows = readCSV('output/{}'.format(file_name))

    def get_datas(column_name):
        return ([float(row['budget_percent']) for row in rows], [float(row[column_name]) / 1000 for row in rows])

    datas = [
        (("mip", "v-"), get_datas('pl_2_time')),
        (("mip preprocessed", "^-"), get_datas('pl_3_time'))
    ]

    for ((label,linestyle),(xdatas,ydatas)) in datas:
        subplot.plot(xdatas, ydatas, linestyle, label=label, markersize=3.25)
        
    subplot.set_yscale('log')


fig, axs = plt.subplots(2, 2)
plot("Aude", axs[0, 0], "time_aude.csv")
plot("Montreal", axs[0, 1], "time_quebec_analysis.csv")
plot("Aix", axs[1, 0], "time_biorevaix_analysis.csv")
plot("Marseille", axs[1, 1], "time_marseille_analysis.csv")

axs[0, 0].set(xlabel=None, ylabel='execution time (s)')
axs[0, 1].set(xlabel=None, ylabel=None)
axs[1, 0].set(xlabel='budget percent', ylabel='execution time (s)')
axs[1, 1].set(xlabel='budget percent', ylabel=None)

fig.subplots_adjust(bottom=0.2, top=0.91, wspace=0.2, hspace=0.35)
axs.flatten()[-2].legend(loc='upper center', bbox_to_anchor=(1.05, -0.35), ncol=3)

fig.set_size_inches(8,5)
plt.rcParams.update({'font.size': 18})

# plt.tight_layout()
plt.savefig("figures/execution_times.pdf", dpi=500)
plt.show()