diff --git a/basics.py b/basics.py index cc06999..d94e6e4 100644 --- a/basics.py +++ b/basics.py @@ -2,13 +2,15 @@ Plot total cases of countries over time on log scale """ import matplotlib.pyplot as pp +import numpy as np def plot(data, countries): figsize = (10,5) - tcp, tc = pp.subplots(figsize=figsize) - tdp, td = pp.subplots(figsize=figsize) - ncp, nc = pp.subplots(figsize=figsize) - ndp, nd = pp.subplots(figsize=figsize) + tcp, tc = pp.subplots(figsize=figsize) # total cases + tdp, td = pp.subplots(figsize=figsize) # total deaths + tip, ti = pp.subplots(figsize=figsize) # total (currently) infected + ncp, nc = pp.subplots(figsize=figsize) # new cases + ndp, nd = pp.subplots(figsize=figsize) # new deaths for loc in data: if loc not in countries: continue @@ -26,12 +28,20 @@ def plot(data, countries): # new deaths nd.plot(time, new_deaths, label=f"{loc}", marker=".") - for ax, fig, name in [(tc, tcp, "total_cases"), (td, tdp, "total_deaths"), (nc, ncp, "new_cases"), (nd, ndp, "new_deaths")]: + # currently infected + delay = 21 + current_infected = np.array(total_cases[delay:]) - np.array(total_deaths[:-delay]) - np.array(total_cases[:-delay]) + + ti.plot(time[:-delay], current_infected, label=f"{loc}", marker=".") + + for ax, fig, name in [(tc, tcp, "total_cases"), (td, tdp, "total_deaths"), (nc, ncp, "new_cases"), (nd, ndp, "new_deaths"), (ti, tip, "current_infected")]: ax.set_yscale("log") ax.set_ylabel(name) for tick in ax.get_xticklabels(): tick.set_rotation(45) ax.legend(frameon=False) + ax.grid(True) fig.tight_layout() + fig.savefig(name+".png")