diff --git a/basics.py b/basics.py index f3b31ad..765cc60 100644 --- a/basics.py +++ b/basics.py @@ -5,34 +5,32 @@ import matplotlib.pyplot as pp def plot(data, countries): + tcp, tc = pp.subplots() + tdp, td = pp.subplots() + ncp, nc = pp.subplots() + ndp, nd = pp.subplots() for loc in data: if loc not in countries: continue time, new_cases, new_deaths, total_cases, total_deaths = data[loc] # total cases - pp.figure("total_cases") - pp.plot(time, total_cases, label=f"{loc}") + tc.plot(time, total_cases, label=f"{loc}") # total deaths - pp.figure("total_deaths") - pp.plot(time, total_deaths, label=f"{loc}") + td.plot(time, total_deaths, label=f"{loc}") # new cases - pp.figure("new_cases") - pp.plot(time, new_cases, label=f"{loc}") + nc.plot(time, new_cases, label=f"{loc}") # new deaths - pp.figure("new_deaths") - pp.plot(time, new_deaths, label=f"{loc}") + nd.plot(time, new_deaths, label=f"{loc}") - for name in ["total_cases", "total_deaths", "new_cases", "new_deaths"]: - postprocess(name) + for ax, fig, name in [(tc, tcp, "total_cases"), (td, tdp, "total_deaths"), (nc, ncp, "new_cases"), (nd, ndp, "new_deaths")]: + ax.set_yscale("log") + for tick in ax.get_xticklabels(): + tick.set_rotation(45) + ax.legend(frameon=False) + fig.tight_layout() -def postprocess(name): - pp.yscale("log") - pp.xticks(rotation=90) - pp.legend(frameon=False) - pp.tight_layout() - - pp.savefig(f"{name}.png") + fig.savefig(name+".png")