Skip to content

Commit 47effec

Browse files
committed
MAINT: Update ma tutorial plt patterns.
1 parent 132fcb7 commit 47effec

File tree

1 file changed

+26
-18
lines changed

1 file changed

+26
-18
lines changed

content/tutorial-ma.md

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -131,9 +131,11 @@ First of all, we can plot the whole set of data we have and see what it looks li
131131
import matplotlib.pyplot as plt
132132
133133
selected_dates = [0, 3, 11, 13]
134-
plt.plot(dates, nbcases.T, "--")
135-
plt.xticks(selected_dates, dates[selected_dates])
136-
plt.title("COVID-19 cumulative cases from Jan 21 to Feb 3 2020")
134+
135+
fig, ax = plt.subplots()
136+
ax.plot(dates, nbcases.T, "--")
137+
ax.set_xticks(selected_dates, dates[selected_dates])
138+
ax.set_title("COVID-19 cumulative cases from Jan 21 to Feb 3 2020")
137139
```
138140

139141
The graph has a strange shape from January 24th to February 1st. It would be interesting to know where this data comes from. If we look at the `locations` array we extracted from the `.csv` file, we can see that we have two columns, where the first would contain regions and the second would contain the name of the country. However, only the first few rows contain data for the the first column (province names in China). Following that, we only have country names. So it would make sense to group all the data from China into a single row. For this, we'll select from the `nbcases` array only the rows for which the second entry of the `locations` array corresponds to China. Next, we'll use the [numpy.sum](https://numpy.org/devdocs/reference/generated/numpy.sum.html#numpy.sum) function to sum all the selected rows (`axis=0`). Note also that row 35 corresponds to the total counts for the whole country for each date. Since we want to calculate the sum ourselves from the provinces data, we have to remove that row first from both `locations` and `nbcases`:
@@ -183,9 +185,10 @@ Let's try and see what the data looks like excluding the first row (data from th
183185
closely:
184186

185187
```{code-cell}
186-
plt.plot(dates, nbcases_ma[1:].T, "--")
187-
plt.xticks(selected_dates, dates[selected_dates])
188-
plt.title("COVID-19 cumulative cases from Jan 21 to Feb 3 2020")
188+
fig, ax = plt.subplots()
189+
ax.plot(dates, nbcases_ma[1:].T, "--")
190+
ax.set_xticks(selected_dates, dates[selected_dates])
191+
ax.set_title("COVID-19 cumulative cases from Jan 21 to Feb 3 2020")
189192
```
190193

191194
Now that our data has been masked, let's try summing up all the cases in China:
@@ -232,9 +235,10 @@ china_total
232235
We can replace the data with this information and plot a new graph, focusing on Mainland China:
233236

234237
```{code-cell}
235-
plt.plot(dates, china_total.T, "--")
236-
plt.xticks(selected_dates, dates[selected_dates])
237-
plt.title("COVID-19 cumulative cases from Jan 21 to Feb 3 2020 - Mainland China")
238+
fig, ax = plt.subplots()
239+
ax.plot(dates, china_total.T, "--")
240+
ax.set_xticks(selected_dates, dates[selected_dates])
241+
ax.set_title("COVID-19 cumulative cases from Jan 21 to Feb 3 2020 - Mainland China")
238242
```
239243

240244
It's clear that masked arrays are the right solution here. We cannot represent the missing data without mischaracterizing the evolution of the curve.
@@ -271,21 +275,25 @@ package to create a cubic polynomial model that fits the data as best as possibl
271275
```{code-cell}
272276
t = np.arange(len(china_total))
273277
model = np.polynomial.Polynomial.fit(t[~china_total.mask], valid, deg=3)
274-
plt.plot(t, china_total)
275-
plt.plot(t, model(t), "--")
278+
279+
fig, ax = plt.subplots()
280+
ax.plot(t, china_total)
281+
ax.plot(t, model(t), "--")
276282
```
277283

278284
This plot is not so readable since the lines seem to be over each other, so let's summarize in a more elaborate plot. We'll plot the real data when
279285
available, and show the cubic fit for unavailable data, using this fit to compute an estimate to the observed number of cases on January 28th 2020, 7 days after the beginning of the records:
280286

281287
```{code-cell}
282-
plt.plot(t, china_total)
283-
plt.plot(t[china_total.mask], model(t)[china_total.mask], "--", color="orange")
284-
plt.plot(7, model(7), "r*")
285-
plt.xticks([0, 7, 13], dates[[0, 7, 13]])
286-
plt.yticks([0, model(7), 10000, 17500])
287-
plt.legend(["Mainland China", "Cubic estimate", "7 days after start"])
288-
plt.title(
288+
fig, ax = plt.subplots()
289+
ax.plot(t, china_total)
290+
ax.plot(t[china_total.mask], model(t)[china_total.mask], "--", color="orange")
291+
ax.plot(7, model(7), "r*")
292+
293+
ax.set_xticks([0, 7, 13], dates[[0, 7, 13]])
294+
ax.set_yticks([0, model(7), 10000, 17500])
295+
ax.legend(["Mainland China", "Cubic estimate", "7 days after start"])
296+
ax.set_title(
289297
"COVID-19 cumulative cases from Jan 21 to Feb 3 2020 - Mainland China\n"
290298
"Cubic estimate for 7 days after start"
291299
)

0 commit comments

Comments
 (0)