refactor present_model

This commit is contained in:
markus 2021-02-08 13:20:42 +01:00
parent a71a104524
commit 6f4bb006df

View file

@ -398,7 +398,6 @@ def print_qr_info(qr_values: np.ndarray) -> None:
def present_model(model: MCConcentrationModel, bins: int = 200) -> None:
global data
fig, axs = plt.subplots(2, 2, sharex=False, sharey=False)
fig.set_figheight(8)
fig.set_figwidth(10)
@ -414,10 +413,11 @@ def present_model(model: MCConcentrationModel, bins: int = 200) -> None:
axs[x, y].set_yticklabels([])
axs[x, y].set_yticks([])
for data, (x, y) in zip((model.infected._generate_viral_loads(),
model.infected._generate_breathing_rates(),
np.log10(model.infected.emission_rate_when_present())),
((0, 0), (1, 0), (1, 1))):
viral_loads, breathing_rates, qRs = (model.infected._generate_viral_loads(),
model.infected._generate_breathing_rates(),
np.log10(model.infected.emission_rate_when_present()))
for data, (x, y) in zip((viral_loads, breathing_rates, qRs), ((0, 0), (1, 0), (1, 1))):
axs[x, y].hist(data, bins=bins)
top = axs[x, y].get_ylim()[1]
mean, median, std = np.mean(data), np.median(data), np.std(data)
@ -453,23 +453,20 @@ def present_model(model: MCConcentrationModel, bins: int = 200) -> None:
axs[1, 1].set_title('Quantum generation rate')
axs[1, 1].set_xlabel('qR [log10($q\;h^{-1}$)]')
axs[1, 1].annotate('', xy=(mean + std, 2000), xytext=(np.max(data), 2000),
arrowprops={'arrowstyle':'<|-|>','ls':'dashed'})
axs[1, 1].text(mean + std + 0.1, 2100,'Superspreader',fontsize=8)
# TODO: Set height to highest bar
mean, std = np.mean(qRs), np.std(qRs)
axs[1, 1].annotate('', xy=(mean + std, 2000), xytext=(np.max(qRs), 2000),
arrowprops={'arrowstyle': '<|-|>', 'ls': 'dashed'})
axs[1, 1].text(mean + std + 0.1, 2100, 'Superspreader', fontsize=8)
mean_patch = patches.Patch(color='grey',label='Mean')
mean_patch = patches.Patch(color='grey', label='Mean')
median_patch = patches.Patch(color='black', label='Median')
std_patch = patches.Patch(color='lightgrey', linestyle='dashed', label='Standard deviations')
fig.legend(handles=(mean_patch, std_patch, median_patch))
# TODO: call print_qr_info
plt.show()
print(
10**np.median(data),
np.median(10**data),
np.mean(data),
np.mean(10**data), # is this correct?
np.std(data),
np.std(10**data)) # is this correct?
def buaonanno_exposure_model():