diff --git a/examples/emcee_vs_r_inference/.gitignore b/examples/emcee_vs_r_inference/.gitignore new file mode 100644 index 000000000..0ab4fc00d --- /dev/null +++ b/examples/emcee_vs_r_inference/.gitignore @@ -0,0 +1,2 @@ +model_output/ +*.ipynb \ No newline at end of file diff --git a/examples/emcee_vs_r_inference/Makefile b/examples/emcee_vs_r_inference/Makefile new file mode 100644 index 000000000..a44300baa --- /dev/null +++ b/examples/emcee_vs_r_inference/Makefile @@ -0,0 +1,91 @@ +# Cross platform customizations +ifeq ($(OS),Windows_NT) + RM_CMD = del /Q + RM_DIR = rmdir /S /Q +else + RM_CMD = rm -f + RM_DIR = rm -rf +endif + + +# Default target to build all outputs +.PHONY: all +all: \ + model_output/three_state_seir_stacked_hosp_stacked/emcee_inference \ + model_output/three_state_seir_stacked_hosp_stacked/r_inference + + +# Clean all generated files & targets +.PHONY: clean +clean: + $(RM_DIR) model_output + $(RM_CMD) model_input/ground_truth_hospitalizations.csv + $(RM_CMD) three_state_simulate.yml + $(RM_CMD) three_state_emcee_inference.yml + $(RM_CMD) three_state_r_inference.yml + $(RM_CMD) log*.txt + $(RM_CMD) *.h5 + $(RM_CMD) *.pdf + + +# Simulate ground truth data +three_state_simulate.yml: + flepimop patch \ + sim_base.yml \ + --method stochastic \ + --nslots 1 \ + --in-id sim > three_state_simulate.yml + +model_output/three_state_state_varied_Ro_state_varied_incidH/sim: three_state_simulate.yml + flepimop simulate \ + --seir_modifiers_scenarios state_varied_Ro \ + --outcome_modifiers_scenarios state_varied_incidH \ + three_state_simulate.yml + +model_input/ground_truth_hospitalizations.csv: model_output/three_state_state_varied_Ro_state_varied_incidH/sim + python ground_truth_from_simulation.py + + +# Inference with EMCEE +three_state_emcee_inference.yml: + flepimop patch \ + inference_base.yml \ + inference_outcome_modifiers.yml \ + inference_seir_modifiers.yml \ + emcee_inference.yml \ + > three_state_emcee_inference.yml + +model_output/three_state_seir_stacked_hosp_stacked/emcee_inference: \ + three_state_emcee_inference.yml model_input/ground_truth_hospitalizations.csv + flepimop-calibrate \ + --config three_state_emcee_inference.yml \ + --jobs 12 \ + --nslots 12 \ + --niterations 500 \ + --nsamples 100 \ + --id emcee_inference + + +# Inference with R +three_state_r_inference.yml: + flepimop patch \ + inference_base.yml \ + inference_outcome_modifiers.yml \ + inference_seir_modifiers.yml \ + r_inference.yml \ + > three_state_r_inference.yml + +model_output/three_state_seir_stacked_hosp_stacked/r_inference: \ + three_state_r_inference.yml model_input/ground_truth_hospitalizations.csv + flepimop-inference-main \ + --config three_state_r_inference.yml \ + --run_id r_inference \ + --seir_modifiers_scenarios seir_stacked \ + --outcome_modifiers_scenarios hosp_stacked \ + --jobs 12 \ + --iterations_per_slot 600 \ + --slots 12 \ + --python "$(CONDA_PREFIX)/bin/python" \ + --rpath "$(CONDA_PREFIX)/bin/R" \ + --save_hosp TRUE \ + --save_seir TRUE diff --git a/examples/emcee_vs_r_inference/README.md b/examples/emcee_vs_r_inference/README.md new file mode 100644 index 000000000..1dc11881a --- /dev/null +++ b/examples/emcee_vs_r_inference/README.md @@ -0,0 +1,3 @@ +# Three State Inference Comparison + +This example highlights the differences between R and EMCEE inference with a small but semi-realistic three state model. diff --git a/examples/emcee_vs_r_inference/emcee_inference.yml b/examples/emcee_vs_r_inference/emcee_inference.yml new file mode 100644 index 000000000..d441be0ea --- /dev/null +++ b/examples/emcee_vs_r_inference/emcee_inference.yml @@ -0,0 +1,13 @@ +inference: + iterations_per_slot: 250 + method: emcee + do_inference: true + gt_data_path: model_input/ground_truth_hospitalizations.csv + statistics: + incidH: + name: incidH + sim_var: hospitalizations + data_var: incidH + zero_to_one: True + likelihood: + dist: pois \ No newline at end of file diff --git a/examples/emcee_vs_r_inference/ground_truth_from_simulation.py b/examples/emcee_vs_r_inference/ground_truth_from_simulation.py new file mode 100644 index 000000000..24abded33 --- /dev/null +++ b/examples/emcee_vs_r_inference/ground_truth_from_simulation.py @@ -0,0 +1,24 @@ +from pathlib import Path + +from gempyor.utils import read_directory + + +project_path = Path(__file__).parent + +simulation_output = ( + project_path + / "model_output" + / "three_state_state_varied_Ro_state_varied_incidH" + / "sim" +) +if not simulation_output.exists(): + raise FileNotFoundError( + f"Simulation output directory {simulation_output} does not exist. " + "Please run the simulation first." + ) + +hosp = read_directory(simulation_output, filters="hosp") +hosp = hosp[["date", "subpop", "hospitalizations_curr"]] +hosp = hosp.rename(columns={"hospitalizations_curr": "incidH"}) +hosp["incidH"] = hosp["incidH"].astype(int) +hosp.to_csv(project_path / "model_input" / "ground_truth_hospitalizations.csv", index=False) diff --git a/examples/emcee_vs_r_inference/inference_base.yml b/examples/emcee_vs_r_inference/inference_base.yml new file mode 100644 index 000000000..3ffcff401 --- /dev/null +++ b/examples/emcee_vs_r_inference/inference_base.yml @@ -0,0 +1,58 @@ +name: three_state +start_date: 2024-01-01 +end_date: 2024-12-31 + +subpop_setup: + geodata: model_input/geodata.csv + mobility: model_input/mobility.csv + +initial_conditions: + method: SetInitialConditions + initial_conditions_file: model_input/initial_conditions.csv + allow_missing_subpops: TRUE + allow_missing_compartments: TRUE + +compartments: + infection_stage: ["S", "E", "I", "R"] + +seir: + integration: + method: rk4 + dt: 0.25 + parameters: + sigma: + value: 1 + gamma: + value: 1 + Ro: + value: 1 + transitions: + - source: ["S"] + destination: ["E"] + rate: ["Ro * gamma"] + proportional_to: [["S"], ["I"]] + proportion_exponent: ["1", "1"] + - source: ["E"] + destination: ["I"] + rate: ["sigma"] + proportional_to: ["E"] + proportion_exponent: ["1"] + - source: ["I"] + destination: ["R"] + rate: ["gamma"] + proportional_to: ["I"] + proportion_exponent: ["1"] + +outcomes: + method: delayframe + outcomes: + hospitalizations: + source: + incidence: + infection_stage: "I" + probability: + value: 1 + delay: + value: 1 + duration: + value: 1 \ No newline at end of file diff --git a/examples/emcee_vs_r_inference/inference_outcome_modifiers.yml b/examples/emcee_vs_r_inference/inference_outcome_modifiers.yml new file mode 100644 index 000000000..991aa87f9 --- /dev/null +++ b/examples/emcee_vs_r_inference/inference_outcome_modifiers.yml @@ -0,0 +1,78 @@ +outcome_modifiers: + scenarios: + - hosp_stacked + modifiers: + hosp_probability: + method: SinglePeriodModifier + parameter: incidH::probability + subpop: "all" + period_start_date: 2024-01-01 + period_end_date: 2024-12-31 + value: + distribution: truncnorm + mean: 0.1 + sd: 0.05 + a: 0 + b: 1 + perturbation: + distribution: truncnorm + mean: 0 + sd: 0.025 + a: -0.1 + b: 0.1 + hosp_delay: + method: SinglePeriodModifier + parameter: incidH::delay + subpop: "all" + subpop_groups: + - + - NC + - SC + - + - GA + period_start_date: 2024-01-01 + period_end_date: 2024-12-31 + value: + distribution: truncnorm + mean: 7.0 + sd: 2.0 + a: 0 + b: 21.0 + perturbation: + distribution: truncnorm + mean: 0 + sd: 0.025 + a: -0.1 + b: 0.1 + hosp_duration: + method: SinglePeriodModifier + parameter: incidH::duration + subpop: "all" + subpop_groups: + - + - NC + - SC + - GA + period_start_date: 2024-01-01 + period_end_date: 2024-12-31 + value: + distribution: truncnorm + mean: 14.0 + sd: 7.0 + a: 3.0 + b: 21.0 + perturbation: + distribution: truncnorm + mean: 0 + sd: 0.025 + a: -0.1 + b: 0.1 + hosp_stacked: + method: StackedModifier + subpop: "all" + period_start_date: 2024-01-01 + period_end_date: 2024-12-31 + modifiers: + - hosp_probability + - hosp_delay + - hosp_duration diff --git a/examples/emcee_vs_r_inference/inference_seir_modifiers.yml b/examples/emcee_vs_r_inference/inference_seir_modifiers.yml new file mode 100644 index 000000000..6ced4a15b --- /dev/null +++ b/examples/emcee_vs_r_inference/inference_seir_modifiers.yml @@ -0,0 +1,66 @@ +seir_modifiers: + scenarios: + - seir_stacked + modifiers: + sigma_fit: + method: SinglePeriodModifier + parameter: sigma + subpop: "all" + subpop_groups: ["NC", "SC", "GA"] + period_start_date: 2024-01-01 + period_end_date: 2024-12-31 + value: + distribution: truncnorm + mean: 0.5 + sd: 0.2 + a: 0 + b: 3 + perturbation: + distribution: truncnorm + mean: 0 + sd: 0.025 + a: -0.1 + b: 0.1 + gamma_fit: + method: SinglePeriodModifier + parameter: gamma + subpop: "all" + subpop_groups: ["NC", "SC", "GA"] + period_start_date: 2024-01-01 + period_end_date: 2024-12-31 + value: + distribution: truncnorm + mean: 0.5 + sd: 0.2 + a: 0 + b: 3 + perturbation: + distribution: truncnorm + mean: 0 + sd: 0.025 + a: -0.1 + b: 0.1 + Ro_fit: + method: SinglePeriodModifier + parameter: Ro + subpop: "all" + period_start_date: 2024-01-01 + period_end_date: 2024-12-31 + value: + distribution: truncnorm + mean: 2.5 + sd: 0.1 + a: 0 + b: 5 + perturbation: + distribution: truncnorm + mean: 0 + sd: 0.025 + a: -0.1 + b: 0.1 + seir_stacked: + method: StackedModifier + modifiers: + - sigma_fit + - gamma_fit + - Ro_fit diff --git a/examples/emcee_vs_r_inference/model_input/geodata.csv b/examples/emcee_vs_r_inference/model_input/geodata.csv new file mode 100644 index 000000000..d1e627c1f --- /dev/null +++ b/examples/emcee_vs_r_inference/model_input/geodata.csv @@ -0,0 +1,4 @@ +"subpop","population" +"NC","10264876" +"SC","5020806" +"GA","10403847" \ No newline at end of file diff --git a/examples/emcee_vs_r_inference/model_input/ground_truth_hospitalizations.csv b/examples/emcee_vs_r_inference/model_input/ground_truth_hospitalizations.csv new file mode 100644 index 000000000..af2c45e1b --- /dev/null +++ b/examples/emcee_vs_r_inference/model_input/ground_truth_hospitalizations.csv @@ -0,0 +1,1099 @@ +date,subpop,incidH +2024-01-01,NC,0 +2024-01-02,NC,1 +2024-01-03,NC,1 +2024-01-04,NC,5 +2024-01-05,NC,10 +2024-01-06,NC,20 +2024-01-07,NC,33 +2024-01-08,NC,54 +2024-01-09,NC,78 +2024-01-10,NC,113 +2024-01-11,NC,141 +2024-01-12,NC,177 +2024-01-13,NC,219 +2024-01-14,NC,274 +2024-01-15,NC,341 +2024-01-16,NC,404 +2024-01-17,NC,489 +2024-01-18,NC,601 +2024-01-19,NC,728 +2024-01-20,NC,856 +2024-01-21,NC,1037 +2024-01-22,NC,1239 +2024-01-23,NC,1507 +2024-01-24,NC,1756 +2024-01-25,NC,2087 +2024-01-26,NC,2439 +2024-01-27,NC,2893 +2024-01-28,NC,3427 +2024-01-29,NC,4043 +2024-01-30,NC,4758 +2024-01-31,NC,5523 +2024-02-01,NC,6559 +2024-02-02,NC,7638 +2024-02-03,NC,8974 +2024-02-04,NC,10453 +2024-02-05,NC,12293 +2024-02-06,NC,14216 +2024-02-07,NC,16522 +2024-02-08,NC,19127 +2024-02-09,NC,22114 +2024-02-10,NC,25833 +2024-02-11,NC,29773 +2024-02-12,NC,34426 +2024-02-13,NC,39579 +2024-02-14,NC,45644 +2024-02-15,NC,52627 +2024-02-16,NC,60665 +2024-02-17,NC,69619 +2024-02-18,NC,79756 +2024-02-19,NC,91528 +2024-02-20,NC,104736 +2024-02-21,NC,119599 +2024-02-22,NC,136457 +2024-02-23,NC,155502 +2024-02-24,NC,177165 +2024-02-25,NC,200811 +2024-02-26,NC,227739 +2024-02-27,NC,257525 +2024-02-28,NC,291062 +2024-02-29,NC,328321 +2024-03-01,NC,369836 +2024-03-02,NC,415665 +2024-03-03,NC,465737 +2024-03-04,NC,520785 +2024-03-05,NC,580322 +2024-03-06,NC,646413 +2024-03-07,NC,718444 +2024-03-08,NC,796757 +2024-03-09,NC,881076 +2024-03-10,NC,970807 +2024-03-11,NC,1066535 +2024-03-12,NC,1167926 +2024-03-13,NC,1274436 +2024-03-14,NC,1385522 +2024-03-15,NC,1501044 +2024-03-16,NC,1618270 +2024-03-17,NC,1736892 +2024-03-18,NC,1854146 +2024-03-19,NC,1968747 +2024-03-20,NC,2080723 +2024-03-21,NC,2186123 +2024-03-22,NC,2283897 +2024-03-23,NC,2372771 +2024-03-24,NC,2449498 +2024-03-25,NC,2512763 +2024-03-26,NC,2561626 +2024-03-27,NC,2593860 +2024-03-28,NC,2610003 +2024-03-29,NC,2611082 +2024-03-30,NC,2594540 +2024-03-31,NC,2561346 +2024-04-01,NC,2512150 +2024-04-02,NC,2450092 +2024-04-03,NC,2375535 +2024-04-04,NC,2288927 +2024-04-05,NC,2193660 +2024-04-06,NC,2091645 +2024-04-07,NC,1985426 +2024-04-08,NC,1874535 +2024-04-09,NC,1761613 +2024-04-10,NC,1649737 +2024-04-11,NC,1539361 +2024-04-12,NC,1430215 +2024-04-13,NC,1324746 +2024-04-14,NC,1223864 +2024-04-15,NC,1127211 +2024-04-16,NC,1035453 +2024-04-17,NC,949073 +2024-04-18,NC,867893 +2024-04-19,NC,792554 +2024-04-20,NC,722374 +2024-04-21,NC,657759 +2024-04-22,NC,597512 +2024-04-23,NC,542041 +2024-04-24,NC,491544 +2024-04-25,NC,445787 +2024-04-26,NC,403659 +2024-04-27,NC,364977 +2024-04-28,NC,330014 +2024-04-29,NC,298190 +2024-04-30,NC,269341 +2024-05-01,NC,243169 +2024-05-02,NC,219359 +2024-05-03,NC,197969 +2024-05-04,NC,178699 +2024-05-05,NC,160886 +2024-05-06,NC,145141 +2024-05-07,NC,130957 +2024-05-08,NC,118022 +2024-05-09,NC,106527 +2024-05-10,NC,95920 +2024-05-11,NC,86429 +2024-05-12,NC,78123 +2024-05-13,NC,70524 +2024-05-14,NC,63723 +2024-05-15,NC,57518 +2024-05-16,NC,51906 +2024-05-17,NC,46682 +2024-05-18,NC,42118 +2024-05-19,NC,37889 +2024-05-20,NC,34088 +2024-05-21,NC,30730 +2024-05-22,NC,27621 +2024-05-23,NC,24926 +2024-05-24,NC,22396 +2024-05-25,NC,20240 +2024-05-26,NC,18291 +2024-05-27,NC,16573 +2024-05-28,NC,14938 +2024-05-29,NC,13538 +2024-05-30,NC,12229 +2024-05-31,NC,10986 +2024-06-01,NC,9909 +2024-06-02,NC,8919 +2024-06-03,NC,8065 +2024-06-04,NC,7268 +2024-06-05,NC,6516 +2024-06-06,NC,5867 +2024-06-07,NC,5267 +2024-06-08,NC,4690 +2024-06-09,NC,4235 +2024-06-10,NC,3827 +2024-06-11,NC,3429 +2024-06-12,NC,3068 +2024-06-13,NC,2737 +2024-06-14,NC,2468 +2024-06-15,NC,2241 +2024-06-16,NC,1989 +2024-06-17,NC,1824 +2024-06-18,NC,1642 +2024-06-19,NC,1458 +2024-06-20,NC,1331 +2024-06-21,NC,1177 +2024-06-22,NC,1063 +2024-06-23,NC,958 +2024-06-24,NC,857 +2024-06-25,NC,763 +2024-06-26,NC,689 +2024-06-27,NC,617 +2024-06-28,NC,568 +2024-06-29,NC,505 +2024-06-30,NC,445 +2024-07-01,NC,417 +2024-07-02,NC,377 +2024-07-03,NC,343 +2024-07-04,NC,303 +2024-07-05,NC,274 +2024-07-06,NC,245 +2024-07-07,NC,214 +2024-07-08,NC,191 +2024-07-09,NC,175 +2024-07-10,NC,159 +2024-07-11,NC,146 +2024-07-12,NC,136 +2024-07-13,NC,124 +2024-07-14,NC,114 +2024-07-15,NC,99 +2024-07-16,NC,96 +2024-07-17,NC,97 +2024-07-18,NC,82 +2024-07-19,NC,77 +2024-07-20,NC,71 +2024-07-21,NC,64 +2024-07-22,NC,59 +2024-07-23,NC,52 +2024-07-24,NC,47 +2024-07-25,NC,43 +2024-07-26,NC,36 +2024-07-27,NC,26 +2024-07-28,NC,23 +2024-07-29,NC,18 +2024-07-30,NC,15 +2024-07-31,NC,13 +2024-08-01,NC,12 +2024-08-02,NC,11 +2024-08-03,NC,10 +2024-08-04,NC,10 +2024-08-05,NC,9 +2024-08-06,NC,10 +2024-08-07,NC,8 +2024-08-08,NC,10 +2024-08-09,NC,9 +2024-08-10,NC,8 +2024-08-11,NC,7 +2024-08-12,NC,5 +2024-08-13,NC,5 +2024-08-14,NC,6 +2024-08-15,NC,5 +2024-08-16,NC,5 +2024-08-17,NC,4 +2024-08-18,NC,2 +2024-08-19,NC,2 +2024-08-20,NC,3 +2024-08-21,NC,3 +2024-08-22,NC,3 +2024-08-23,NC,3 +2024-08-24,NC,2 +2024-08-25,NC,3 +2024-08-26,NC,2 +2024-08-27,NC,2 +2024-08-28,NC,2 +2024-08-29,NC,2 +2024-08-30,NC,1 +2024-08-31,NC,1 +2024-09-01,NC,1 +2024-09-02,NC,1 +2024-09-03,NC,1 +2024-09-04,NC,0 +2024-09-05,NC,0 +2024-09-06,NC,0 +2024-09-07,NC,0 +2024-09-08,NC,1 +2024-09-09,NC,1 +2024-09-10,NC,1 +2024-09-11,NC,1 +2024-09-12,NC,1 +2024-09-13,NC,1 +2024-09-14,NC,1 +2024-09-15,NC,1 +2024-09-16,NC,1 +2024-09-17,NC,1 +2024-09-18,NC,1 +2024-09-19,NC,1 +2024-09-20,NC,1 +2024-09-21,NC,1 +2024-09-22,NC,1 +2024-09-23,NC,1 +2024-09-24,NC,1 +2024-09-25,NC,1 +2024-09-26,NC,2 +2024-09-27,NC,2 +2024-09-28,NC,1 +2024-09-29,NC,1 +2024-09-30,NC,1 +2024-10-01,NC,1 +2024-10-02,NC,2 +2024-10-03,NC,2 +2024-10-04,NC,2 +2024-10-05,NC,2 +2024-10-06,NC,2 +2024-10-07,NC,2 +2024-10-08,NC,2 +2024-10-09,NC,2 +2024-10-10,NC,2 +2024-10-11,NC,2 +2024-10-12,NC,1 +2024-10-13,NC,1 +2024-10-14,NC,1 +2024-10-15,NC,1 +2024-10-16,NC,0 +2024-10-17,NC,0 +2024-10-18,NC,0 +2024-10-19,NC,0 +2024-10-20,NC,0 +2024-10-21,NC,0 +2024-10-22,NC,0 +2024-10-23,NC,0 +2024-10-24,NC,0 +2024-10-25,NC,1 +2024-10-26,NC,1 +2024-10-27,NC,1 +2024-10-28,NC,1 +2024-10-29,NC,1 +2024-10-30,NC,1 +2024-10-31,NC,1 +2024-11-01,NC,1 +2024-11-02,NC,1 +2024-11-03,NC,1 +2024-11-04,NC,0 +2024-11-05,NC,0 +2024-11-06,NC,0 +2024-11-07,NC,0 +2024-11-08,NC,0 +2024-11-09,NC,0 +2024-11-10,NC,0 +2024-11-11,NC,0 +2024-11-12,NC,0 +2024-11-13,NC,0 +2024-11-14,NC,0 +2024-11-15,NC,0 +2024-11-16,NC,1 +2024-11-17,NC,1 +2024-11-18,NC,1 +2024-11-19,NC,1 +2024-11-20,NC,1 +2024-11-21,NC,1 +2024-11-22,NC,1 +2024-11-23,NC,1 +2024-11-24,NC,1 +2024-11-25,NC,1 +2024-11-26,NC,0 +2024-11-27,NC,0 +2024-11-28,NC,0 +2024-11-29,NC,0 +2024-11-30,NC,0 +2024-12-01,NC,0 +2024-12-02,NC,0 +2024-12-03,NC,0 +2024-12-04,NC,0 +2024-12-05,NC,0 +2024-12-06,NC,0 +2024-12-07,NC,1 +2024-12-08,NC,2 +2024-12-09,NC,2 +2024-12-10,NC,2 +2024-12-11,NC,2 +2024-12-12,NC,2 +2024-12-13,NC,2 +2024-12-14,NC,2 +2024-12-15,NC,3 +2024-12-16,NC,3 +2024-12-17,NC,2 +2024-12-18,NC,1 +2024-12-19,NC,1 +2024-12-20,NC,1 +2024-12-21,NC,1 +2024-12-22,NC,1 +2024-12-23,NC,1 +2024-12-24,NC,1 +2024-12-25,NC,0 +2024-12-26,NC,0 +2024-12-27,NC,0 +2024-12-28,NC,0 +2024-12-29,NC,0 +2024-12-30,NC,0 +2024-12-31,NC,0 +2024-01-01,SC,0 +2024-01-02,SC,1137 +2024-01-03,SC,2159 +2024-01-04,SC,3290 +2024-01-05,SC,4454 +2024-01-06,SC,5751 +2024-01-07,SC,7088 +2024-01-08,SC,8720 +2024-01-09,SC,10542 +2024-01-10,SC,12575 +2024-01-11,SC,14954 +2024-01-12,SC,16343 +2024-01-13,SC,18271 +2024-01-14,SC,20523 +2024-01-15,SC,23128 +2024-01-16,SC,26153 +2024-01-17,SC,29680 +2024-01-18,SC,33466 +2024-01-19,SC,37877 +2024-01-20,SC,42874 +2024-01-21,SC,48441 +2024-01-22,SC,54961 +2024-01-23,SC,62142 +2024-01-24,SC,70118 +2024-01-25,SC,79157 +2024-01-26,SC,89195 +2024-01-27,SC,100620 +2024-01-28,SC,113407 +2024-01-29,SC,127517 +2024-01-30,SC,143452 +2024-01-31,SC,160864 +2024-02-01,SC,180493 +2024-02-02,SC,202310 +2024-02-03,SC,226323 +2024-02-04,SC,252631 +2024-02-05,SC,281884 +2024-02-06,SC,313839 +2024-02-07,SC,349186 +2024-02-08,SC,387345 +2024-02-09,SC,428402 +2024-02-10,SC,473345 +2024-02-11,SC,520829 +2024-02-12,SC,571545 +2024-02-13,SC,625326 +2024-02-14,SC,681575 +2024-02-15,SC,740331 +2024-02-16,SC,800139 +2024-02-17,SC,859679 +2024-02-18,SC,920580 +2024-02-19,SC,980852 +2024-02-20,SC,1038640 +2024-02-21,SC,1094092 +2024-02-22,SC,1144829 +2024-02-23,SC,1190635 +2024-02-24,SC,1230710 +2024-02-25,SC,1262708 +2024-02-26,SC,1288197 +2024-02-27,SC,1305379 +2024-02-28,SC,1313002 +2024-02-29,SC,1311892 +2024-03-01,SC,1302721 +2024-03-02,SC,1284095 +2024-03-03,SC,1258045 +2024-03-04,SC,1225202 +2024-03-05,SC,1185214 +2024-03-06,SC,1140422 +2024-03-07,SC,1089865 +2024-03-08,SC,1036987 +2024-03-09,SC,981738 +2024-03-10,SC,924258 +2024-03-11,SC,865999 +2024-03-12,SC,808588 +2024-03-13,SC,752124 +2024-03-14,SC,696470 +2024-03-15,SC,643282 +2024-03-16,SC,592564 +2024-03-17,SC,544577 +2024-03-18,SC,499062 +2024-03-19,SC,456469 +2024-03-20,SC,416923 +2024-03-21,SC,379931 +2024-03-22,SC,346013 +2024-03-23,SC,314610 +2024-03-24,SC,285810 +2024-03-25,SC,259657 +2024-03-26,SC,235257 +2024-03-27,SC,213403 +2024-03-28,SC,193265 +2024-03-29,SC,174827 +2024-03-30,SC,158735 +2024-03-31,SC,143984 +2024-04-01,SC,130377 +2024-04-02,SC,118096 +2024-04-03,SC,106993 +2024-04-04,SC,97025 +2024-04-05,SC,88097 +2024-04-06,SC,79926 +2024-04-07,SC,72702 +2024-04-08,SC,66209 +2024-04-09,SC,59926 +2024-04-10,SC,54487 +2024-04-11,SC,49561 +2024-04-12,SC,45142 +2024-04-13,SC,41217 +2024-04-14,SC,37398 +2024-04-15,SC,34078 +2024-04-16,SC,31045 +2024-04-17,SC,28317 +2024-04-18,SC,25713 +2024-04-19,SC,23401 +2024-04-20,SC,21273 +2024-04-21,SC,19382 +2024-04-22,SC,17714 +2024-04-23,SC,16102 +2024-04-24,SC,14698 +2024-04-25,SC,13336 +2024-04-26,SC,12159 +2024-04-27,SC,11092 +2024-04-28,SC,10099 +2024-04-29,SC,9220 +2024-04-30,SC,8410 +2024-05-01,SC,7669 +2024-05-02,SC,6972 +2024-05-03,SC,6377 +2024-05-04,SC,5825 +2024-05-05,SC,5290 +2024-05-06,SC,4847 +2024-05-07,SC,4393 +2024-05-08,SC,3977 +2024-05-09,SC,3617 +2024-05-10,SC,3276 +2024-05-11,SC,2960 +2024-05-12,SC,2707 +2024-05-13,SC,2453 +2024-05-14,SC,2212 +2024-05-15,SC,2029 +2024-05-16,SC,1828 +2024-05-17,SC,1653 +2024-05-18,SC,1491 +2024-05-19,SC,1340 +2024-05-20,SC,1226 +2024-05-21,SC,1130 +2024-05-22,SC,1024 +2024-05-23,SC,914 +2024-05-24,SC,837 +2024-05-25,SC,750 +2024-05-26,SC,679 +2024-05-27,SC,621 +2024-05-28,SC,567 +2024-05-29,SC,511 +2024-05-30,SC,455 +2024-05-31,SC,399 +2024-06-01,SC,359 +2024-06-02,SC,317 +2024-06-03,SC,286 +2024-06-04,SC,266 +2024-06-05,SC,236 +2024-06-06,SC,213 +2024-06-07,SC,198 +2024-06-08,SC,182 +2024-06-09,SC,165 +2024-06-10,SC,151 +2024-06-11,SC,143 +2024-06-12,SC,128 +2024-06-13,SC,121 +2024-06-14,SC,106 +2024-06-15,SC,99 +2024-06-16,SC,89 +2024-06-17,SC,83 +2024-06-18,SC,76 +2024-06-19,SC,69 +2024-06-20,SC,64 +2024-06-21,SC,57 +2024-06-22,SC,55 +2024-06-23,SC,44 +2024-06-24,SC,39 +2024-06-25,SC,29 +2024-06-26,SC,29 +2024-06-27,SC,28 +2024-06-28,SC,28 +2024-06-29,SC,26 +2024-06-30,SC,22 +2024-07-01,SC,18 +2024-07-02,SC,15 +2024-07-03,SC,14 +2024-07-04,SC,12 +2024-07-05,SC,14 +2024-07-06,SC,10 +2024-07-07,SC,9 +2024-07-08,SC,9 +2024-07-09,SC,7 +2024-07-10,SC,7 +2024-07-11,SC,7 +2024-07-12,SC,6 +2024-07-13,SC,7 +2024-07-14,SC,7 +2024-07-15,SC,5 +2024-07-16,SC,5 +2024-07-17,SC,4 +2024-07-18,SC,2 +2024-07-19,SC,2 +2024-07-20,SC,2 +2024-07-21,SC,1 +2024-07-22,SC,1 +2024-07-23,SC,0 +2024-07-24,SC,0 +2024-07-25,SC,0 +2024-07-26,SC,0 +2024-07-27,SC,0 +2024-07-28,SC,0 +2024-07-29,SC,0 +2024-07-30,SC,0 +2024-07-31,SC,0 +2024-08-01,SC,0 +2024-08-02,SC,1 +2024-08-03,SC,1 +2024-08-04,SC,1 +2024-08-05,SC,1 +2024-08-06,SC,1 +2024-08-07,SC,1 +2024-08-08,SC,2 +2024-08-09,SC,2 +2024-08-10,SC,2 +2024-08-11,SC,2 +2024-08-12,SC,1 +2024-08-13,SC,1 +2024-08-14,SC,1 +2024-08-15,SC,1 +2024-08-16,SC,2 +2024-08-17,SC,2 +2024-08-18,SC,2 +2024-08-19,SC,2 +2024-08-20,SC,2 +2024-08-21,SC,2 +2024-08-22,SC,2 +2024-08-23,SC,2 +2024-08-24,SC,2 +2024-08-25,SC,2 +2024-08-26,SC,1 +2024-08-27,SC,1 +2024-08-28,SC,0 +2024-08-29,SC,0 +2024-08-30,SC,0 +2024-08-31,SC,0 +2024-09-01,SC,0 +2024-09-02,SC,0 +2024-09-03,SC,0 +2024-09-04,SC,0 +2024-09-05,SC,0 +2024-09-06,SC,0 +2024-09-07,SC,0 +2024-09-08,SC,0 +2024-09-09,SC,0 +2024-09-10,SC,0 +2024-09-11,SC,0 +2024-09-12,SC,0 +2024-09-13,SC,0 +2024-09-14,SC,0 +2024-09-15,SC,0 +2024-09-16,SC,1 +2024-09-17,SC,1 +2024-09-18,SC,1 +2024-09-19,SC,1 +2024-09-20,SC,1 +2024-09-21,SC,1 +2024-09-22,SC,1 +2024-09-23,SC,1 +2024-09-24,SC,1 +2024-09-25,SC,1 +2024-09-26,SC,0 +2024-09-27,SC,0 +2024-09-28,SC,0 +2024-09-29,SC,0 +2024-09-30,SC,0 +2024-10-01,SC,0 +2024-10-02,SC,0 +2024-10-03,SC,0 +2024-10-04,SC,0 +2024-10-05,SC,0 +2024-10-06,SC,0 +2024-10-07,SC,0 +2024-10-08,SC,1 +2024-10-09,SC,1 +2024-10-10,SC,1 +2024-10-11,SC,1 +2024-10-12,SC,1 +2024-10-13,SC,1 +2024-10-14,SC,1 +2024-10-15,SC,1 +2024-10-16,SC,1 +2024-10-17,SC,1 +2024-10-18,SC,0 +2024-10-19,SC,0 +2024-10-20,SC,0 +2024-10-21,SC,0 +2024-10-22,SC,0 +2024-10-23,SC,0 +2024-10-24,SC,0 +2024-10-25,SC,0 +2024-10-26,SC,0 +2024-10-27,SC,0 +2024-10-28,SC,0 +2024-10-29,SC,0 +2024-10-30,SC,0 +2024-10-31,SC,0 +2024-11-01,SC,0 +2024-11-02,SC,0 +2024-11-03,SC,1 +2024-11-04,SC,1 +2024-11-05,SC,1 +2024-11-06,SC,1 +2024-11-07,SC,1 +2024-11-08,SC,1 +2024-11-09,SC,1 +2024-11-10,SC,1 +2024-11-11,SC,1 +2024-11-12,SC,1 +2024-11-13,SC,0 +2024-11-14,SC,0 +2024-11-15,SC,0 +2024-11-16,SC,0 +2024-11-17,SC,0 +2024-11-18,SC,0 +2024-11-19,SC,0 +2024-11-20,SC,0 +2024-11-21,SC,0 +2024-11-22,SC,0 +2024-11-23,SC,0 +2024-11-24,SC,0 +2024-11-25,SC,0 +2024-11-26,SC,0 +2024-11-27,SC,0 +2024-11-28,SC,0 +2024-11-29,SC,0 +2024-11-30,SC,0 +2024-12-01,SC,0 +2024-12-02,SC,0 +2024-12-03,SC,0 +2024-12-04,SC,0 +2024-12-05,SC,0 +2024-12-06,SC,0 +2024-12-07,SC,0 +2024-12-08,SC,0 +2024-12-09,SC,0 +2024-12-10,SC,0 +2024-12-11,SC,0 +2024-12-12,SC,0 +2024-12-13,SC,0 +2024-12-14,SC,0 +2024-12-15,SC,0 +2024-12-16,SC,0 +2024-12-17,SC,0 +2024-12-18,SC,0 +2024-12-19,SC,0 +2024-12-20,SC,0 +2024-12-21,SC,1 +2024-12-22,SC,1 +2024-12-23,SC,1 +2024-12-24,SC,1 +2024-12-25,SC,1 +2024-12-26,SC,1 +2024-12-27,SC,1 +2024-12-28,SC,1 +2024-12-29,SC,1 +2024-12-30,SC,1 +2024-12-31,SC,0 +2024-01-01,GA,0 +2024-01-02,GA,0 +2024-01-03,GA,3 +2024-01-04,GA,5 +2024-01-05,GA,8 +2024-01-06,GA,11 +2024-01-07,GA,20 +2024-01-08,GA,29 +2024-01-09,GA,43 +2024-01-10,GA,49 +2024-01-11,GA,68 +2024-01-12,GA,89 +2024-01-13,GA,115 +2024-01-14,GA,145 +2024-01-15,GA,184 +2024-01-16,GA,224 +2024-01-17,GA,267 +2024-01-18,GA,330 +2024-01-19,GA,408 +2024-01-20,GA,509 +2024-01-21,GA,620 +2024-01-22,GA,769 +2024-01-23,GA,931 +2024-01-24,GA,1099 +2024-01-25,GA,1315 +2024-01-26,GA,1576 +2024-01-27,GA,1867 +2024-01-28,GA,2199 +2024-01-29,GA,2595 +2024-01-30,GA,3066 +2024-01-31,GA,3608 +2024-02-01,GA,4239 +2024-02-02,GA,4974 +2024-02-03,GA,5853 +2024-02-04,GA,6879 +2024-02-05,GA,8009 +2024-02-06,GA,9321 +2024-02-07,GA,10924 +2024-02-08,GA,12836 +2024-02-09,GA,14981 +2024-02-10,GA,17525 +2024-02-11,GA,20387 +2024-02-12,GA,23780 +2024-02-13,GA,27741 +2024-02-14,GA,32185 +2024-02-15,GA,37405 +2024-02-16,GA,43478 +2024-02-17,GA,50370 +2024-02-18,GA,58102 +2024-02-19,GA,67389 +2024-02-20,GA,77431 +2024-02-21,GA,89472 +2024-02-22,GA,103255 +2024-02-23,GA,118682 +2024-02-24,GA,136485 +2024-02-25,GA,156621 +2024-02-26,GA,179277 +2024-02-27,GA,204997 +2024-02-28,GA,234385 +2024-02-29,GA,267378 +2024-03-01,GA,305064 +2024-03-02,GA,346750 +2024-03-03,GA,393650 +2024-03-04,GA,446741 +2024-03-05,GA,505536 +2024-03-06,GA,571107 +2024-03-07,GA,644372 +2024-03-08,GA,724948 +2024-03-09,GA,813454 +2024-03-10,GA,909599 +2024-03-11,GA,1015100 +2024-03-12,GA,1128258 +2024-03-13,GA,1249639 +2024-03-14,GA,1378915 +2024-03-15,GA,1515357 +2024-03-16,GA,1657041 +2024-03-17,GA,1802376 +2024-03-18,GA,1949604 +2024-03-19,GA,2097597 +2024-03-20,GA,2243454 +2024-03-21,GA,2382816 +2024-03-22,GA,2513930 +2024-03-23,GA,2633964 +2024-03-24,GA,2738694 +2024-03-25,GA,2827986 +2024-03-26,GA,2897011 +2024-03-27,GA,2945790 +2024-03-28,GA,2973434 +2024-03-29,GA,2977028 +2024-03-30,GA,2957672 +2024-03-31,GA,2916945 +2024-04-01,GA,2855338 +2024-04-02,GA,2775205 +2024-04-03,GA,2678544 +2024-04-04,GA,2566304 +2024-04-05,GA,2444085 +2024-04-06,GA,2312538 +2024-04-07,GA,2175701 +2024-04-08,GA,2036278 +2024-04-09,GA,1895342 +2024-04-10,GA,1755484 +2024-04-11,GA,1619931 +2024-04-12,GA,1487703 +2024-04-13,GA,1361731 +2024-04-14,GA,1242326 +2024-04-15,GA,1130683 +2024-04-16,GA,1026325 +2024-04-17,GA,928783 +2024-04-18,GA,838126 +2024-04-19,GA,756354 +2024-04-20,GA,680845 +2024-04-21,GA,611821 +2024-04-22,GA,549572 +2024-04-23,GA,493008 +2024-04-24,GA,441879 +2024-04-25,GA,395548 +2024-04-26,GA,353990 +2024-04-27,GA,316438 +2024-04-28,GA,283068 +2024-04-29,GA,252582 +2024-04-30,GA,225984 +2024-05-01,GA,202046 +2024-05-02,GA,180484 +2024-05-03,GA,161354 +2024-05-04,GA,144328 +2024-05-05,GA,128718 +2024-05-06,GA,114963 +2024-05-07,GA,102684 +2024-05-08,GA,91645 +2024-05-09,GA,81811 +2024-05-10,GA,72820 +2024-05-11,GA,64928 +2024-05-12,GA,57998 +2024-05-13,GA,51750 +2024-05-14,GA,46049 +2024-05-15,GA,41140 +2024-05-16,GA,36757 +2024-05-17,GA,32730 +2024-05-18,GA,29337 +2024-05-19,GA,26178 +2024-05-20,GA,23359 +2024-05-21,GA,20935 +2024-05-22,GA,18688 +2024-05-23,GA,16625 +2024-05-24,GA,14942 +2024-05-25,GA,13395 +2024-05-26,GA,11913 +2024-05-27,GA,10694 +2024-05-28,GA,9520 +2024-05-29,GA,8557 +2024-05-30,GA,7637 +2024-05-31,GA,6803 +2024-06-01,GA,6108 +2024-06-02,GA,5514 +2024-06-03,GA,4905 +2024-06-04,GA,4338 +2024-06-05,GA,3878 +2024-06-06,GA,3483 +2024-06-07,GA,3119 +2024-06-08,GA,2766 +2024-06-09,GA,2528 +2024-06-10,GA,2279 +2024-06-11,GA,2029 +2024-06-12,GA,1782 +2024-06-13,GA,1606 +2024-06-14,GA,1432 +2024-06-15,GA,1280 +2024-06-16,GA,1137 +2024-06-17,GA,996 +2024-06-18,GA,894 +2024-06-19,GA,792 +2024-06-20,GA,686 +2024-06-21,GA,600 +2024-06-22,GA,541 +2024-06-23,GA,474 +2024-06-24,GA,423 +2024-06-25,GA,376 +2024-06-26,GA,328 +2024-06-27,GA,294 +2024-06-28,GA,269 +2024-06-29,GA,238 +2024-06-30,GA,212 +2024-07-01,GA,184 +2024-07-02,GA,164 +2024-07-03,GA,150 +2024-07-04,GA,140 +2024-07-05,GA,124 +2024-07-06,GA,106 +2024-07-07,GA,90 +2024-07-08,GA,78 +2024-07-09,GA,73 +2024-07-10,GA,69 +2024-07-11,GA,61 +2024-07-12,GA,61 +2024-07-13,GA,57 +2024-07-14,GA,49 +2024-07-15,GA,48 +2024-07-16,GA,45 +2024-07-17,GA,46 +2024-07-18,GA,41 +2024-07-19,GA,33 +2024-07-20,GA,29 +2024-07-21,GA,30 +2024-07-22,GA,25 +2024-07-23,GA,19 +2024-07-24,GA,18 +2024-07-25,GA,15 +2024-07-26,GA,11 +2024-07-27,GA,9 +2024-07-28,GA,7 +2024-07-29,GA,6 +2024-07-30,GA,5 +2024-07-31,GA,3 +2024-08-01,GA,5 +2024-08-02,GA,6 +2024-08-03,GA,6 +2024-08-04,GA,5 +2024-08-05,GA,5 +2024-08-06,GA,4 +2024-08-07,GA,4 +2024-08-08,GA,4 +2024-08-09,GA,4 +2024-08-10,GA,4 +2024-08-11,GA,1 +2024-08-12,GA,0 +2024-08-13,GA,0 +2024-08-14,GA,0 +2024-08-15,GA,1 +2024-08-16,GA,1 +2024-08-17,GA,1 +2024-08-18,GA,1 +2024-08-19,GA,1 +2024-08-20,GA,1 +2024-08-21,GA,2 +2024-08-22,GA,2 +2024-08-23,GA,2 +2024-08-24,GA,2 +2024-08-25,GA,1 +2024-08-26,GA,1 +2024-08-27,GA,1 +2024-08-28,GA,1 +2024-08-29,GA,1 +2024-08-30,GA,1 +2024-08-31,GA,0 +2024-09-01,GA,0 +2024-09-02,GA,0 +2024-09-03,GA,0 +2024-09-04,GA,0 +2024-09-05,GA,0 +2024-09-06,GA,0 +2024-09-07,GA,0 +2024-09-08,GA,0 +2024-09-09,GA,0 +2024-09-10,GA,0 +2024-09-11,GA,0 +2024-09-12,GA,0 +2024-09-13,GA,0 +2024-09-14,GA,0 +2024-09-15,GA,0 +2024-09-16,GA,0 +2024-09-17,GA,0 +2024-09-18,GA,0 +2024-09-19,GA,0 +2024-09-20,GA,0 +2024-09-21,GA,0 +2024-09-22,GA,0 +2024-09-23,GA,0 +2024-09-24,GA,0 +2024-09-25,GA,0 +2024-09-26,GA,0 +2024-09-27,GA,0 +2024-09-28,GA,0 +2024-09-29,GA,0 +2024-09-30,GA,0 +2024-10-01,GA,0 +2024-10-02,GA,0 +2024-10-03,GA,0 +2024-10-04,GA,0 +2024-10-05,GA,0 +2024-10-06,GA,0 +2024-10-07,GA,0 +2024-10-08,GA,0 +2024-10-09,GA,0 +2024-10-10,GA,0 +2024-10-11,GA,0 +2024-10-12,GA,0 +2024-10-13,GA,0 +2024-10-14,GA,0 +2024-10-15,GA,2 +2024-10-16,GA,2 +2024-10-17,GA,2 +2024-10-18,GA,2 +2024-10-19,GA,2 +2024-10-20,GA,2 +2024-10-21,GA,2 +2024-10-22,GA,2 +2024-10-23,GA,2 +2024-10-24,GA,2 +2024-10-25,GA,0 +2024-10-26,GA,0 +2024-10-27,GA,0 +2024-10-28,GA,0 +2024-10-29,GA,0 +2024-10-30,GA,0 +2024-10-31,GA,0 +2024-11-01,GA,0 +2024-11-02,GA,0 +2024-11-03,GA,0 +2024-11-04,GA,0 +2024-11-05,GA,0 +2024-11-06,GA,0 +2024-11-07,GA,0 +2024-11-08,GA,0 +2024-11-09,GA,1 +2024-11-10,GA,1 +2024-11-11,GA,1 +2024-11-12,GA,1 +2024-11-13,GA,1 +2024-11-14,GA,1 +2024-11-15,GA,1 +2024-11-16,GA,1 +2024-11-17,GA,1 +2024-11-18,GA,1 +2024-11-19,GA,0 +2024-11-20,GA,0 +2024-11-21,GA,1 +2024-11-22,GA,1 +2024-11-23,GA,1 +2024-11-24,GA,1 +2024-11-25,GA,1 +2024-11-26,GA,1 +2024-11-27,GA,1 +2024-11-28,GA,1 +2024-11-29,GA,1 +2024-11-30,GA,1 +2024-12-01,GA,0 +2024-12-02,GA,0 +2024-12-03,GA,0 +2024-12-04,GA,0 +2024-12-05,GA,0 +2024-12-06,GA,0 +2024-12-07,GA,0 +2024-12-08,GA,0 +2024-12-09,GA,0 +2024-12-10,GA,0 +2024-12-11,GA,0 +2024-12-12,GA,0 +2024-12-13,GA,0 +2024-12-14,GA,0 +2024-12-15,GA,0 +2024-12-16,GA,0 +2024-12-17,GA,0 +2024-12-18,GA,0 +2024-12-19,GA,0 +2024-12-20,GA,0 +2024-12-21,GA,0 +2024-12-22,GA,0 +2024-12-23,GA,0 +2024-12-24,GA,0 +2024-12-25,GA,0 +2024-12-26,GA,0 +2024-12-27,GA,0 +2024-12-28,GA,0 +2024-12-29,GA,0 +2024-12-30,GA,0 +2024-12-31,GA,0 diff --git a/examples/emcee_vs_r_inference/model_input/initial_conditions.csv b/examples/emcee_vs_r_inference/model_input/initial_conditions.csv new file mode 100644 index 000000000..cb26fe982 --- /dev/null +++ b/examples/emcee_vs_r_inference/model_input/initial_conditions.csv @@ -0,0 +1,6 @@ +"subpop","mc_name","amount" +"NC","S","10264876" +"SC","S","5014806" +"SC","E","5000" +"SC","I","1000" +"GA","S","10403847" \ No newline at end of file diff --git a/examples/emcee_vs_r_inference/model_input/mobility.csv b/examples/emcee_vs_r_inference/model_input/mobility.csv new file mode 100644 index 000000000..caf0a90bd --- /dev/null +++ b/examples/emcee_vs_r_inference/model_input/mobility.csv @@ -0,0 +1,7 @@ +ori,dest,amount +"NC","SC","83998" +"SC","NC","134650" +"NC","GA","12498" +"GA","NC","11004" +"SC","GA","49722" +"GA","SC","41754" \ No newline at end of file diff --git a/examples/emcee_vs_r_inference/r_inference.yml b/examples/emcee_vs_r_inference/r_inference.yml new file mode 100644 index 000000000..85b5c67e4 --- /dev/null +++ b/examples/emcee_vs_r_inference/r_inference.yml @@ -0,0 +1,15 @@ +inference: + iterations_per_slot: 250 + do_inference: true + gt_data_path: model_input/ground_truth_hospitalizations.csv + statistics: + incidH: + name: incidH + sim_var: hospitalizations + data_var: incidH + add_one: True + remove_na: False + aggregator: sum + period: 1 days + likelihood: + dist: pois diff --git a/examples/emcee_vs_r_inference/sim_base.yml b/examples/emcee_vs_r_inference/sim_base.yml new file mode 100644 index 000000000..9232ef863 --- /dev/null +++ b/examples/emcee_vs_r_inference/sim_base.yml @@ -0,0 +1,147 @@ +name: three_state +start_date: 2024-01-01 +end_date: 2024-12-31 +nslots: 8 +jobs: 8 + +subpop_setup: + geodata: model_input/geodata.csv + mobility: model_input/mobility.csv + +initial_conditions: + method: SetInitialConditions + initial_conditions_file: model_input/initial_conditions.csv + allow_missing_subpops: TRUE + allow_missing_compartments: TRUE + +compartments: + infection_stage: ["S", "E", "I", "R"] + +seir: + integration: + method: rk4 + dt: 0.25 + parameters: + sigma: + value: 1 / 4 + gamma: + value: 1 / 5 + Ro: + value: 1 + transitions: + - source: ["S"] + destination: ["E"] + rate: ["Ro * gamma"] + proportional_to: [["S"], ["I"]] + proportion_exponent: ["1", "1"] + - source: ["E"] + destination: ["I"] + rate: ["sigma"] + proportional_to: ["E"] + proportion_exponent: ["1"] + - source: ["I"] + destination: ["R"] + rate: ["gamma"] + proportional_to: ["I"] + proportion_exponent: ["1"] + +seir_modifiers: + scenarios: + - state_varied_Ro + modifiers: + nc_varied_Ro: + method: SinglePeriodModifier + parameter: Ro + subpop: NC + period_start_date: 2024-01-01 + period_end_date: 2024-12-31 + value: 2.4 + sc_varied_Ro: + method: SinglePeriodModifier + parameter: Ro + subpop: SC + period_start_date: 2024-01-01 + period_end_date: 2024-12-31 + value: 2.5 + ga_varied_Ro: + method: SinglePeriodModifier + parameter: Ro + subpop: GA + period_start_date: 2024-01-01 + period_end_date: 2024-12-31 + value: 2.6 + state_varied_Ro: + method: StackedModifier + modifiers: + - nc_varied_Ro + - sc_varied_Ro + - ga_varied_Ro + + +outcomes: + method: delayframe + outcomes: + hospitalizations: + source: + incidence: + infection_stage: "I" + probability: + value: 1 + delay: + value: 1 + duration: + value: 10 + +outcome_modifiers: + scenarios: + - state_varied_incidH + modifiers: + nc_hosp_probability: + method: SinglePeriodModifier + parameter: incidH::probability + subpop: NC + period_start_date: 2024-01-01 + period_end_date: 2024-12-31 + value: 0.05 + sc_hosp_probability: + method: SinglePeriodModifier + parameter: incidH::probability + subpop: SC + period_start_date: 2024-01-01 + period_end_date: 2024-12-31 + value: 0.09 + ga_hosp_probability: + method: SinglePeriodModifier + parameter: incidH::probability + subpop: NC + period_start_date: 2024-01-01 + period_end_date: 2024-12-31 + value: 0.07 + nc_sc_hosp_delay: + method: SinglePeriodModifier + parameter: incidH::delay + subpop: + - NC + - SC + subpop_groups: + - + - NC + - SC + period_start_date: 2024-01-01 + period_end_date: 2024-12-31 + value: 7 + ga_hosp_delay: + method: SinglePeriodModifier + parameter: incidH::delay + subpop: GA + period_start_date: 2024-01-01 + period_end_date: 2024-12-31 + value: 9 + state_varied_incidH: + method: StackedModifier + modifiers: + - nc_hosp_probability + - sc_hosp_probability + - ga_hosp_probability + - nc_sc_hosp_delay + - ga_hosp_delay diff --git a/examples/emcee_vs_r_inference/three_state_emcee_inference.yml b/examples/emcee_vs_r_inference/three_state_emcee_inference.yml new file mode 100644 index 000000000..5c8508348 --- /dev/null +++ b/examples/emcee_vs_r_inference/three_state_emcee_inference.yml @@ -0,0 +1,189 @@ +write_parquet: true +write_csv: false +jobs: 14 +first_sim_index: 1 +config_src: [inference_base.yml, inference_outcome_modifiers.yml, inference_seir_modifiers.yml, + emcee_inference.yml] +name: three_state +start_date: 2024-01-01 +end_date: 2024-12-31 +subpop_setup: + geodata: model_input/geodata.csv + mobility: model_input/mobility.csv +initial_conditions: + method: SetInitialConditions + initial_conditions_file: model_input/initial_conditions.csv + allow_missing_subpops: true + allow_missing_compartments: true +compartments: + infection_stage: [S, E, I, R] +seir: + integration: + method: rk4 + dt: 0.25 + parameters: + sigma: + value: 1 + gamma: + value: 1 + Ro: + value: 1 + transitions: [{source: [S], destination: [E], rate: ["Ro * gamma"], proportional_to: [ + [S], [I]], proportion_exponent: ['1', '1']}, {source: [E], destination: [ + I], rate: [sigma], proportional_to: [E], proportion_exponent: ['1']}, + {source: [I], destination: [R], rate: [gamma], proportional_to: [I], proportion_exponent: [ + '1']}] +outcomes: + method: delayframe + outcomes: + hospitalizations: + source: + incidence: + infection_stage: I + probability: + value: 1 + delay: + value: 1 + duration: + value: 1 +outcome_modifiers: + scenarios: [hosp_stacked] + modifiers: + hosp_probability: + method: SinglePeriodModifier + parameter: incidH::probability + subpop: all + period_start_date: 2024-01-01 + period_end_date: 2024-12-31 + value: + distribution: truncnorm + mean: 0.1 + sd: 0.05 + a: 0 + b: 1 + perturbation: + distribution: truncnorm + mean: 0 + sd: 0.025 + a: -0.1 + b: 0.1 + hosp_delay: + method: SinglePeriodModifier + parameter: incidH::delay + subpop: all + subpop_groups: [[NC, SC], [GA]] + period_start_date: 2024-01-01 + period_end_date: 2024-12-31 + value: + distribution: truncnorm + mean: 7.0 + sd: 2.0 + a: 0 + b: 21.0 + perturbation: + distribution: truncnorm + mean: 0 + sd: 0.025 + a: -0.1 + b: 0.1 + hosp_duration: + method: SinglePeriodModifier + parameter: incidH::duration + subpop: all + subpop_groups: [[NC, SC, GA]] + period_start_date: 2024-01-01 + period_end_date: 2024-12-31 + value: + distribution: truncnorm + mean: 14.0 + sd: 7.0 + a: 3.0 + b: 21.0 + perturbation: + distribution: truncnorm + mean: 0 + sd: 0.025 + a: -0.1 + b: 0.1 + hosp_stacked: + method: StackedModifier + subpop: all + period_start_date: 2024-01-01 + period_end_date: 2024-12-31 + modifiers: [hosp_probability, hosp_delay, hosp_duration] +seir_modifiers: + scenarios: [seir_stacked] + modifiers: + sigma_fit: + method: SinglePeriodModifier + parameter: sigma + subpop: all + subpop_groups: [NC, SC, GA] + period_start_date: 2024-01-01 + period_end_date: 2024-12-31 + value: + distribution: truncnorm + mean: 0.5 + sd: 0.2 + a: 0 + b: 3 + perturbation: + distribution: truncnorm + mean: 0 + sd: 0.025 + a: -0.1 + b: 0.1 + gamma_fit: + method: SinglePeriodModifier + parameter: gamma + subpop: all + subpop_groups: [NC, SC, GA] + period_start_date: 2024-01-01 + period_end_date: 2024-12-31 + value: + distribution: truncnorm + mean: 0.5 + sd: 0.2 + a: 0 + b: 3 + perturbation: + distribution: truncnorm + mean: 0 + sd: 0.025 + a: -0.1 + b: 0.1 + Ro_fit: + method: SinglePeriodModifier + parameter: Ro + subpop: all + period_start_date: 2024-01-01 + period_end_date: 2024-12-31 + value: + distribution: truncnorm + mean: 2.5 + sd: 0.1 + a: 0 + b: 5 + perturbation: + distribution: truncnorm + mean: 0 + sd: 0.025 + a: -0.1 + b: 0.1 + seir_stacked: + method: StackedModifier + modifiers: [sigma_fit, gamma_fit, Ro_fit] +inference: + iterations_per_slot: 250 + method: emcee + do_inference: true + gt_data_path: model_input/ground_truth_hospitalizations.csv + statistics: + incidH: + name: incidH + sim_var: hospitalizations + data_var: incidH + zero_to_one: true + likelihood: + dist: pois + diff --git a/examples/emcee_vs_r_inference/three_state_r_inference.yml b/examples/emcee_vs_r_inference/three_state_r_inference.yml new file mode 100644 index 000000000..dc8fe58f8 --- /dev/null +++ b/examples/emcee_vs_r_inference/three_state_r_inference.yml @@ -0,0 +1,191 @@ +write_csv: false +first_sim_index: 1 +jobs: 14 +write_parquet: true +config_src: [inference_base.yml, inference_outcome_modifiers.yml, inference_seir_modifiers.yml, + r_inference.yml] +name: three_state +start_date: 2024-01-01 +end_date: 2024-12-31 +subpop_setup: + geodata: model_input/geodata.csv + mobility: model_input/mobility.csv +initial_conditions: + method: SetInitialConditions + initial_conditions_file: model_input/initial_conditions.csv + allow_missing_subpops: true + allow_missing_compartments: true +compartments: + infection_stage: [S, E, I, R] +seir: + integration: + method: rk4 + dt: 0.25 + parameters: + sigma: + value: 1 + gamma: + value: 1 + Ro: + value: 1 + transitions: [{source: [S], destination: [E], rate: ["Ro * gamma"], proportional_to: [ + [S], [I]], proportion_exponent: ['1', '1']}, {source: [E], destination: [ + I], rate: [sigma], proportional_to: [E], proportion_exponent: ['1']}, + {source: [I], destination: [R], rate: [gamma], proportional_to: [I], proportion_exponent: [ + '1']}] +outcomes: + method: delayframe + outcomes: + hospitalizations: + source: + incidence: + infection_stage: I + probability: + value: 1 + delay: + value: 1 + duration: + value: 1 +outcome_modifiers: + scenarios: [hosp_stacked] + modifiers: + hosp_probability: + method: SinglePeriodModifier + parameter: incidH::probability + subpop: all + period_start_date: 2024-01-01 + period_end_date: 2024-12-31 + value: + distribution: truncnorm + mean: 0.1 + sd: 0.05 + a: 0 + b: 1 + perturbation: + distribution: truncnorm + mean: 0 + sd: 0.025 + a: -0.1 + b: 0.1 + hosp_delay: + method: SinglePeriodModifier + parameter: incidH::delay + subpop: all + subpop_groups: [[NC, SC], [GA]] + period_start_date: 2024-01-01 + period_end_date: 2024-12-31 + value: + distribution: truncnorm + mean: 7.0 + sd: 2.0 + a: 0 + b: 21.0 + perturbation: + distribution: truncnorm + mean: 0 + sd: 0.025 + a: -0.1 + b: 0.1 + hosp_duration: + method: SinglePeriodModifier + parameter: incidH::duration + subpop: all + subpop_groups: [[NC, SC, GA]] + period_start_date: 2024-01-01 + period_end_date: 2024-12-31 + value: + distribution: truncnorm + mean: 14.0 + sd: 7.0 + a: 3.0 + b: 21.0 + perturbation: + distribution: truncnorm + mean: 0 + sd: 0.025 + a: -0.1 + b: 0.1 + hosp_stacked: + method: StackedModifier + subpop: all + period_start_date: 2024-01-01 + period_end_date: 2024-12-31 + modifiers: [hosp_probability, hosp_delay, hosp_duration] +seir_modifiers: + scenarios: [seir_stacked] + modifiers: + sigma_fit: + method: SinglePeriodModifier + parameter: sigma + subpop: all + subpop_groups: [NC, SC, GA] + period_start_date: 2024-01-01 + period_end_date: 2024-12-31 + value: + distribution: truncnorm + mean: 0.5 + sd: 0.2 + a: 0 + b: 3 + perturbation: + distribution: truncnorm + mean: 0 + sd: 0.025 + a: -0.1 + b: 0.1 + gamma_fit: + method: SinglePeriodModifier + parameter: gamma + subpop: all + subpop_groups: [NC, SC, GA] + period_start_date: 2024-01-01 + period_end_date: 2024-12-31 + value: + distribution: truncnorm + mean: 0.5 + sd: 0.2 + a: 0 + b: 3 + perturbation: + distribution: truncnorm + mean: 0 + sd: 0.025 + a: -0.1 + b: 0.1 + Ro_fit: + method: SinglePeriodModifier + parameter: Ro + subpop: all + period_start_date: 2024-01-01 + period_end_date: 2024-12-31 + value: + distribution: truncnorm + mean: 2.5 + sd: 0.1 + a: 0 + b: 5 + perturbation: + distribution: truncnorm + mean: 0 + sd: 0.025 + a: -0.1 + b: 0.1 + seir_stacked: + method: StackedModifier + modifiers: [sigma_fit, gamma_fit, Ro_fit] +inference: + iterations_per_slot: 250 + do_inference: true + gt_data_path: model_input/ground_truth_hospitalizations.csv + statistics: + incidH: + name: incidH + sim_var: hospitalizations + data_var: incidH + add_one: true + remove_na: false + aggregator: sum + period: "1 days" + likelihood: + dist: pois + diff --git a/examples/emcee_vs_r_inference/three_state_simulate.yml b/examples/emcee_vs_r_inference/three_state_simulate.yml new file mode 100644 index 000000000..c36e44431 --- /dev/null +++ b/examples/emcee_vs_r_inference/three_state_simulate.yml @@ -0,0 +1,120 @@ +jobs: 14 +in_run_id: sim +write_csv: false +first_sim_index: 1 +write_parquet: true +nslots: 1 +seir: + integration: + method: stochastic + dt: 0.25 + parameters: + sigma: + value: "1 / 4" + gamma: + value: "1 / 5" + Ro: + value: 1 + transitions: [{source: [S], destination: [E], rate: ["Ro * gamma"], proportional_to: [ + [S], [I]], proportion_exponent: ['1', '1']}, {source: [E], destination: [ + I], rate: [sigma], proportional_to: [E], proportion_exponent: ['1']}, + {source: [I], destination: [R], rate: [gamma], proportional_to: [I], proportion_exponent: [ + '1']}] +config_src: [sim_base.yml] +name: three_state +start_date: 2024-01-01 +end_date: 2024-12-31 +subpop_setup: + geodata: model_input/geodata.csv + mobility: model_input/mobility.csv +initial_conditions: + method: SetInitialConditions + initial_conditions_file: model_input/initial_conditions.csv + allow_missing_subpops: true + allow_missing_compartments: true +compartments: + infection_stage: [S, E, I, R] +seir_modifiers: + scenarios: [state_varied_Ro] + modifiers: + nc_varied_Ro: + method: SinglePeriodModifier + parameter: Ro + subpop: NC + period_start_date: 2024-01-01 + period_end_date: 2024-12-31 + value: 2.4 + sc_varied_Ro: + method: SinglePeriodModifier + parameter: Ro + subpop: SC + period_start_date: 2024-01-01 + period_end_date: 2024-12-31 + value: 2.5 + ga_varied_Ro: + method: SinglePeriodModifier + parameter: Ro + subpop: GA + period_start_date: 2024-01-01 + period_end_date: 2024-12-31 + value: 2.6 + state_varied_Ro: + method: StackedModifier + modifiers: [nc_varied_Ro, sc_varied_Ro, ga_varied_Ro] +outcomes: + method: delayframe + outcomes: + hospitalizations: + source: + incidence: + infection_stage: I + probability: + value: 1 + delay: + value: 1 + duration: + value: 10 +outcome_modifiers: + scenarios: [state_varied_incidH] + modifiers: + nc_hosp_probability: + method: SinglePeriodModifier + parameter: incidH::probability + subpop: NC + period_start_date: 2024-01-01 + period_end_date: 2024-12-31 + value: 0.05 + sc_hosp_probability: + method: SinglePeriodModifier + parameter: incidH::probability + subpop: SC + period_start_date: 2024-01-01 + period_end_date: 2024-12-31 + value: 0.09 + ga_hosp_probability: + method: SinglePeriodModifier + parameter: incidH::probability + subpop: NC + period_start_date: 2024-01-01 + period_end_date: 2024-12-31 + value: 0.07 + nc_sc_hosp_delay: + method: SinglePeriodModifier + parameter: incidH::delay + subpop: [NC, SC] + subpop_groups: [[NC, SC]] + period_start_date: 2024-01-01 + period_end_date: 2024-12-31 + value: 7 + ga_hosp_delay: + method: SinglePeriodModifier + parameter: incidH::delay + subpop: GA + period_start_date: 2024-01-01 + period_end_date: 2024-12-31 + value: 9 + state_varied_incidH: + method: StackedModifier + modifiers: [nc_hosp_probability, sc_hosp_probability, ga_hosp_probability, + nc_sc_hosp_delay, ga_hosp_delay] + diff --git a/examples/simple_usa_statelevel/simple_usa_statelevel.yml b/examples/simple_usa_statelevel/simple_usa_statelevel.yml index 22255b38b..1a6839399 100644 --- a/examples/simple_usa_statelevel/simple_usa_statelevel.yml +++ b/examples/simple_usa_statelevel/simple_usa_statelevel.yml @@ -56,11 +56,9 @@ seir_modifiers: period_end_date: 2024-02-01 subpop: "all" value: - distribution: truncnorm - mean: 1.3 - sd: 1 - a: 0.2 - b: 3 + distribution: gamma + shape: 169 / 100 + scale: 10 / 13 perturbation: distribution: truncnorm mean: 0 @@ -74,11 +72,9 @@ seir_modifiers: period_end_date: 2024-04-01 subpop: "all" value: - distribution: truncnorm - mean: 1.3 - sd: 3 - a: 0.1 - b: 10 + distribution: gamma + shape: 169 / 900 + scale: 90 / 13 perturbation: distribution: truncnorm mean: 0 diff --git a/flepimop/R_packages/inference/inst/scripts/flepimop-inference-main.R b/flepimop/R_packages/inference/inst/scripts/flepimop-inference-main.R index 5e5d7f4ed..2d20d43f6 100755 --- a/flepimop/R_packages/inference/inst/scripts/flepimop-inference-main.R +++ b/flepimop/R_packages/inference/inst/scripts/flepimop-inference-main.R @@ -133,8 +133,7 @@ foreach(seir_modifiers_scenario = seir_modifiers_scenarios) %:% "`flepimop-inference-slot` not found in PATH, unable to run inference slot" ) } - command <- c( - inference_slot_cmd, + args <- c( "-c", opt$config, "-u", opt$run_id, "-s", opt$seir_modifiers_scenarios, @@ -155,12 +154,18 @@ foreach(seir_modifiers_scenario = seir_modifiers_scenarios) %:% "-H", opt$save_hosp, "-M", opt$memory_profiling, "-P", opt$memory_profiling_iters, - "-g", opt$subpop_len, - sep = " " + "-g", opt$subpop_len + ) + writeLines( + paste("Running inference slot with args:", paste(args, collapse = " ")), + con = gsub("inference_slot", "cmd", log_file, fixed = TRUE) ) err <- tryCatch({ system2( - command = opt$rpath, args = command, stdout = log_file, stderr = log_file + command = inference_slot_cmd, + args = args, + stdout = log_file, + stderr = log_file ) }, error = function(e) { message <- paste("Error in slot", flepi_slot, ":", e$message) diff --git a/flepimop/gempyor_pkg/pyproject.toml b/flepimop/gempyor_pkg/pyproject.toml index 6936d3edc..66bce89ce 100644 --- a/flepimop/gempyor_pkg/pyproject.toml +++ b/flepimop/gempyor_pkg/pyproject.toml @@ -27,6 +27,7 @@ dependencies = [ "pandas", "pyarrow", "pydantic>=2.10.0", + "pyyaml", "scipy", "seaborn", "sympy", diff --git a/flepimop/gempyor_pkg/src/gempyor/NPI/MultiPeriodModifier.py b/flepimop/gempyor_pkg/src/gempyor/NPI/MultiPeriodModifier.py index f5b6aa0cb..259554af2 100644 --- a/flepimop/gempyor_pkg/src/gempyor/NPI/MultiPeriodModifier.py +++ b/flepimop/gempyor_pkg/src/gempyor/NPI/MultiPeriodModifier.py @@ -3,9 +3,9 @@ import numpy as np import pandas as pd -from . import helpers -from .base import NPIBase from ..distributions import distribution_from_confuse_config +from .base import NPIBase +from .helpers import SpatialGroups class MultiPeriodModifier(NPIBase): @@ -167,19 +167,24 @@ def __createFromConfig(self, npi_config): else: start_dates = [self.start_date] end_dates = [self.end_date] - this_spatial_group = helpers.get_spatial_groups( - grp_config, affected_subpops_grp + this_spatial_group = SpatialGroups.from_subpopulations( + list(affected_subpops_grp), + ( + grp_config["subpop_groups"].get() + if grp_config["subpop_groups"].exists() + else None + ), ) self.spatial_groups.append(this_spatial_group) # print(self.name, this_spatial_groups) # unfortunately, we cannot use .loc here, because it is not possible to assign a list of list # to a subset of a dataframe... so we iterate. - for subpop in this_spatial_group["ungrouped"]: + for subpop in this_spatial_group.ungrouped: self.parameters.at[subpop, "start_date"] = start_dates self.parameters.at[subpop, "end_date"] = end_dates self.parameters.at[subpop, "value"] = dist() - for group in this_spatial_group["grouped"]: + for group in this_spatial_group.grouped: drawn_value = dist() for subpop in group: self.parameters.at[subpop, "start_date"] = start_dates @@ -227,12 +232,17 @@ def __createFromDf(self, loaded_df, npi_config): else: start_dates = [self.start_date] end_dates = [self.end_date] - this_spatial_group = helpers.get_spatial_groups( - grp_config, affected_subpops_grp + this_spatial_group = SpatialGroups.from_subpopulations( + list(affected_subpops_grp), + ( + grp_config["subpop_groups"].get() + if grp_config["subpop_groups"].exists() + else None + ), ) self.spatial_groups.append(this_spatial_group) - for subpop in this_spatial_group["ungrouped"]: + for subpop in this_spatial_group.ungrouped: if not subpop in loaded_df.index: self.parameters.at[subpop, "start_date"] = start_dates self.parameters.at[subpop, "end_date"] = end_dates @@ -242,7 +252,7 @@ def __createFromDf(self, loaded_df, npi_config): self.parameters.at[subpop, "start_date"] = start_dates self.parameters.at[subpop, "end_date"] = end_dates self.parameters.at[subpop, "value"] = loaded_df.at[subpop, "value"] - for group in this_spatial_group["grouped"]: + for group in this_spatial_group.grouped: if ",".join(group) in loaded_df.index: # ordered, so it's ok for subpop in group: self.parameters.at[subpop, "start_date"] = start_dates @@ -306,7 +316,7 @@ def getReductionToWrite(self): for this_spatial_groups in self.spatial_groups: # spatially ungrouped dataframe df_ungroup = self.parameters[ - self.parameters.index.isin(this_spatial_groups["ungrouped"]) + self.parameters.index.isin(this_spatial_groups.ungrouped) ].copy() df_ungroup.index.name = "subpop" df_ungroup["start_date"] = df_ungroup["start_date"].apply( @@ -318,7 +328,7 @@ def getReductionToWrite(self): df_list.append(df_ungroup) # spatially grouped dataframe. They are nested within multitime reduce groups, # so we can set the same dates for allof them - for group in this_spatial_groups["grouped"]: + for group in this_spatial_groups.grouped: # we use the first subpop to represent the group df_group = self.parameters[self.parameters.index == group[0]].copy() diff --git a/flepimop/gempyor_pkg/src/gempyor/NPI/SinglePeriodModifier.py b/flepimop/gempyor_pkg/src/gempyor/NPI/SinglePeriodModifier.py index ede2ce848..96f437831 100644 --- a/flepimop/gempyor_pkg/src/gempyor/NPI/SinglePeriodModifier.py +++ b/flepimop/gempyor_pkg/src/gempyor/NPI/SinglePeriodModifier.py @@ -3,9 +3,9 @@ import numpy as np import pandas as pd -from . import helpers -from .base import NPIBase from ..distributions import distribution_from_confuse_config +from .base import NPIBase +from .helpers import SpatialGroups class SinglePeriodModifier(NPIBase): @@ -149,17 +149,22 @@ def __createFromConfig(self, npi_config): else self.end_date ) self.parameters["parameter"] = self.param_name - self.spatial_groups = helpers.get_spatial_groups( - npi_config, list(self.affected_subpops) + self.spatial_groups = SpatialGroups.from_subpopulations( + list(self.affected_subpops), + ( + npi_config["subpop_groups"].get() + if npi_config["subpop_groups"].exists() + else None + ), ) - if self.spatial_groups["ungrouped"]: - self.parameters.loc[self.spatial_groups["ungrouped"], "value"] = ( - self.dist.sample(size=len(self.spatial_groups["ungrouped"])) + if self.spatial_groups.ungrouped: + self.parameters.loc[list(self.spatial_groups.ungrouped), "value"] = ( + self.dist.sample(size=len(self.spatial_groups.ungrouped)) ) - if self.spatial_groups["grouped"]: - for group in self.spatial_groups["grouped"]: + if self.spatial_groups.grouped: + for group in self.spatial_groups.grouped: drawn_value = np.repeat(self.dist(), len(group)) - self.parameters.loc[group, "value"] = drawn_value + self.parameters.loc[list(group), "value"] = drawn_value def __createFromDf(self, loaded_df, npi_config): loaded_df.index = loaded_df.subpop @@ -204,17 +209,22 @@ def __createFromDf(self, loaded_df, npi_config): # TODO: to be consistent with MTR, we want to also draw the values for the subpops # that are not in the loaded_df. - self.spatial_groups = helpers.get_spatial_groups( - npi_config, list(self.affected_subpops) + self.spatial_groups = SpatialGroups.from_subpopulations( + list(self.affected_subpops), + ( + npi_config["subpop_groups"].get() + if npi_config["subpop_groups"].exists() + else None + ), ) - if self.spatial_groups["ungrouped"]: - self.parameters.loc[self.spatial_groups["ungrouped"], "value"] = loaded_df.loc[ - self.spatial_groups["ungrouped"], "value" - ] - if self.spatial_groups["grouped"]: - for group in self.spatial_groups["grouped"]: - self.parameters.loc[group, "value"] = loaded_df.loc[ - ",".join(group), "value" + if self.spatial_groups.ungrouped: + self.parameters.loc[list(self.spatial_groups.ungrouped), "value"] = ( + loaded_df.loc[list(self.spatial_groups.ungrouped), "value"] + ) + if self.spatial_groups.grouped: + for group in self.spatial_groups.grouped: + self.parameters.loc[(list(group)), "value"] = loaded_df.loc[ + ",".join(list(group)), "value" ] def get_default(self, param): @@ -235,14 +245,14 @@ def getReduction(self, param): def getReductionToWrite(self): # spatially ungrouped dataframe df = self.parameters[ - self.parameters.index.isin(self.spatial_groups["ungrouped"]) + self.parameters.index.isin(self.spatial_groups.ungrouped) ].copy() df.index.name = "subpop" df["start_date"] = df["start_date"].astype("str") df["end_date"] = df["end_date"].astype("str") # spatially grouped dataframe - for group in self.spatial_groups["grouped"]: + for group in self.spatial_groups.grouped: # we use the first subpop to represent the group df_group = self.parameters[self.parameters.index == group[0]].copy() diff --git a/flepimop/gempyor_pkg/src/gempyor/NPI/helpers.py b/flepimop/gempyor_pkg/src/gempyor/NPI/helpers.py index aaeea15d5..24eb36b03 100644 --- a/flepimop/gempyor_pkg/src/gempyor/NPI/helpers.py +++ b/flepimop/gempyor_pkg/src/gempyor/NPI/helpers.py @@ -1,12 +1,177 @@ -import pandas as pd +"""Helpers for interacting with and using modifiers.""" + +__all__ = ("SpatialGroups", "reduce_parameter") + +from collections.abc import Iterator, Sequence +from dataclasses import dataclass, field +from typing import Literal + import numpy as np -import typing +import pandas as pd + +from ..utils import _flatten_list_of_lists, _make_list_of_list + + +@dataclass(frozen=True) +class SpatialGroups: + """ + Modifier spatial groups. + + Attributes: + grouped: List of lists of subpopulations that share the same modifier value. + ungrouped: List of subpopulations that have individual modifier values. + + Examples: + >>> from gempyor.NPI.helpers import SpatialGroups + >>> sp = SpatialGroups(grouped=(("A", "B"), ("C",)), ungrouped=("D", "E")) + >>> sp.grouped + (('A', 'B'), ('C',)) + >>> sp.ungrouped + ('D', 'E') + >>> for kind, group in sp: + ... print(f"{kind}: {group}") + ungrouped: ('D',) + ungrouped: ('E',) + grouped: ('A', 'B') + grouped: ('C',) + + """ + + grouped: tuple[tuple[str]] = field(default_factory=tuple) + ungrouped: tuple[str] = field(default_factory=tuple) + + def __iter__(self) -> Iterator[tuple[Literal["grouped", "ungrouped"], tuple[str]]]: + yield from zip(len(self.ungrouped) * ("ungrouped",), ((g,) for g in self.ungrouped)) + yield from zip(len(self.grouped) * ("grouped",), self.grouped) + + @classmethod + def from_dict( + cls, x: dict[str, Sequence[str] | Sequence[Sequence[str]]] + ) -> "SpatialGroups": + # pylint: disable=line-too-long + """ + Create a `SpatialGroups` instance from a dictionary. + + Args: + x: Dictionary with 'grouped' and 'ungrouped' keys. If a key is missing, it + defaults to an empty tuple. + + Returns: + A `SpatialGroups` instance. + + Raises: + TypeError: If the value for 'grouped' is not a sequence of sequences of + strings or strings + TypeError: If the value for 'ungrouped' is not a sequence of strings. + + Examples: + >>> from gempyor.NPI.helpers import SpatialGroups + >>> SpatialGroups.from_dict( + ... {"grouped": [["A", "B"], ["C"]], "ungrouped": ["D", "E"]} + ... ) + SpatialGroups(grouped=(('A', 'B'), ('C',)), ungrouped=('D', 'E')) + >>> SpatialGroups.from_dict({"ungrouped": ["D", "E"]}) + SpatialGroups(grouped=(), ungrouped=('D', 'E')) + >>> SpatialGroups.from_dict({"grouped": [["A", "B"], ["C"]]}) + SpatialGroups(grouped=(('A', 'B'), ('C',)), ungrouped=()) + >>> SpatialGroups.from_dict({}) + SpatialGroups(grouped=(), ungrouped=()) + >>> SpatialGroups.from_dict({"grouped": "AB"}) + Traceback (most recent call last): + ... + TypeError: The 'grouped' key is type , expected a sequence of sequences of strings or strings. The value was grouped='AB'. + >>> SpatialGroups.from_dict({"ungrouped": "DE"}) + Traceback (most recent call last): + ... + TypeError: The 'ungrouped' key is type , expected a sequence of strings. The value was ungrouped='DE'. + + """ + # pylint: enable=line-too-long + # Preliminary type checks + grouped = x.get("grouped", ()) + if not ( + (isinstance(grouped, Sequence) and not isinstance(grouped, str)) + and all(isinstance(g, (Sequence, str)) for g in grouped) + and all( + isinstance(s, str) + for g in grouped + for s in (g if isinstance(g, Sequence) else [g]) + ) + ): + msg = ( + f"The 'grouped' key is type {type(grouped)}, expected a sequence " + f"of sequences of strings or strings. The value was {grouped=}." + ) + raise TypeError(msg) + ungrouped = x.get("ungrouped", ()) + if not ( + (isinstance(ungrouped, Sequence) and not isinstance(ungrouped, str)) + and all(isinstance(s, str) for s in ungrouped) + ): + msg = ( + f"The 'ungrouped' key is type {type(ungrouped)}, expected a sequence " + f"of strings. The value was {ungrouped=}." + ) + raise TypeError(msg) + # Convert to tuples for immutability + return cls( + grouped=tuple( + tuple(sorted(g)) if isinstance(g, Sequence) else (g,) for g in grouped + ), + ungrouped=tuple(sorted(ungrouped)), + ) + + @classmethod + def from_subpopulations( + cls, + subpopulations: list[str], + subpopulation_groups: list[list[str]] | list[str] | str | None, + ) -> "SpatialGroups": + """ + Get the spatial groupings from a modifier group config. + + Args: + grp_config: Configuration view containing 'subpop_groups' key. + affected_subpops: List of subpopulations affected by the modifier. + + Returns: + A `SpatialGroups` instance constructed from modifier subpopulations. + + Examples: + >>> from gempyor.NPI.helpers import SpatialGroups + >>> SpatialGroups.from_subpopulations(["A", "B", "C"], None) + SpatialGroups(grouped=(), ungrouped=('A', 'B', 'C')) + >>> SpatialGroups.from_subpopulations(["A", "B", "C"], [["A", "B"], ["C"]]) + SpatialGroups(grouped=(('A', 'B'), ('C',)), ungrouped=()) + >>> SpatialGroups.from_subpopulations(["A", "B", "C"], "all") + SpatialGroups(grouped=(('A', 'B', 'C'),), ungrouped=()) + >>> SpatialGroups.from_subpopulations( + ... ["A", "B", "C", "D", "E", "F"], + ... [["A", "B"], [], ["E", "F"]], + ... ) + SpatialGroups(grouped=(('A', 'B'), ('E', 'F')), ungrouped=('C', 'D')) + + """ + spatial_groups = {} + if subpopulation_groups is None: + spatial_groups["ungrouped"] = subpopulations + elif subpopulation_groups == "all": + spatial_groups["grouped"] = [subpopulations] + else: + spatial_groups["grouped"] = [ + subgrp + for grp in _make_list_of_list(subpopulation_groups) + if (subgrp := list(set(grp).intersection(subpopulations))) + ] + spatial_groups["ungrouped"] = list( + set(subpopulations) - set(_flatten_list_of_lists(subpopulation_groups)) + ) + return SpatialGroups.from_dict(spatial_groups) -# Helper function def reduce_parameter( parameter: np.ndarray, - modification: typing.Union[pd.DataFrame, float], + modification: pd.DataFrame | float, method: str = "product", ) -> np.ndarray: if isinstance(modification, pd.DataFrame): @@ -21,79 +186,3 @@ def reduce_parameter( return parameter * modification else: raise ValueError(f"Unknown method to do NPI reduction, got {method}") - - -def get_spatial_groups(grp_config, affected_subpops: list) -> dict: - """ - Spatial groups are defined in the config file as a list (of lists). - They have the same value. - grouped is a list of lists of subpops - ungrouped is a list of subpops - the list are ordered, and this is important so we can get back and forth - from the written to disk part that is comma separated - """ - - spatial_groups = {"grouped": [], "ungrouped": []} - - if not grp_config["subpop_groups"].exists(): - spatial_groups["ungrouped"] = affected_subpops - else: - if grp_config["subpop_groups"].get() == "all": - spatial_groups["grouped"] = [affected_subpops] - else: - spatial_groups["grouped"] = grp_config["subpop_groups"].get() - spatial_groups["ungrouped"] = list( - set(affected_subpops) - - set(flatten_list_of_lists(spatial_groups["grouped"])) - ) - - # flatten the list of lists of grouped subpops, so we can do some checks - flat_grouped_list = flatten_list_of_lists(spatial_groups["grouped"]) - # check that all subpops are either grouped or ungrouped - - # if set(flat_grouped_list + spatial_groups["ungrouped"]) != set(affected_subpops): - # print("set of grouped and ungrouped subpops", set(flat_grouped_list + spatial_groups["ungrouped"])) - # print("set of affected subpops ", set(affected_subpops)) - # raise ValueError(f"The two above sets are differs for for intervention with config \n {grp_config}") - # if len(set(flat_grouped_list + spatial_groups["ungrouped"])) != len( - # flat_grouped_list + spatial_groups["ungrouped"] - # ): - # raise ValueError( - # f"subpop_groups error. For intervention with config \n {grp_config} \n duplicate entries in the set of grouped and ungrouped subpops" - # f" {flat_grouped_list + spatial_groups['ungrouped']} vs {set(flat_grouped_list + spatial_groups['ungrouped'])}" - # ) - - spatial_groups["grouped"] = make_list_of_list(spatial_groups["grouped"]) - - # sort the lists - spatial_groups["grouped"] = [ - sorted(list(set(x).intersection(affected_subpops))) - for x in spatial_groups["grouped"] - ] - spatial_groups["ungrouped"] = sorted( - list(set(spatial_groups["ungrouped"]).intersection(affected_subpops)) - ) - - # remove empty sublist in grp{'grouped': [[], ['01000_14to15', '01000_17to18', '01000_22to23'], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], []], - spatial_groups["grouped"] = [x for x in spatial_groups["grouped"] if x] - - return spatial_groups - - -def flatten_list_of_lists(list_of_lists): - """flatten a list of lists into a single list, or return the original list if it is not a list of lists""" - if not list_of_lists: - return list_of_lists # empty list - elif not isinstance(list_of_lists[0], list): - return list_of_lists - return [item for sublist in list_of_lists for item in sublist] - - -def make_list_of_list(this_list): - """if the list contains its' values, nest it into another list""" - if not this_list: - return this_list # empty list - elif isinstance(this_list[0], list): - return this_list - else: - return [this_list] diff --git a/flepimop/gempyor_pkg/src/gempyor/calibrate.py b/flepimop/gempyor_pkg/src/gempyor/calibrate.py index 87ad322ea..b7040f721 100644 --- a/flepimop/gempyor_pkg/src/gempyor/calibrate.py +++ b/flepimop/gempyor_pkg/src/gempyor/calibrate.py @@ -192,6 +192,10 @@ def calibrate( # and then acceptances are not guaranted, see issue #316. This solves this issue and greates a new chain with llik evaluation p0 = backend.get_last_sample().coords else: + print( + f"nwalkers: {nwalkers}, inference parameters dim: " + f"{gempyor_inference.inferpar.get_dim()}." + ) backend.reset(nwalkers, gempyor_inference.inferpar.get_dim()) p0 = gempyor_inference.inferpar.draw_initial(n_draw=nwalkers) for i in range(nwalkers): @@ -316,8 +320,8 @@ def calibrate( ) print("EMCEE Run done, doing sampling") - shutil.rmtree("model_output/", ignore_errors=True) - shutil.rmtree(os.path.join(project_path, "model_output/"), ignore_errors=True) + # shutil.rmtree("model_output/", ignore_errors=True) + # shutil.rmtree(os.path.join(project_path, "model_output/"), ignore_errors=True) max_indices = np.argsort(sampler.get_log_prob()[-1, :])[-nsamples:] samples = sampler.get_chain()[ diff --git a/flepimop/gempyor_pkg/src/gempyor/distributions.py b/flepimop/gempyor_pkg/src/gempyor/distributions.py index fa983dcb6..6c9d09868 100644 --- a/flepimop/gempyor_pkg/src/gempyor/distributions.py +++ b/flepimop/gempyor_pkg/src/gempyor/distributions.py @@ -13,18 +13,19 @@ "TruncatedNormalDistribution", "UniformDistribution", "WeibullDistribution", + "distribution_from_confuse_config", ) -import confuse from abc import ABC, abstractmethod -from math import isclose +from math import inf, isclose from typing import Annotated, Literal +import confuse import numpy as np -from numpy.random import Generator import numpy.typing as npt -from pydantic import BaseModel, PrivateAttr, Field, TypeAdapter, model_validator +from numpy.random import Generator +from pydantic import BaseModel, Field, PrivateAttr, TypeAdapter, model_validator from scipy.stats import truncnorm from ._pydantic_ext import EvaledFloat, EvaledInt @@ -37,6 +38,12 @@ class DistributionABC(ABC, BaseModel): allow_edge_cases: bool = False _rng: Generator = PrivateAttr(default_factory=np.random.default_rng) + _lower_bound: float | int | None = PrivateAttr(default=None) + _upper_bound: float | int | None = PrivateAttr(default=None) + + def __call__(self) -> float | int: + """A shortcut for `self.sample(size=1)`.""" + return self.sample(size=1).item() def sample( self, size: int | tuple[int, ...] = 1, rng: Generator | None = None @@ -55,10 +62,6 @@ def sample( rng = rng if rng is not None else self._rng return self._sample_from_generator(size=size, rng=rng) - def __call__(self) -> float | int: - """A shortcut for `self.sample(size=1)`.""" - return self.sample(size=1).item() - @abstractmethod def _sample_from_generator( self, size: int | tuple[int, ...], rng: Generator @@ -69,9 +72,63 @@ def _sample_from_generator( Args: size: The desired output size of samples to be drawn. rng: A NumPy random number generator instance used for sampling. + + Returns: + A NumPy array of either floats/ints (depending on distribution) + drawn from the distribution with shape `size`. """ raise NotImplementedError + def _bound(self, kind: Literal["lower", "upper"]) -> float | int: + """ + Get the lower or upper bound of the distribution's support. + + Args: + kind: Either "lower" or "upper" to specify which bound to return. + + Returns: + The lower or upper bound of the distribution's support. + """ + if (bnd := getattr(self, f"_{kind}_bound", None)) is not None: + return bnd + msg = ( + f"{kind.title()} bound not defined for {self.distribution} " + f"distribution. Implementation must either define `_{kind}_bound` " + f"property or override `_{kind}` to handle bounds." + ) + raise NotImplementedError(msg) + + @property + def _lower(self) -> float | int: + """ + The lower bound of the distribution's support. + + Returns: + The lower bound of the distribution's support as a float or int. + """ + return self._bound("lower") + + @property + def _upper(self) -> float | int: + """ + The upper bound of the distribution's support. + + Returns: + The upper bound of the distribution's support as a float or int. + """ + return self._bound("upper") + + @property + def support(self) -> tuple[float | int, float | int]: + """ + The theoretical support of the distribution as a (min, max) tuple. + + Returns: + A tuple representing the upper and lower bounds in the distribution's + support. + """ + return (self._lower, self._upper) + class FixedDistribution(DistributionABC): """ @@ -86,6 +143,8 @@ class FixedDistribution(DistributionABC): array([[1.23, 1.23, 1.23, 1.23, 1.23], [1.23, 1.23, 1.23, 1.23, 1.23], [1.23, 1.23, 1.23, 1.23, 1.23]]) + >>> dist.support + (1.23, 1.23) """ distribution: Literal["fixed"] = "fixed" @@ -97,6 +156,16 @@ def _sample_from_generator( """Sampling logic for fixed distributions.""" return np.full(size, self.value) + @property + def _lower(self) -> float: + """The lower bound of the fixed distribution's support.""" + return self.value + + @property + def _upper(self) -> float: + """The upper bound of the fixed distribution's support.""" + return self.value + class NormalDistribution(DistributionABC): """ @@ -115,12 +184,17 @@ class NormalDistribution(DistributionABC): array([[-2.37992848, 5.67703038, 6.53254122, -6.47965835, -3.55980778], [ 2.87528181, 0.87690833, 2.22439479, -1.53869767, 6.25729089], [ 5.80006371, 2.59713814, 7.37258543, 4.40379204, -1.56681608]]) + >>> dist.support + (-inf, inf) """ distribution: Literal["norm"] = "norm" mu: EvaledFloat sigma: EvaledFloat = Field(..., gt=0) + _lower_bound: float = PrivateAttr(default=-inf) + _upper_bound: float = PrivateAttr(default=inf) + def _sample_from_generator( self, size: int | tuple[int, ...], rng: Generator ) -> npt.NDArray[np.float64]: @@ -129,6 +203,7 @@ def _sample_from_generator( class UniformDistribution(DistributionABC): + # pylint: disable=line-too-long """ Represents a uniform distribution. @@ -145,10 +220,14 @@ class UniformDistribution(DistributionABC): array([[ 0.37775688, 1.21719584, 0.89473606, -0.3116453 , 1.4512447 ], [ 1.0222794 , 1.07212861, -0.24377273, 0.40077188, 0.24159605], [ 1.35352998, 0.78773024, 1.14552323, 0.3868284 , -0.04552256]]) + >>> dist.support + (-0.5, 1.5) >>> # With `low == high` and `allow_edge_cases=True`, all samples == `low`. >>> dist_edge = UniformDistribution(low=5.0, high=5.0, allow_edge_cases=True) >>> dist_edge.sample(size=5) array([5., 5., 5., 5., 5.]) + >>> dist_edge.support + (5.0, 5.0) >>> # Without `allow_edge_cases` set to True, it fails by default when `low == high`. >>> UniformDistribution(low=5.0, high=5.0) Traceback (most recent call last): @@ -156,6 +235,7 @@ class UniformDistribution(DistributionABC): pydantic_core._pydantic_core.ValidationError: 1 validation error for UniformDistribution Value error, Upper bound `high`, 5.0, must be > to lower bound `low`, 5.0. [type=value_error, ... """ + # pylint: enable=line-too-long distribution: Literal["uniform"] = "uniform" low: EvaledFloat @@ -167,6 +247,16 @@ def _sample_from_generator( """Sampling logic for uniform distributions.""" return rng.uniform(low=self.low, high=self.high, size=size) + @property + def _lower(self) -> float: + """The lower bound of the uniform distribution's support.""" + return self.low + + @property + def _upper(self) -> float: + """The upper bound of the uniform distribution's support.""" + return self.high + @model_validator(mode="after") def _validate_bounds(self) -> "UniformDistribution": """Validate bounds based on whether or not edge cases are allowed.""" @@ -174,13 +264,16 @@ def _validate_bounds(self) -> "UniformDistribution": not self.allow_edge_cases and isclose(self.high, self.low) ): op = ">=" if self.allow_edge_cases else ">" - raise ValueError( - f"Upper bound `high`, {self.high}, must be {op} to lower bound `low`, {self.low}." + msg = ( + f"Upper bound `high`, {self.high}, must be " + f"{op} to lower bound `low`, {self.low}." ) + raise ValueError(msg) return self class LognormalDistribution(DistributionABC): + # pylint: disable=line-too-long """ Represents a Lognormal distribution. @@ -197,12 +290,18 @@ class LognormalDistribution(DistributionABC): array([[0.3534603 , 2.11795541, 2.56142749, 0.14212687, 0.27193845], [1.13637163, 0.72888261, 0.98333919, 0.42611589, 2.40944872], [2.17666075, 1.06825951, 3.08712799, 1.59601411, 0.42346159]]) + >>> dist.support + (0.0, inf) """ + # pylint: enable=line-too-long distribution: Literal["lognorm"] = "lognorm" meanlog: EvaledFloat sdlog: EvaledFloat = Field(..., gt=0) + _lower_bound: float = PrivateAttr(default=0.0) + _upper_bound: float = PrivateAttr(default=inf) + def _sample_from_generator( self, size: int | tuple[int, ...], rng: Generator ) -> npt.NDArray[np.float64]: @@ -228,11 +327,15 @@ class TruncatedNormalDistribution(DistributionABC): array([[1.07000038, 2.18016199, 1.66002835, 0.28689654, 3.04332767], [1.83818339, 1.9153892 , 0.37639339, 1.09435167, 0.92629918], [2.54134935, 1.52545861, 2.04022114, 1.07959285, 0.6142512 ]]) + >>> dist.support + (0.0, 10.0) >>> # With `a == b` and `allow_edge_cases=True`, all samples == `a`. >>> dist_edge = TruncatedNormalDistribution(mean=5.0, sd=2.0, a=7.0, b=7.0, allow_edge_cases=True) >>> dist_edge.sample(size=5) array([7., 7., 7., 7., 7.]) - >>> # Withoug `allow_edge_cases` set to True, it fails by default when `a == b`. + >>> dist_edge.support + (7.0, 7.0) + >>> # Without `allow_edge_cases` set to True, it fails by default when `a == b`. >>> TruncatedNormalDistribution(mean=5.0, sd=2.0, a=7.0, b=7.0) Traceback (most recent call last): ... @@ -267,6 +370,16 @@ def _sample_from_generator( random_state=rng, ) + @property + def _lower(self) -> float: + """The lower bound of the truncated normal distribution's support.""" + return self.a + + @property + def _upper(self) -> float: + """The upper bound of the truncated normal distribution's support.""" + return self.b + @model_validator(mode="after") def _validate_bounds(self) -> "TruncatedNormalDistribution": """Validate bounds based on whether or not edge cases are allowed..""" @@ -279,6 +392,7 @@ def _validate_bounds(self) -> "TruncatedNormalDistribution": class PoissonDistribution(DistributionABC): + # pylint: disable=line-too-long """ Represents a Poisson distribution. @@ -295,10 +409,14 @@ class PoissonDistribution(DistributionABC): array([[4, 5, 1, 7, 1], [4, 2, 2, 5, 4], [1, 6, 2, 5, 0]]) + >>> dist.support + (0, inf) >>> # With `lam=0` and `allow_edge_cases=True`, all samples will be 0. >>> dist_edge = PoissonDistribution(lam=0.0, allow_edge_cases=True) >>> dist_edge.sample(size=5) array([0, 0, 0, 0, 0]) + >>> dist_edge.support + (0, 0) >>> # Without `allow_edge_cases` explicitly set to True, it fails by default. >>> PoissonDistribution(lam=0.0) Traceback (most recent call last): @@ -306,16 +424,24 @@ class PoissonDistribution(DistributionABC): pydantic_core._pydantic_core.ValidationError: 1 validation error for PoissonDistribution Value error, Input for `lam` cannot be zero when `allow_edge_cases` is `False`. [type=value_error, ... """ + # pylint: enable=line-too-long distribution: Literal["poisson"] = "poisson" lam: EvaledFloat = Field(..., ge=0.0) + _lower_bound: int = PrivateAttr(default=0) + def _sample_from_generator( self, size: int | tuple[int, ...], rng: Generator ) -> npt.NDArray[np.int64]: """Sampling logic for Poisson distributions.""" return rng.poisson(lam=self.lam, size=size) + @property + def _upper(self) -> float | int: + """The upper bound of the Poisson distribution is inf.""" + return 0 if isclose(self.lam, 0.0) else inf + @model_validator(mode="after") def _validate_lambda(self) -> "PoissonDistribution": if not self.allow_edge_cases and isclose(self.lam, 0.0): @@ -326,6 +452,7 @@ def _validate_lambda(self) -> "PoissonDistribution": class BinomialDistribution(DistributionABC): + # pylint: disable=line-too-long """ Represents a binomial distribution. @@ -342,10 +469,14 @@ class BinomialDistribution(DistributionABC): array([[5, 7, 6, 3, 8], [6, 6, 3, 5, 4], [7, 6, 6, 5, 4]]) + >>> dist.support + (0, 10) >>> # It succeeds with `p=0` or `p=1` when `allow_edge_cases=True`. >>> dist_edge = BinomialDistribution(n=10, p=1.0, allow_edge_cases=True) >>> dist_edge.sample(size=5) array([10, 10, 10, 10, 10]) + >>> dist_edge.support + (10, 10) >>> # Without `allow_edge_cases` set to True, it fails by default when `p=0` or `p=1`. >>> BinomialDistribution(n=10, p=0.0) Traceback (most recent call last): @@ -354,6 +485,7 @@ class BinomialDistribution(DistributionABC): Value error, Input for `p` cannot be 0 or 1 when `allow_edge_cases` is `False`. [type=value_error, input_value={'n': 10, 'p': 0.0}, input_type=dict] For further information visit https://errors.pydantic.dev/2.11/v/value_error """ + # pylint: enable=line-too-long distribution: Literal["binomial"] = "binomial" n: EvaledInt = Field(..., ge=0) @@ -365,6 +497,16 @@ def _sample_from_generator( """Sampling logic for binomial distributions.""" return rng.binomial(n=self.n, p=self.p, size=size) + @property + def _lower(self) -> int: + """The lower bound of the binomial distribution is 0.""" + return self.n if isclose(self.p, 1.0) else 0 + + @property + def _upper(self) -> int: + """The upper bound of the binomial distribution is n.""" + return 0 if isclose(self.p, 0.0) else self.n + @model_validator(mode="after") def _validate_params(self) -> "BinomialDistribution": """Validate params based on whether or not edge cases are allowed.""" @@ -381,6 +523,7 @@ def _validate_params(self) -> "BinomialDistribution": class GammaDistribution(DistributionABC): + # pylint: disable=line-too-long """ Represents a gamma distribution. @@ -397,20 +540,32 @@ class GammaDistribution(DistributionABC): array([[4.25301838, 2.75582337, 2.46760563, 4.61888299, 2.63006031], [3.51900762, 3.28422915, 4.61612039, 2.15883076, 5.69337578], [1.75889742, 3.67897959, 3.38745296, 1.6334585 , 3.89261212]]) + >>> dist.support + (0.0, inf) """ + # pylint: enable=line-too-long distribution: Literal["gamma"] = "gamma" shape: EvaledFloat = Field(..., gt=0) scale: EvaledFloat = Field(..., gt=0) + _lower_bound: float = PrivateAttr(default=0.0) + _upper_bound: float = PrivateAttr(default=inf) + def _sample_from_generator( self, size: int | tuple[int, ...], rng: Generator ) -> npt.NDArray[np.float64]: """Sampling logic for Gamma distributions.""" return rng.gamma(shape=self.shape, scale=self.scale, size=size) + @property + def support(self) -> tuple[float, float]: + """The theoretical support of the gamma distribution is [0, inf).""" + return (0.0, inf) + class WeibullDistribution(DistributionABC): + # pylint: disable=line-too-long """ Represents a weibull distribution. @@ -427,12 +582,18 @@ class WeibullDistribution(DistributionABC): array([[7.02058405, 7.0786094 , 3.00403884, 1.87780582, 5.8054469 ], [5.73657673, 7.886256 , 1.81412232, 5.0918523 , 1.73016943], [5.17350562, 6.22761392, 3.4198509 , 5.43445321, 2.36440749]]) + >>> dist.support + (0.0, inf) """ + # pylint: enable=line-too-long distribution: Literal["weibull"] = "weibull" shape: EvaledFloat = Field(..., gt=0) scale: EvaledFloat = Field(..., gt=0) + _lower_bound: float = PrivateAttr(default=0.0) + _upper_bound: float = PrivateAttr(default=inf) + def _sample_from_generator( self, size: int | tuple[int, ...], rng: Generator ) -> npt.NDArray[np.float64]: @@ -458,12 +619,17 @@ class BetaDistribution(DistributionABC): array([[0.28406092, 0.39027204, 0.29864681, 0.41835336, 0.49963165], [0.30396328, 0.15089427, 0.32937986, 0.52373987, 0.16127411], [0.32746504, 0.48761242, 0.2162056 , 0.29178583, 0.22819733]]) + >>> dist.support + (0.0, 1.0) """ distribution: Literal["beta"] = "beta" alpha: EvaledFloat = Field(..., gt=0) beta: EvaledFloat = Field(..., gt=0) + _lower_bound: float = PrivateAttr(default=0.0) + _upper_bound: float = PrivateAttr(default=1.0) + def _sample_from_generator( self, size: int | tuple[int, ...], rng: Generator ) -> npt.NDArray[np.float64]: diff --git a/flepimop/gempyor_pkg/src/gempyor/inference_parameter.py b/flepimop/gempyor_pkg/src/gempyor/inference_parameter.py index 9bed34792..5e6910b3b 100644 --- a/flepimop/gempyor_pkg/src/gempyor/inference_parameter.py +++ b/flepimop/gempyor_pkg/src/gempyor/inference_parameter.py @@ -1,79 +1,149 @@ """ -Managing inference paramters in a vectorized way. +Managing inference parameters in a vectorized way. """ -import xarray as xr -import pandas as pd +__all__ = ("InferenceParameters",) + +from collections import Counter +from typing import Literal + import numpy as np -import confuse -from . import NPI -from .distributions import distribution_from_confuse_config +import numpy.typing as npt +import pandas as pd +from confuse import ConfigView + +from .distributions import DistributionABC, distribution_from_confuse_config +from .NPI.helpers import SpatialGroups -# TODO cast uper and lower bound as arrays class InferenceParameters: """ - A class to manage inference parameters, in a vectorized way - - Parameters: - global_config (confuse.ConfigView): The global configuration. - subpop_names (list): The subpopulation names, in the right order + A class to manage vectorized inference parameters. """ - def __init__(self, global_config, subpop_names): - self.ptypes = [] - self.pnames = [] - self.subpops = [] - self.pdists = [] - self.ubs = [] - self.lbs = [] + def __init__(self, global_config: ConfigView, subpop_names: list[str]) -> None: + """ + Initializes the `InferenceParameters` instance. + + This constructor sets up the inference parameters based on the provided global + configuration and subpopulation names by calling the `build_from_config` method. + + Args: + global_config: The global configuration represented as a + `confuse.ConfigView`. + subpop_names: A list of subpopulation names to be used in + the configuration. + """ + self.ptypes: list[Literal["outcome_modifiers", "seir_modifiers"]] = [] + self.pnames: list[str] = [] + self.subpops: list[str] = [] + self.pdists: list[DistributionABC] = [] + self._lower_bounds: npt.NDArray[np.float64] = np.array([], dtype=np.float64) + self._upper_bounds: npt.NDArray[np.float64] = np.array([], dtype=np.float64) self.build_from_config(global_config, subpop_names) - def add_modifier(self, pname, ptype, parameter_config, subpops): + def __str__(self) -> str: + return f"InferenceParameters: with {len(self)} parameters: \n" + "\n".join( + f" {key}: {value} parameters" for key, value in Counter(self.ptypes).items() + ) + + def __len__(self) -> int: + return len(self.pnames) + + def print_summary(self) -> None: + """ + Prints a summary of the inference parameters. + + Produces a summary of the parameters, including their types, names, bounds and + relevant subpopulations. This is useful for debugging and understanding the + configuration. + """ + dim = len(self) + print(f"There are {dim} parameters in the configuration.") + for p_idx in range(dim): + lower, upper = self.pdists[p_idx].support + print( + f"{self.ptypes[p_idx]}::{self.pnames[p_idx]} in " + f"[{lower}, {upper}] >> affected subpop: {self.subpops[p_idx]}" + ) + + def build_from_config(self, global_config: ConfigView, subpop_names: list[str]) -> None: + """ + Constructs the inference parameters from the global configuration. + + Args: + global_config: The global configuration represented as a + `confuse.ConfigView`. + subpop_names: A list of subpopulation names to be used in + the configuration. + """ + for config_part in ["seir_modifiers", "outcome_modifiers"]: + if global_config[config_part].exists(): + for npi in global_config[config_part]["modifiers"].get(): + if global_config[config_part]["modifiers"][npi][ + "perturbation" + ].exists(): + self.add_modifier( + pname=npi, + ptype=config_part, + parameter_config=global_config[config_part]["modifiers"][npi], + subpops=subpop_names, + ) + + def add_modifier( + self, + pname: str, + ptype: Literal["outcome_modifiers", "seir_modifiers"], + parameter_config: ConfigView, + subpops: list[str], + ) -> None: """ - Adds a modifier parameter to the parameters list. + Adds a modifier parameter to the inference parameters representation. Args: - pname (str): The parameter name. - ptype (str): The parameter type. - parameter_config (confuse.ConfigView): The configuration for the parameter. - subpops (list): List of subpopulations affected by the modifier. + pname: The name of the parameter. + ptype: The parameter type, must be one of "outcome_modifiers" or + "seir_modifiers". + parameter_config: The confuse representation of the parameter configuration. + subpops: A list of subpopulations affected by the modifier. """ # identify spatial group - affected_subpops = set(subpops) + affected_subpops = subpops if parameter_config["method"].get() == "SinglePeriodModifier": if ( parameter_config["subpop"].exists() and parameter_config["subpop"].get() != "all" ): - affected_subpops = {str(n.get()) for n in parameter_config["subpop"]} - spatial_groups = NPI.helpers.get_spatial_groups( - parameter_config, list(affected_subpops) + affected_subpops = [str(n.get()) for n in parameter_config["subpop"]] + spatial_groups = SpatialGroups.from_subpopulations( + list(affected_subpops), + ( + parameter_config["subpop_groups"].get() + if parameter_config["subpop_groups"].exists() + else None + ), ) - - # ungrouped subpop (all affected subpop by default) have one parameter per subpop - if spatial_groups["ungrouped"]: - for sp in spatial_groups["ungrouped"]: + # ungrouped subpop (all affected subpop by + # default) have one parameter per subpop + if spatial_groups.ungrouped: + for sp in spatial_groups.ungrouped: + dist = distribution_from_confuse_config(parameter_config["value"]) self.add_single_parameter( ptype=ptype, pname=pname, subpop=sp, - pdist=distribution_from_confuse_config(parameter_config["value"]), - lb=parameter_config["value"]["a"].get(float), - ub=parameter_config["value"]["b"].get(float), + pdist=dist, ) - # grouped subpop have one parameter per group - if spatial_groups["grouped"]: - for group in spatial_groups["grouped"]: + if spatial_groups.grouped: + for group in spatial_groups.grouped: + dist = distribution_from_confuse_config(parameter_config["value"]) self.add_single_parameter( ptype=ptype, pname=pname, subpop=",".join(group), - pdist=distribution_from_confuse_config(parameter_config["value"]), - lb=parameter_config["value"]["a"].get(float), - ub=parameter_config["value"]["b"].get(float), + pdist=dist, ) elif parameter_config["method"].get() == "MultiPeriodModifier": affected_subpops_grp = [] @@ -82,133 +152,109 @@ def add_modifier(self, pname, ptype, parameter_config, subpops): affected_subpops_grp = affected_subpops else: affected_subpops_grp += [str(n.get()) for n in grp_config["subpop"]] - affected_subpops = list(set(affected_subpops_grp)) + affected_subpops = list(affected_subpops_grp) spatial_groups = [] for grp_config in parameter_config["groups"]: if grp_config["subpop"].get() == "all": affected_subpops_grp = affected_subpops else: affected_subpops_grp = [str(n.get()) for n in grp_config["subpop"]] - - this_spatial_group = NPI.helpers.get_spatial_groups( - grp_config, affected_subpops_grp + this_spatial_group = SpatialGroups.from_subpopulations( + affected_subpops_grp, + ( + grp_config["subpop_groups"].get() + if grp_config["subpop_groups"].exists() + else None + ), ) - - # ungrouped subpop (all affected subpop by default) have one parameter per subpop - if this_spatial_group["ungrouped"]: - for sp in this_spatial_group["ungrouped"]: + # ungrouped subpop (all affected subpop by + # default) have one parameter per subpop + if this_spatial_group.ungrouped: + for sp in this_spatial_group.ungrouped: + dist = distribution_from_confuse_config(parameter_config["value"]) self.add_single_parameter( ptype=ptype, pname=pname, subpop=sp, - pdist=istribution_from_confuse_config( - parameter_config["value"] - ), - lb=parameter_config["value"]["a"].get(float), - ub=parameter_config["value"]["b"].get(float), + pdist=dist, ) - # grouped subpop have one parameter per group - if this_spatial_group["grouped"]: - for group in this_spatial_group["grouped"]: + if this_spatial_group.grouped: + for group in this_spatial_group.grouped: + dist = distribution_from_confuse_config(parameter_config["value"]) self.add_single_parameter( ptype=ptype, pname=pname, subpop=",".join(group), - pdist=istribution_from_confuse_config( - parameter_config["value"] - ), - lb=parameter_config["value"]["a"].get(float), - ub=parameter_config["value"]["b"].get(float), + pdist=dist, ) else: raise ValueError(f"Unknown method {parameter_config['method']}") - def add_single_parameter(self, ptype, pname, subpop, pdist, lb, ub): + def add_single_parameter( + self, + ptype: Literal["outcome_modifiers", "seir_modifiers"], + pname: str, + subpop: str, + pdist: DistributionABC, + ) -> None: """ - Adds a single parameter to the parameters list. + Adds a single parameter to the inference parameters representation. Args: - ptype (str): The parameter type. - pname (str): The parameter name. - subpop (str): The subpopulation affected by the parameter. + ptype: The parameter type, must be one of "outcome_modifiers" or + "seir_modifiers". + pname: The parameter name. + subpop: The subpopulation affected by the parameter. pdist: The distribution of the parameter. - lb: The lower bound of the parameter. - ub: The upper bound of the parameter. """ self.ptypes.append(ptype) self.pnames.append(pname) self.subpops.append(subpop) self.pdists.append(pdist) - self.ubs.append(ub) - self.lbs.append(lb) - - def build_from_config(self, global_config, subpop_names): - for config_part in ["seir_modifiers", "outcome_modifiers"]: - if global_config[config_part].exists(): - for npi in global_config[config_part]["modifiers"].get(): - if global_config[config_part]["modifiers"][npi][ - "perturbation" - ].exists(): - self.add_modifier( - pname=npi, - ptype=config_part, - parameter_config=global_config[config_part]["modifiers"][npi], - subpops=subpop_names, - ) + lower, upper = pdist.support + self._lower_bounds = np.append(self._lower_bounds, lower) + self._upper_bounds = np.append(self._upper_bounds, upper) - def print_summary(self): - print(f"There are {len(self.pnames)} parameters in the configuration.") - for p_idx in range(self.get_dim()): - print( - f"{self.ptypes[p_idx]}::{self.pnames[p_idx]} in [{self.lbs[p_idx]}, {self.ubs[p_idx]}]" - f" >> affected subpop: {self.subpops[p_idx]}" - ) - - def __str__(self) -> str: - from collections import Counter - - this_str = f"InferenceParameters: with {self.get_dim()} parameters: \n" - for key, value in Counter(self.ptypes).items(): - this_str += f" {key}: {value} parameters\n" + def get_dim(self) -> int: + """ + Get the dimension of the parameter space. - return this_str + Returns: + The dimension of the parameter space, which is a non-negative integer. + """ + return len(self) - def get_dim(self): - return len(self.pnames) + def get_parameters_for_subpop(self, subpop: str) -> list[int]: + """ + Get the indices of parameters relevant for a specific subpopulation. - def get_parameters_for_subpop(self, subpop: str) -> list: - """Returns the index parameters for a given subpopulation""" - parameters = [] - for i, sp in enumerate(self.subpops): - if sp == subpop: - parameters.append(i) - return parameters + Args: + subpop: The name of the subpopulation to pull parameters indexes for. - def __len__(self): - """ - so one can use the built-in python len function + Returns: + A list of indices corresponding to parameters that affect the specified + subpopulation. """ - return len(self.pnames) + return [i for i, s in enumerate(self.subpops) if s == subpop] - def draw_initial(self, n_draw=1): + def draw_initial(self, n_draw: int = 1) -> npt.NDArray[np.float64]: """ Draws initial parameter values. Args: - n_draw (int): Number of draws, e.g the number of slots or walkers + n_draw: Number of draws, e.g the number of slots or walkers. Returns: - np.ndarray: Array of initial parameter values. + Array of initial parameter values with shape (`n_draw`, dim). """ - p0 = np.zeros((n_draw, self.get_dim())) - for p_idx in range(self.get_dim()): - p0[:, p_idx] = self.pdists[p_idx](n_draw) - + dim = len(self) + p0 = np.zeros((n_draw, dim)) + for p_idx in range(dim): + p0[:, p_idx] = self.pdists[p_idx].sample(size=n_draw) return p0 - # TODO: write a more granular method the return for a single parameter and correct the proposal like we did - def check_in_bound(self, proposal) -> bool: + def check_in_bound(self, proposal: npt.NDArray[np.float64]) -> bool: """ Checks if the proposal is within parameter bounds. @@ -216,54 +262,38 @@ def check_in_bound(self, proposal) -> bool: proposal: The proposed parameter values. Returns: - bool: True if the proposal is within bounds, False otherwise. + `True` if the proposal is within bounds, `False` otherwise. """ - if self.hit_lbs(proposal=proposal).any() or self.hit_ubs(proposal=proposal).any(): - return False - return True - - def hit_lbs(self, proposal) -> np.ndarray: - return np.array((proposal < self.lbs)) - - def hit_ubs(self, proposal) -> np.ndarray: - """ - boolean vector of True if the parameter is bigger than the upper bound and False if not - """ - return np.array((proposal > self.ubs)) + return np.logical_and( + np.greater_equal(proposal, self._lower_bounds), + np.less_equal(proposal, self._upper_bounds), + ).all() def inject_proposal( self, - proposal, - snpi_df=None, - hnpi_df=None, - ): + proposal: npt.NDArray[np.float64], + snpi_df: pd.DataFrame, + hnpi_df: pd.DataFrame, + ) -> tuple[pd.DataFrame, pd.DataFrame]: """ Injects the proposal into model inputs, at the right place. Args: proposal: The proposed parameter values. - hnpi_df (pd.DataFrame): DataFrame for hnpi. - snpi_df (pd.DataFrame): DataFrame for snpi. + snpi_df: A dataframe representing the SEIR modifiers. + hnpi_df: A dataframe representing the outcome modifiers. Returns: - pd.DataFrame, pd.DataFrame: Modified hnpi_df and snpi_df. + A tuple of modified DataFrames, the first one for SEIR modifiers and the + second for outcome modifiers. """ snpi_df_mod = snpi_df.copy(deep=True) hnpi_df_mod = hnpi_df.copy(deep=True) - - # Ideally this should lie in each submodules, e.g NPI.inject, parameter.inject - - for p_idx in range(self.get_dim()): - if self.ptypes[p_idx] == "seir_modifiers": - snpi_df_mod.loc[ - (snpi_df_mod["modifier_name"] == self.pnames[p_idx]) - & (snpi_df_mod["subpop"] == self.subpops[p_idx]), - "value", - ] = proposal[p_idx] - elif self.ptypes[p_idx] == "outcome_modifiers": - hnpi_df_mod.loc[ - (hnpi_df_mod["modifier_name"] == self.pnames[p_idx]) - & (hnpi_df_mod["subpop"] == self.subpops[p_idx]), - "value", - ] = proposal[p_idx] + for p_idx in range(len(self)): + df = snpi_df_mod if self.ptypes[p_idx] == "seir_modifiers" else hnpi_df_mod + df.loc[ + (df["modifier_name"] == self.pnames[p_idx]) + & (df["subpop"] == self.subpops[p_idx]), + "value", + ] = proposal[p_idx] return snpi_df_mod, hnpi_df_mod diff --git a/flepimop/gempyor_pkg/src/gempyor/output/__init__.py b/flepimop/gempyor_pkg/src/gempyor/output/__init__.py new file mode 100644 index 000000000..7a4b7fdb9 --- /dev/null +++ b/flepimop/gempyor_pkg/src/gempyor/output/__init__.py @@ -0,0 +1,14 @@ +"""Model output I/O API.""" + +__all__ = ( + "Chains", + "EmceeOutput", + "ModifierInfo", + "ModifierInfoPeriod", + "ModifiersDataFrames", + "OutputABC", +) + +from ._base import OutputABC +from ._emcee_output import EmceeOutput +from ._types import Chains, ModifierInfo, ModifierInfoPeriod, ModifiersDataFrames diff --git a/flepimop/gempyor_pkg/src/gempyor/output/_base.py b/flepimop/gempyor_pkg/src/gempyor/output/_base.py new file mode 100644 index 000000000..a53dbc479 --- /dev/null +++ b/flepimop/gempyor_pkg/src/gempyor/output/_base.py @@ -0,0 +1,82 @@ +"""ABC For model outputs.""" + +__all__: tuple[str, ...] = () + + +from abc import ABC, abstractmethod +from pathlib import Path + +import yaml + +from ._types import Chains + + +class OutputABC(ABC): + """Base class for model outputs.""" + + def __init__( + self, + config: Path | str, + run_id: str, + seir_modifiers_scenario: str | None = None, + outcome_modifers_scenario: str | None = None, + path_prefix: Path | str | None = None, + ) -> None: + """ + Initialize the OutputABC. + + Args: + config: Path to the config file used to generate the outputs. + run_id: Run ID of the model, typically generated by the command to produce + the outputs. + seir_modifiers_scenario: Name of the 'seir_modifiers' scenario to process or + `None` to not consider a scenario. + outcome_modifers_scenario: Name of the 'outcome_modifiers' scenario to + process or `None` to not consider a scenario. + path_prefix: Prefix path to the model output directory. If `None`, use + the current working directory. + + Raises: + ValueError: If the config file does not have a 'name' key. + FileNotFoundError: If the output path does not exist. + NotADirectoryError: If the output path is not a directory. + + """ + # Initialize self + self._config = Path(config) + self._run_id = run_id + self._seir_modifiers_scenario = seir_modifiers_scenario + self._outcome_modifers_scenario = outcome_modifers_scenario + self._path_prefix = Path(path_prefix) if path_prefix is not None else Path.cwd() + # Figure out the name of the run + with self._config.open("r") as f: + conf = yaml.safe_load(f) + self._name = conf.get("name") + if self._name is None: + msg = f"Config file {self._config} does not have a 'name' key." + raise ValueError(msg) + self._name = self._name.strip() + if self._seir_modifiers_scenario is not None: + self._name += f"_{self._seir_modifiers_scenario}" + if self._outcome_modifers_scenario is not None: + self._name += f"_{self._outcome_modifers_scenario}" + # Assert the main directory exists + path = self._path() + if not path.exists(): + msg = f"Output path '{path}' does not exist." + raise FileNotFoundError(msg) + if not path.is_dir(): + msg = f"Output path '{path}' is not a directory." + raise NotADirectoryError(msg) + + def _path(self, *args: str) -> Path: + """Get the path to the output file.""" + path = self._path_prefix / "model_output" / self._name / self._run_id + if args: + path /= Path(*args) + return path + + @abstractmethod + def get_chains(self) -> Chains: + """Get the model chains.""" + raise NotImplementedError diff --git a/flepimop/gempyor_pkg/src/gempyor/output/_emcee_output.py b/flepimop/gempyor_pkg/src/gempyor/output/_emcee_output.py new file mode 100644 index 000000000..92962622c --- /dev/null +++ b/flepimop/gempyor_pkg/src/gempyor/output/_emcee_output.py @@ -0,0 +1,155 @@ +__all__: tuple[str, ...] = () + +from pathlib import Path +from typing import Literal + +import confuse +import numpy as np +import yaml +from emcee.backends import HDFBackend + +from ..inference_parameter import InferenceParameters +from ..NPI.helpers import SpatialGroups +from ..subpopulation_structure import SubpopulationStructure +from ._base import OutputABC +from ._types import Chains, ModifierInfo, ModifierInfoPeriod + + +class EmceeOutput(OutputABC): + def __init__( + self, + config: Path | str, + run_id: str, + seir_modifiers_scenario: str | None = None, + outcome_modifers_scenario: str | None = None, + path_prefix: Path | str | None = None, + ) -> None: + super().__init__( + config, run_id, seir_modifiers_scenario, outcome_modifers_scenario, path_prefix + ) + + # Construct an instance of InferenceParameters so we can use + # it to determine the order of the parameters in the H5 file + with self._config.open("r") as f: + conf = yaml.safe_load(f) + cfg = confuse.RootView([confuse.ConfigSource.of(conf)]) + subpopulation_structure = SubpopulationStructure.from_confuse_config( + cfg["subpop_setup"], path_prefix=path_prefix + ) + self._inference_parameters = InferenceParameters( + confuse.RootView([confuse.ConfigSource.of(conf)]), + subpopulation_structure.subpop_names, + ) + + # Extract the names of the SEIR and outcome + # modifiers from the inference parameters + modifiers_names: dict[Literal["seir", "outcome"], set[str]] = { + "seir": set(), + "outcome": set(), + } + for i in range(len(self._inference_parameters)): + kind = self._inference_parameters.ptypes[i][:-10] + modifiers_names[kind].add(self._inference_parameters.pnames[i]) + + # Parse the underlying modifiers parameters + # and periods directly from the config + modifiers_lib: dict[int, (str, list[ModifierInfoPeriod])] = {} + for kind in ("seir", "outcome"): + for modifier_name, modifier_conf in ( + conf.get(f"{kind}_modifiers", {}).get("modifiers", {}).items() + ): + parameter = modifier_conf.get("parameter") + if (method := modifier_conf["method"]) == "SinglePeriodModifier": + spatial_groups = SpatialGroups.from_subpopulations( + ( + subpopulation_structure.subpop_names + if ((subpop := modifier_conf.get("subpop", "all") == "all")) + else subpop + ), + modifier_conf.get("subpop_groups", None), + ) + for _, subpop_group in spatial_groups: + subpop_group = tuple(subpop_group) + lookup_hash = hash((kind, modifier_name, subpop_group)) + modifiers_lib[lookup_hash] = ( + parameter, + [ + ModifierInfoPeriod( + start_date=modifier_conf["period_start_date"], + end_date=modifier_conf["period_end_date"], + ) + ], + ) + elif method == "MultiPeriodModifier": + for group_conf in modifier_conf.get("groups", []): + spatial_groups = SpatialGroups.from_subpopulations( + ( + subpopulation_structure.subpop_names + if ((subpop := modifier_conf.get("subpop", "all") == "all")) + else subpop + ), + group_conf.get("subpop_groups", None), + ) + periods = [ + ModifierInfoPeriod( + start_date=p["start_date"], end_date=p["end_date"] + ) + for p in group_conf.get("periods", []) + ] + for _, subpop_group in spatial_groups: + subpop_group = tuple(sorted(subpop_group)) + lookup_hash = hash((kind, modifier_name, subpop_group)) + modifiers_lib[lookup_hash] = ( + parameter, + periods, + ) + elif method == "StackedModifier": + # Inference is not done directly on StackedModifiers, but instead + # on the underlying constituents which are captured above. + continue + else: + msg = ( + f"Unsupported modifier method '{method}' for " + f"the {kind} modifier '{modifier_name}'." + ) + raise NotImplementedError(msg) + + # Create a list of ModifierInfo objects from the + # inference parameters and parsed modifiers config + self._modifiers: list[ModifierInfo] = [] + for i in range(len(self._inference_parameters)): + kind = self._inference_parameters.ptypes[i][:-10] + modifier_name = self._inference_parameters.pnames[i] + subpops = sorted(self._inference_parameters.subpops[i].split(",")) + lookup_hash = hash((kind, modifier_name, tuple(subpops))) + parameter, periods = modifiers_lib[lookup_hash] + self._modifiers.append( + ModifierInfo( + kind=kind, + name=modifier_name, + subpops=subpops, + periods=periods, + parameter=parameter, + ) + ) + + # Open the HDF5 backend file produced by EMCEE which contains the chains + h5 = self._path_prefix / f"{self._run_id}_backend.h5" + if not h5.exists(): + msg = f"The EMCEE inference H5 backend file '{h5}' does not exist." + raise FileNotFoundError(msg) + if not h5.is_file(): + msg = f"The EMCEE inference H5 backend file '{h5}' is not a file." + raise NotADirectoryError(msg) + self._reader = HDFBackend(h5, read_only=True) + + def get_chains(self) -> Chains: + log_prob = self._reader.get_log_prob().T + samples = np.transpose(self._reader.get_chain(), axes=(1, 0, 2)) + shape = samples.shape + return Chains( + shape=shape, + log_probability=log_prob, + samples=samples, + modifiers=self._modifiers, + ) diff --git a/flepimop/gempyor_pkg/src/gempyor/output/_types.py b/flepimop/gempyor_pkg/src/gempyor/output/_types.py new file mode 100644 index 000000000..debaf950d --- /dev/null +++ b/flepimop/gempyor_pkg/src/gempyor/output/_types.py @@ -0,0 +1,329 @@ +"""Types to represent the model output data structures.""" + +__all__: tuple[str, ...] = () + +from dataclasses import dataclass +from datetime import date +from typing import Literal, NamedTuple + +import numpy as np +import numpy.typing as npt +import pandas as pd + +from .._pydantic_ext import _ensure_list + + +class ModifiersDataFrames(NamedTuple): + """ + DataFrames to hold modifier information. + + Attributes: + snpi: List of DataFrames for SEIR modifiers. + hnpi: List of DataFrames for outcome modifiers. + + """ + + snpi: list[pd.DataFrame] + hnpi: list[pd.DataFrame] + + +@dataclass(frozen=True) +class ModifierInfoPeriod: + """ + Dataclass to hold information about a modifier period. + + Attributes: + start_date: The start date of the modifier period. + end_date: The end date of the modifier period. + + Examples: + >>> from datetime import date + >>> from pprint import pprint + >>> from gempyor.output import ModifierInfoPeriod + >>> period = ModifierInfoPeriod( + ... start_date=date(2020, 1, 1), + ... end_date=date(2020, 12, 31), + ... ) + >>> pprint(period) + ModifierInfoPeriod(start_date=datetime.date(2020, 1, 1), + end_date=datetime.date(2020, 12, 31)) + + """ + + start_date: date + end_date: date + + +@dataclass(frozen=True) +class ModifierInfo: + """ + Dataclass to hold information about a modifier used in the model. + + Attributes: + kind: The kind of modifier, either 'seir' or 'outcome'. + name: The name of the modifier. + subpops: A list of subpopulation names the modifier applies to. + start_date: The start date of the modifier. + end_date: The end date of the modifier. + parameter: The name of the parameter being modified. + + Examples: + >>> from datetime import date + >>> from pprint import pprint + >>> from gempyor.output import ModifierInfo, ModifierInfoPeriod + >>> periods = [ + ... ModifierInfoPeriod( + ... start_date=date(2020, 1, 1), + ... end_date=date(2020, 1, 31), + ... ), + ... ModifierInfoPeriod( + ... start_date=date(2020, 3, 1), + ... end_date=date(2020, 3, 31), + ... ), + ... ] + >>> pprint(periods) + [ModifierInfoPeriod(start_date=datetime.date(2020, 1, 1), + end_date=datetime.date(2020, 1, 31)), + ModifierInfoPeriod(start_date=datetime.date(2020, 3, 1), + end_date=datetime.date(2020, 3, 31))] + >>> modifier_info = ModifierInfo( + ... kind="seir", + ... name="seasonal_gamma", + ... subpops=["subpop1", "subpop2"], + ... periods=periods, + ... parameter="gamma", + ... ) + >>> pprint(modifier_info) + ModifierInfo(kind='seir', + name='seasonal_gamma', + subpops=['subpop1', 'subpop2'], + periods=[ModifierInfoPeriod(start_date=datetime.date(2020, 1, 1), + end_date=datetime.date(2020, 1, 31)), + ModifierInfoPeriod(start_date=datetime.date(2020, 3, 1), + end_date=datetime.date(2020, 3, 31))], + parameter='gamma') + + """ + + kind: Literal["seir", "outcome"] + name: str + subpops: list[str] + periods: list[ModifierInfoPeriod] + parameter: str + + +@dataclass(frozen=True) +class Chains: + # pylint: disable=line-too-long + """ + Dataclass to hold the chains of a model output. + + Attributes: + shape: The shape of the chains represented as a tuple of integers corresponding + to number of chains, iterations, and parameters. + log_probability: A 2D numpy array of the evaluated log probabilities for each + chain and iteration. Has shape (n_chains, n_iterations). + samples: A 3D numpy array of the sampled values for each chain, iteration, + and parameter. Has shape (n_chains, n_iterations, n_parameters). + modifiers: A list of `ModifierInfo` instances describing the modifiers used in + the model. Corresponds to the order of parameters in the `samples` array. + + Examples: + >>> from datetime import date + >>> from pprint import pprint + >>> import numpy as np + >>> from gempyor.output import Chains, ModifierInfo, ModifierInfoPeriod + >>> rng = np.random.default_rng(12345) + >>> shape = (4, 30, 2) # 4 chains, 30 iterations, 2 parameters + >>> log_probability = rng.lognormal(size=(shape[0], shape[1])) + >>> samples = rng.normal(size=shape) + >>> modifiers = [ + ... ModifierInfo( + ... kind="seir", + ... name="seasonal_beta", + ... subpops=["subpop1"], + ... periods=[ + ... ModifierInfoPeriod( + ... start_date=date(2020, 1, 1), + ... end_date=date(2020, 6, 30), + ... ), + ... ModifierInfoPeriod( + ... start_date=date(2020, 7, 1), + ... end_date=date(2020, 12, 31), + ... ), + ... ], + ... parameter="beta", + ... ), + ... ModifierInfo( + ... kind="outcome", + ... name="hospitalization_rate", + ... subpops=["subpop1", "subpop2"], + ... periods=[ + ... ModifierInfoPeriod( + ... start_date=date(2020, 1, 1), + ... end_date=date(2020, 12, 31), + ... ), + ... ], + ... parameter="hosp::probability", + ... ), + ... ] + >>> chains = Chains( + ... shape=shape, + ... log_probability=log_probability, + ... samples=samples, + ... modifiers=modifiers, + ... ) + >>> chains.shape + (4, 30, 2) + >>> chains.log_probability.shape + (4, 30) + >>> chains.samples.shape + (4, 30, 2) + >>> pprint(chains.modifiers) + [ModifierInfo(kind='seir', + name='seasonal_beta', + subpops=['subpop1'], + periods=[ModifierInfoPeriod(start_date=datetime.date(2020, 1, 1), + end_date=datetime.date(2020, 6, 30)), + ModifierInfoPeriod(start_date=datetime.date(2020, 7, 1), + end_date=datetime.date(2020, 12, 31))], + parameter='beta'), + ModifierInfo(kind='outcome', + name='hospitalization_rate', + subpops=['subpop1', 'subpop2'], + periods=[ModifierInfoPeriod(start_date=datetime.date(2020, 1, 1), + end_date=datetime.date(2020, 12, 31))], + parameter='hosp::probability')] + >>> chains_subset = chains.subset(chains=[0, 1], iterations=[0, 1, 2]) + >>> chains_subset.shape + (2, 3, 2) + >>> chains_subset.log_probability.shape + (2, 3) + >>> chains_subset.samples.shape + (2, 3, 2) + >>> chains.flatten_samples().shape + (120, 2) + >>> modifiers_dfs = chains.to_modifiers_dataframes() + >>> len(modifiers_dfs.snpi) + 120 + >>> modifiers_dfs.snpi[0] + subpop modifier_name start_date end_date parameter value + 0 subpop1 seasonal_beta 2020-01-01,2020-07-01 2020-06-30,2020-12-31 beta -0.811887 + >>> modifiers_dfs.hnpi[0] + subpop modifier_name start_date end_date parameter value + 0 subpop1,subpop2 hospitalization_rate 2020-01-01 2020-12-31 hosp::probability -0.025538 + + """ + # pylint: enable=line-too-long + + shape: tuple[int, int, int] + log_probability: npt.NDArray[np.float64] + samples: npt.NDArray[np.float64] + modifiers: list[ModifierInfo] + + def subset( + self, + chains: list[int] | int | None = None, + iterations: list[int] | int | None = None, + ) -> "Chains": + """ + Create a new `Chains` instance that is a subset of the current instance. + + Args: + iterations: A list, or integer for just one, of iteration indices to include + in the subset. If `None`, include all iterations. + chains: A list, or integer for just one, of chain indices to include in the + subset. If `None`, include all chains. + + Returns: + A new `Chains` instance containing only the specified chains and iterations. + + """ + chains = list(range(self.shape[0])) if chains is None else _ensure_list(chains) + iterations = ( + list(range(self.shape[1])) if iterations is None else _ensure_list(iterations) + ) + return Chains( + shape=(len(chains), len(iterations), self.shape[2]), + log_probability=self.log_probability[np.ix_(chains, iterations)], + samples=self.samples[np.ix_(chains, iterations, np.arange(self.shape[2]))], + modifiers=self.modifiers, + ) + + def flatten_samples(self) -> npt.NDArray[np.float64]: + """ + Flatten the samples array to 2D. + + Returns: + A 2D numpy array of shape (n_chains * n_iterations, n_parameters). + + """ + return self.samples.reshape(-1, self.shape[2]) + + def to_modifiers_dataframes(self) -> ModifiersDataFrames: + """ + Convert the chains to a `ModifiersDataFrames` instance. + + Returns: + A `ModifiersDataFrames` instance containing DataFrames for SEIR and outcome + modifiers with their corresponding sampled values. + """ + # Construct base DataFrames without values + empty_base_df = pd.DataFrame( + columns=[ + "subpop", + "modifier_name", + "start_date", + "end_date", + "parameter", + "value", + ] + ) + snpi_param_idx = [] + hnpi_param_idx = [] + snpi_base = [] + hnpi_base = [] + for i, modifier in enumerate(self.modifiers): + param_idx = snpi_param_idx if modifier.kind == "seir" else hnpi_param_idx + base = snpi_base if modifier.kind == "seir" else hnpi_base + param_idx.append(i) + base.append( + { + "subpop": ",".join(modifier.subpops), + "modifier_name": modifier.name, + "start_date": ",".join( + [p.start_date.strftime("%Y-%m-%d") for p in modifier.periods] + ), + "end_date": ",".join( + [p.end_date.strftime("%Y-%m-%d") for p in modifier.periods] + ), + "parameter": modifier.parameter, + } + ) + snpi_base_df = ( + pd.DataFrame.from_records(snpi_base) if snpi_base else empty_base_df.copy() + ) + hnpi_base_df = ( + pd.DataFrame.from_records(hnpi_base) if hnpi_base else empty_base_df.copy() + ) + # Expand the base DataFrame for each chain and iteration + nchains, niterations, _ = self.shape + do_snpi = bool(snpi_param_idx) + do_hnpi = bool(hnpi_param_idx) + snpi_dfs = [] + hnpi_dfs = [] + for i in range(nchains): + for j in range(niterations): + snpi_df = snpi_base_df.copy() + if do_snpi: + snpi_df["value"] = self.samples[i, j, snpi_param_idx] + hnpi_df = hnpi_base_df.copy() + if do_hnpi: + hnpi_df["value"] = self.samples[i, j, hnpi_param_idx] + snpi_dfs.append(snpi_df) + hnpi_dfs.append(hnpi_df) + + return ModifiersDataFrames( + snpi=snpi_dfs, + hnpi=hnpi_dfs, + ) diff --git a/flepimop/gempyor_pkg/src/gempyor/seir.py b/flepimop/gempyor_pkg/src/gempyor/seir.py index a31fc7a18..c3d55198a 100644 --- a/flepimop/gempyor_pkg/src/gempyor/seir.py +++ b/flepimop/gempyor_pkg/src/gempyor/seir.py @@ -17,7 +17,6 @@ from .model_info import ModelInfo from .utils import Timer, _nslots_random_seeds, read_df - logger = logging.getLogger(__name__) @@ -194,7 +193,10 @@ def build_step_source_arg( } check_parameter_positivity( - fnct_args["parameters"], modinf.parameters.pnames, modinf.dates, modinf.subpop_pop + fnct_args["parameters"], + modinf.parameters.pnames, + modinf.dates, + modinf.subpop_struct.subpop_names, ) return fnct_args diff --git a/flepimop/gempyor_pkg/src/gempyor/simulate.py b/flepimop/gempyor_pkg/src/gempyor/simulate.py index 182e43a03..2401dbb3b 100644 --- a/flepimop/gempyor_pkg/src/gempyor/simulate.py +++ b/flepimop/gempyor_pkg/src/gempyor/simulate.py @@ -1,199 +1,63 @@ -""" -Tools to forward simulate a model with `gempyor`. -""" +"""Tools to forward simulate a model with `gempyor`.""" -#!/usr/bin/env python - -## -# @file -# @brief Runs hospitalization model -# -# @details -# -# ## Configuration Items -# -# ```yaml -# name: -# setup_name: -# start_date: -# end_date: -# dt: float -# nslots: overridden by the -n/--nslots script parameter -# subpop_setup: -# geodata: -# mobility: -# -# seir: -# parameters -# alpha: -# sigma: -# gamma: -# R0s: -# -# seir_modifiers: -# scenarios: -# - -# - -# - ... -# settings: -# : -# method: choose one - "SinglePeriodModifier", ", "StackedModifier" -# ... -# : -# method: choose one - "SinglePeriodModifier", "", "StackedModifier" -# ... -# -# seeding: -# method: choose one - "PoissonDistributed", "FolderDraw" -# ``` -# -# ### seir_modifiers::scenarios::settings:: -# -# If {method} is -# ```yaml -# seir_modifiers: -# scenarios: -# : -# method: SinglePeriodModifier -# parameter: choose one - "alpha, sigma, gamma, r0" -# period_start_date: -# period_end_date: -# value: -# subpop: optional -# ``` -# -# If {method} is -# ```yaml -# seir_modifiers: -# scenarios: -# : -# method: -# period_start_date: -# period_end_date: -# value: -# subpop: optional -# ``` -# -# If {method} is StackedModifier -# ```yaml -# seir_modifiers: -# scenarios: -# : -# method: StackedModifier -# scenarios: -# ``` -# -# ### seeding -# -# If {seeding::method} is PoissonDistributed -# ```yaml -# seeding: -# method: PoissonDistributed -# lambda_file: -# ``` -# -# If {seeding::method} is FolderDraw -# ```yaml -# seeding: -# method: FolderDraw -# folder_path: \; make sure this ends in a '/' -# ``` -# -# ## Input Data -# -# * {subpop_setup::geodata} is a csv with columns {subpop_setup::subpop_names} and {subpop_setup::subpop_pop} -# * {subpop_setup::mobility} -# -# If {seeding::method} is PoissonDistributed -# * {seeding::lambda_file} -# -# If {seeding::method} is FolderDraw -# * {seeding::folder_path}/[simulation ID].impa.csv -# -# ## Output Data -# -# * model_output/{setup_name}_[scenario]/[simulation ID].seir.[csv/parquet] -# * model_parameters/{setup_name}_[scenario]/[simulation ID].spar.[csv/parquet] -# * model_parameters/{setup_name}_[scenario]/[simulation ID].snpi.[csv/parquet] -# ## Configuration Items -# -# ```yaml -# outcomes: -# method: delayframe # Only fast is supported atm. Makes fast delay_table computations. Later agent-based method ? -# paths: -# param_from_file: TRUE # -# param_subpop_file: # OPTIONAL: File with param per csv. For each param in this file -# scenarios: # Outcomes scenarios to run -# - low_death_rate -# - mid_death_rate -# settings: # Setting for each scenario -# low_death_rate: -# new_comp1: # New compartement name -# source: incidence # Source of the new compartement: either an previously defined compartement or "incidence" for diffI of the SEIR -# probability: # Branching probability from source -# delay: # Delay from incidence of source to incidence of new_compartement -# duration: # OPTIONAL ! Duration in new_comp. If provided, the model add to it's -# #output "new_comp1_curr" with current amount in new_comp1 -# new_comp2: # Example for a second compatiment -# source: new_comp1 -# probability: -# delay: -# duration: -# death_tot: # Possibility to combine compartements for death. -# sum: ['death_hosp', 'death_ICU', 'death_incid'] -# -# mid_death_rate: -# ... -# -# ## Input Data -# -# * {param_subpop_file} is a csv with columns subpop, parameter, value. Parameter is constructed as, e.g for comp1: -# probability: Pnew_comp1|source -# delay: Dnew_comp1 -# duration: Lnew_comp1 - - -# ## Output Data -# * {output_path}/model_output/{setup_name}_[scenario]/[simulation ID].hosp.parquet - - -## @cond - -import time, warnings, sys - -from pathlib import Path +import pickle +import subprocess +import sys +import time +import warnings from collections.abc import Iterable +from itertools import product +from pathlib import Path +from typing import Any +import click from confuse import Configuration -from click import Context, pass_context -from . import seir, outcomes, model_info, utils -from .shared_cli import ( - config_files_argument, - config_file_options, - parse_config_files, - cli, - click_helpstring, - mock_context, -) +from . import outcomes, seir, utils +from .model_info import ModelInfo +from .output import Chains +from .shared_cli import cli, config_file_options, config_files_argument, parse_config_files + + +def _simulate_seir_and_outcomes( + modinf: ModelInfo, + cfg: Configuration, + chains: Chains | None, + first_sim_index: int, + nslots: int, + jobs: int, + run_seir: bool = True, + run_outcomes: bool = True, +) -> None: + """ + Thin wrapper to run the SEIR and outcomes simulations in parallel. -# from .profile import profile_options + Args: + modinf: A `ModelInfo` instance corresponding to the simulation to be run. + cfg: A `Configuration` instance containing the simulation configuration. + chains: Optional `Chains` instance containing MCMC samples to run the simulation + from. If `None`, the simulation will be run using the parameters in the + configuration file. + first_sim_index: The index of the first simulation to be run. + nslots: The number of simulation chains to be run. + jobs: The number of parallel jobs to use. + run_seir: Whether to run the SEIR model. + run_outcomes: Whether to run the outcomes model. + """ + if run_seir: + seir.run_parallel_SEIR(modinf, cfg, n_jobs=jobs) + if run_outcomes: + outcomes.run_parallel_outcomes( + modinf, + sim_id2write=first_sim_index, + nslots=nslots, + n_jobs=jobs, + ) -# @profile_options -# @profile() def simulate( config_filepath: Configuration | Path | Iterable[Path], - id_run_id: str = None, - out_run_id: str = None, - seir_modifiers_scenarios: str | Iterable[str] = [], - outcome_modifiers_scenarios: str | Iterable[str] = [], - in_prefix: str = None, - nslots: int = None, - jobs: int = None, - write_csv: bool = False, - write_parquet: bool = True, - first_sim_index: int = 1, + from_chains: Path | None = None, verbose: bool = True, ) -> int: """ @@ -224,76 +88,117 @@ def simulate( else: cfg = config_filepath - scenarios_combinations = [ - [s, d] - for s in ( - cfg["seir_modifiers"]["scenarios"].as_str_seq() - if cfg["seir_modifiers"].exists() - else [None] - ) - for d in ( - cfg["outcome_modifiers"]["scenarios"].as_str_seq() - if cfg["outcome_modifiers"].exists() - else [None] - ) - ] + seir_modifiers_scenarios = ( + cfg["seir_modifiers"]["scenarios"].as_str_seq() + if cfg["seir_modifiers"].exists() + else [None] + ) + outcome_modifiers_scenarios = ( + cfg["outcome_modifiers"]["scenarios"].as_str_seq() + if cfg["outcome_modifiers"].exists() + else [None] + ) + scenarios_combinations = list( + product(seir_modifiers_scenarios, outcome_modifiers_scenarios) + ) if verbose: print("Combination of modifiers scenarios to be run: ") print(scenarios_combinations) for seir_modifiers_scenario, outcome_modifiers_scenario in scenarios_combinations: print( - f"seir_modifier: {seir_modifiers_scenario}, outcomes_modifier: {outcome_modifiers_scenario}" + f"seir_modifier: {seir_modifiers_scenario}, " + f"outcomes_modifier: {outcome_modifiers_scenario}" ) - nchains = cfg["nslots"].as_number() + nslots = cfg["nslots"].as_number() if verbose: - print(f"Simulations to be run: {nchains}") + print(f"Simulations to be run: {nslots}") + + write_csv = cfg["write_csv"].get(bool) + write_parquet = cfg["write_parquet"].get(bool) + first_sim_index = cfg["first_sim_index"].get(int) + in_run_id = cfg["in_run_id"].get(str) if cfg["in_run_id"].exists() else None + out_run_id = cfg["out_run_id"].get(str) if cfg["out_run_id"].exists() else None + config_filepath = cfg["config_src"].as_str_seq() + n_jobs = cfg["jobs"].get(int) + run_seir = cfg["seir"].exists() + run_outcomes = cfg["outcomes"].exists() + + # Load samples from chains if provided + chains: Chains | None = None + if from_chains is not None: + with from_chains.open("rb") as f: + chains = pickle.load(f) + if not isinstance(chains, Chains): + raise ValueError(f"Expected a Chains instance, got {type(chains)}") + if verbose: + print( + f"Loaded chains with shape {chains.shape} " + f"from {from_chains} for simulating." + ) + n_chains, n_iterations, _ = chains.shape + if n_chains != n_jobs: + raise ValueError( + f"Number of chains in {from_chains} is {n_chains}, which " + f"does not match the number of jobs to be run, {n_jobs}." + ) + if n_iterations < nslots: + raise ValueError( + f"Number of iterations in {from_chains} is {n_iterations}, " + f"which is less than the number of slots to be run, {nslots}." + ) for seir_modifiers_scenario, outcome_modifiers_scenario in scenarios_combinations: start = time.monotonic() if verbose: print(f"Running {seir_modifiers_scenario}_{outcome_modifiers_scenario}") - modinf = model_info.ModelInfo( + modinf = ModelInfo( config=cfg, - nslots=nchains, + nslots=nslots, seir_modifiers_scenario=seir_modifiers_scenario, outcome_modifiers_scenario=outcome_modifiers_scenario, - write_csv=cfg["write_csv"].get(bool), - write_parquet=cfg["write_parquet"].get(bool), - first_sim_index=cfg["first_sim_index"].get(int), - in_run_id=cfg["in_run_id"].get(str) if cfg["in_run_id"].exists() else None, - # in_prefix=config["name"].get() + "/", - out_run_id=cfg["out_run_id"].get(str) if cfg["out_run_id"].exists() else None, - # out_prefix=config["name"].get() + "/" + str(seir_modifiers_scenario) + "/" + out_run_id + "/", - config_filepath=cfg["config_src"].as_str_seq(), + write_csv=write_csv, + write_parquet=write_parquet, + first_sim_index=first_sim_index, + in_run_id=in_run_id, + out_run_id=out_run_id, + config_filepath=config_filepath, ) if verbose: + print(f">> Running from config {config_filepath}") + print( + f">> Starting {nslots} model runs beginning " + f"from {first_sim_index} on {n_jobs} processes" + ) print( - f""" - >> Running from config {cfg["config_src"].as_str_seq()} - >> Starting {modinf.nslots} model runs beginning from {modinf.first_sim_index} on {cfg["jobs"].get(int)} processes - >> ModelInfo *** {modinf.setup_name} *** from {modinf.ti} to {modinf.tf} - >> Running scenario {seir_modifiers_scenario}_{outcome_modifiers_scenario} - >> using ***{modinf.get_engine()}*** engine for trajectories - """ + f">> ModelInfo *** {modinf.setup_name} " + f"*** from {modinf.ti} to {modinf.tf}" ) - # (there should be a run function) - if cfg["seir"].exists(): - seir.run_parallel_SEIR(modinf, config=cfg, n_jobs=cfg["jobs"].get(int)) - if cfg["outcomes"].exists(): - outcomes.run_parallel_outcomes( - sim_id2write=cfg["first_sim_index"].get(int), - modinf=modinf, - nslots=nchains, - n_jobs=cfg["jobs"].get(int), + print( + f">> Running scenario " + f"{seir_modifiers_scenario}_{outcome_modifiers_scenario}" ) + print(f">> using ***{modinf.get_engine()}*** engine for trajectories") + + _simulate_seir_and_outcomes( + modinf, + cfg, + chains, + first_sim_index, + nslots, + n_jobs, + run_seir=run_seir, + run_outcomes=run_outcomes, + ) + if verbose: print( - f">>> {seir_modifiers_scenario}_{outcome_modifiers_scenario} completed in {time.monotonic() - start:.1f} seconds" + f">>> {seir_modifiers_scenario}_{outcome_modifiers_scenario} " + f"completed in {time.monotonic() - start:.1f} seconds" ) return 0 @@ -301,22 +206,31 @@ def simulate( @cli.command( name="simulate", - params=[config_files_argument] + list(config_file_options.values()), + params=[config_files_argument] + + list(config_file_options.values()) + + [ + click.Option( + param_decls=["--from-chains"], + type=click.Path(exists=True, dir_okay=False), + default=None, + show_default=True, + help=( + "Optional path to a chains pickle file to run simulations from. " + "Will override the modifiers within the config file(s)." + ), + ) + ], context_settings=dict(help_option_names=["-h", "--help"]), ) -@pass_context -def _click_simulate(ctx: Context, **kwargs) -> int: +@click.pass_context +def _click_simulate(ctx: click.Context, **kwargs: Any) -> int: """Forward simulate a model using gempyor.""" cfg = parse_config_files(utils.config, ctx, **kwargs) - return simulate(cfg) - + return simulate(cfg, from_chains=kwargs.get("from_chains")) -# will all be removed upon deprecated endpoint removal -import subprocess - - -def _deprecated_simulate(argv: list[str] = []) -> int: +def _deprecated_simulate(argv: list[str] | None = None) -> int: + argv = argv or [] if not argv: argv = sys.argv[1:] clickcmd = " ".join(["flepimop", "simulate"] + argv) @@ -331,5 +245,3 @@ def _deprecated_simulate(argv: list[str] = []) -> int: clickcmd = " ".join(["flepimop", "simulate"] + argv) warnings.warn(f"Use the CLI instead: `{clickcmd}`", DeprecationWarning) _deprecated_simulate(argv) - -## @endcond diff --git a/flepimop/gempyor_pkg/src/gempyor/utils.py b/flepimop/gempyor_pkg/src/gempyor/utils.py index 411224ee6..e2ec88584 100644 --- a/flepimop/gempyor_pkg/src/gempyor/utils.py +++ b/flepimop/gempyor_pkg/src/gempyor/utils.py @@ -2,34 +2,28 @@ Helper functions for interacting with model I/O. """ -from collections import Counter -from collections.abc import Iterable import datetime import functools import logging -import numbers import os -from pathlib import Path import random -from shlex import quote as shlex_quote import shutil import subprocess import time -from typing import Any, Callable, Literal, overload +from collections import Counter +from collections.abc import Iterable +from pathlib import Path +from shlex import quote as shlex_quote +from typing import Any, Literal, TypeVar, overload import confuse import numpy as np import numpy.typing as npt import pandas as pd -import pyarrow as pa import scipy.ndimage -import scipy.stats -import sympy.parsing.sympy_parser import yaml from . import file_paths -from ._pydantic_ext import _evaled_expression - logger = logging.getLogger(__name__) @@ -126,8 +120,8 @@ def command_safe_run( Raises: RuntimeError: If fail_on_fail=True and the command fails, an error will be thrown. """ - import subprocess import shlex # using shlex to split the command because it's not obvious https://docs.python.org/3/library/subprocess.html#subprocess.Popen + import subprocess sr = subprocess.Popen( shlex.split(command), stdout=subprocess.PIPE, stderr=subprocess.PIPE @@ -192,7 +186,8 @@ def search_and_import_plugins_class( # Look for all possible plugins and import them # https://stackoverflow.com/questions/67631/how-can-i-import-a-module-dynamically-given-the-full-path # unfortunatelly very complicated, this is cpython only ?? - import sys, os + import os + import sys full_path = os.path.join(path_prefix, plugin_file_path) sys.path.append(os.path.dirname(full_path)) @@ -1196,3 +1191,96 @@ def _trim_s3_path(path: str | Path) -> str | Path: PosixPath('s3:/foo/bar.txt') """ return path.lstrip("s3:") if isinstance(path, str) else path + + +T = TypeVar("T") + + +def _flatten_list_of_lists(value: list[list[T]] | list[T] | None) -> list[T]: + """ + Flatten a list of lists into a single list. + + Args: + value: A value to flatten. + + Returns: + A flattened list. If `value` is None, an empty list is returned. If `value` is + not a list, it is returned as-is. If `value` is an empty list, an empty list is + returned. If `value` is a list of lists, it is flattened. + + Examples: + >>> from gempyor.utils import _flatten_list_of_lists + >>> _flatten_list_of_lists(None) + [] + >>> _flatten_list_of_lists(42) + [42] + >>> _flatten_list_of_lists([1, 2, 3]) + [1, 2, 3] + >>> _flatten_list_of_lists([[1, 2], [3, 4]]) + [1, 2, 3, 4] + >>> _flatten_list_of_lists([[1, 2], "a", "b", [2.1, [4.3, 5.6]]]) + [1, 2, 'a', 'b', 2.1, 4.3, 5.6] + """ + if value is None: + return [] + if not isinstance(value, list): + return [value] + if not value: + return [] + return [x for subvalue in value for x in _flatten_list_of_lists(subvalue)] + + +def _make_list_of_list(value: list[list[T]] | list[T] | T | None) -> list[list[T]]: + # pylint: disable=line-too-long + """ + Construct a list of lists from a flat list-ish object. + + Args: + value: A value to coerce into a list of lists. + + Returns: + A list of lists. If `value` is None, an empty list is returned. If `value` is + not a list, it is wrapped in a single-element list of lists. If `value` is an + empty list, an empty list is returned. If `value` is a list of lists, it is + returned as-is. + + Examples: + >>> from gempyor.utils import _make_list_of_list + >>> _make_list_of_list(None) + [] + >>> _make_list_of_list(42) + [[42]] + >>> _make_list_of_list([1, 2, 3]) + [[1, 2, 3]] + >>> _make_list_of_list([[1, 2], [3, 4]]) + [[1, 2], [3, 4]] + >>> _make_list_of_list({"key": "value"}) + [[{'key': 'value'}]] + >>> _make_list_of_list("string") + [['string']] + >>> _make_list_of_list([1, 2, [3, 4]]) + Traceback (most recent call last): + ... + ValueError: value=[1, 2, [3, 4]] contains a mix of lists and non-lists, cannot coerce to list of lists. + >>> _make_list_of_list([]) + [] + >>> _make_list_of_list([[]]) + [[]] + + """ + # pylint: enable=line-too-long + if value is None: + return [] + if not isinstance(value, list): + return [[value]] + if not value: + return [] + if all(isinstance(x, list) for x in value): + return value + if any(isinstance(x, list) for x in value): + msg = ( + f"{value=} contains a mix of lists and " + "non-lists, cannot coerce to list of lists." + ) + raise ValueError(msg) + return [value] diff --git a/flepimop/gempyor_pkg/tests/distributions/test_distributions_common.py b/flepimop/gempyor_pkg/tests/distributions/test_distributions_common.py index 62a4ee20b..1fa63d7ea 100644 --- a/flepimop/gempyor_pkg/tests/distributions/test_distributions_common.py +++ b/flepimop/gempyor_pkg/tests/distributions/test_distributions_common.py @@ -1,6 +1,7 @@ import numpy as np import pytest from gempyor.distributions import DistributionABC +from pydantic import PrivateAttr class DummyDistribution(DistributionABC): @@ -8,9 +9,8 @@ class DummyDistribution(DistributionABC): distribution: str = "dummy" - def __call__(self) -> float | int: - """A shortcut for `self.sample(size=1)`.""" - return self.sample(size=1).item() + _lower_bound: float = PrivateAttr(default=0.0) + _upper_bound: float = PrivateAttr(default=1.0) def _sample_from_generator( self, size: int | tuple[int, ...], rng: np.random.Generator @@ -32,7 +32,6 @@ def test_stochastic_sampling_with_default_rng() -> None: dist = DummyDistribution() sample1 = dist.sample(size=10) sample2 = dist.sample(size=10) - assert not np.array_equal(sample1, sample2) diff --git a/flepimop/gempyor_pkg/tests/inference_parameter/test_inference_parameter_class.py b/flepimop/gempyor_pkg/tests/inference_parameter/test_inference_parameter_class.py new file mode 100644 index 000000000..5bb47794a --- /dev/null +++ b/flepimop/gempyor_pkg/tests/inference_parameter/test_inference_parameter_class.py @@ -0,0 +1,820 @@ +"""Unit tests for the `gempyor.inference_parameter.InferenceParameter` class.""" + +from math import inf +from typing import Any, Literal, NamedTuple, TypedDict + +import numpy as np +import numpy.typing as npt +import pandas as pd +import pytest +from gempyor.distributions import ( + BetaDistribution, + DistributionABC, + GammaDistribution, + UniformDistribution, +) +from gempyor.inference_parameter import InferenceParameters +from gempyor.testing import create_confuse_configview_from_dict + + +class AddSingleParameterArg(TypedDict): + """Arguments for adding a single parameter.""" + + ptype: Literal["outcome_modifiers", "seir_modifiers"] + pname: str + subpop: str + pdist: DistributionABC + + +class AddModifierArg(NamedTuple): + """Arguments for adding a modifier parameter.""" + + pname: str + ptype: Literal["outcome_modifiers", "seir_modifiers"] + parameter_config: dict[str, Any] + + def as_dict(self) -> dict[str, Any]: + """Convert to a dictionary.""" + return { + "pname": self.pname, + "ptype": self.ptype, + "parameter_config": create_confuse_configview_from_dict(self.parameter_config), + } + + def number_of_subpopulations(self, subpopulations: list[str]) -> int: + """Return the number of subpopulations.""" + if ( + self.parameter_config.get("method", "SinglePeriodModifier") + == "MultiPeriodModifier" + ): + return sum( + AddModifierArg( + pname=self.pname, ptype=self.ptype, parameter_config=group + ).number_of_subpopulations(subpopulations) + for group in self.parameter_config.get("groups", []) + ) + if subpop_groups := self.parameter_config.get("subpop_groups"): + return len(subpop_groups) + if self.parameter_config["subpop"] == "all": + return len(subpopulations) + return len(self.parameter_config["subpop"]) + + +@pytest.mark.parametrize( + "single_parameter_args", + [ + [ + AddSingleParameterArg( + ptype="incidH::delay", + pname="Ro", + subpop="GA", + pdist=GammaDistribution(shape=14.0, scale=0.5), + ), + ], + [ + AddSingleParameterArg( + ptype="seir_modifiers", + pname="gamma", + subpop="GA", + pdist=GammaDistribution(shape=2.0, scale=0.5), + ), + AddSingleParameterArg( + ptype="seir_modifiers", + pname="beta", + subpop="GA", + pdist=GammaDistribution(shape=2.0, scale=0.5), + ), + ], + [ + AddSingleParameterArg( + ptype="outcome_modifiers", + pname="incidH::probability", + subpop="USA", + pdist=BetaDistribution(alpha=2.0, beta=5.0), + ), + AddSingleParameterArg( + ptype="seir_modifiers", + pname="Ro", + subpop="USA", + pdist=GammaDistribution(shape=1.5, scale=0.5), + ), + AddSingleParameterArg( + ptype="outcome_modifiers", + pname="incidH::delay", + subpop="Canada", + pdist=GammaDistribution(shape=10.0, scale=0.5), + ), + AddSingleParameterArg( + ptype="seir_modifiers", + pname="beta", + subpop="Canada", + pdist=GammaDistribution(shape=2.0, scale=0.5), + ), + ], + ], +) +def test_adding_single_parameters( + single_parameter_args: list[AddSingleParameterArg], +) -> None: + """Test adding single parameters to an empty `InferenceParameters` class.""" + inference_params = InferenceParameters(create_confuse_configview_from_dict({}), []) + for arg in single_parameter_args: + inference_params.add_single_parameter(**arg) + assert inference_params.get_dim() == len(single_parameter_args) + assert inference_params.ptypes == [arg["ptype"] for arg in single_parameter_args] + assert inference_params.pnames == [arg["pname"] for arg in single_parameter_args] + assert inference_params.subpops == [arg["subpop"] for arg in single_parameter_args] + assert inference_params.pdists == [arg["pdist"] for arg in single_parameter_args] + + +@pytest.mark.parametrize( + ("modifier_args", "subpopulations"), + [ + ( + [ + AddModifierArg( + pname="Ro_summer", + ptype="seir_modifiers", + parameter_config={ + "method": "SinglePeriodModifier", + "parameter": "Ro", + "subpop": "all", + "value": { + "distribution": "gamma", + "shape": 1.5, + "scale": 0.5, + }, + }, + ), + ], + ["USA"], + ), + ( + [ + AddModifierArg( + pname="gamma_humid", + ptype="seir_modifiers", + parameter_config={ + "method": "SinglePeriodModifier", + "parameter": "gamma", + "subpop": "all", + "value": { + "distribution": "gamma", + "shape": 1.5, + "scale": 0.5, + }, + }, + ), + ], + ["NC", "SC", "GA"], + ), + ( + [ + AddModifierArg( + pname="incidH_spring_fall", + ptype="outcome_modifiers", + parameter_config={ + "method": "MultiPeriodModifier", + "parameter": "incidH::probability", + "groups": [ + { + "subpop": "all", + }, + ], + "value": { + "distribution": "beta", + "alpha": 2.0, + "beta": 5.0, + }, + }, + ), + AddModifierArg( + pname="incidH_winter_delay", + ptype="outcome_modifiers", + parameter_config={ + "method": "MultiPeriodModifier", + "parameter": "incidH::delay", + "groups": [ + { + "subpop": "all", + "subpop_groups": [["USA", "Canada"]], + } + ], + "value": { + "distribution": "gamma", + "shape": 1.5, + "scale": 0.5, + }, + }, + ), + ], + ["USA", "Canada"], + ), + ( + [ + AddModifierArg( + pname="Ro_summer_midatlantic", + ptype="seir_modifiers", + parameter_config={ + "method": "SinglePeriodModifier", + "parameter": "Ro", + "subpop": ["VA", "NC"], + "subpop_groups": [["VA", "NC"]], + "value": { + "distribution": "truncnorm", + "a": 0.1, + "b": 2.1, + "mean": 0.9, + "sd": 0.5, + }, + }, + ), + AddModifierArg( + pname="Ro_summer_southeast", + ptype="seir_modifiers", + parameter_config={ + "method": "SinglePeriodModifier", + "parameter": "Ro", + "subpop": ["SC", "GA"], + "subpop_groups": [["SC", "GA"]], + "value": { + "distribution": "truncnorm", + "a": 0.1, + "b": 2.1, + "mean": 0.9, + "sd": 0.5, + }, + }, + ), + ], + ["VA", "NC", "SC", "GA"], + ), + ], +) +def test_adding_modifiers( + modifier_args: list[AddModifierArg], subpopulations: list[str] +) -> None: + """Test adding modifier parameters to an `InferenceParameters` class.""" + inference_params = InferenceParameters(create_confuse_configview_from_dict({}), []) + for arg in modifier_args: + inference_params.add_modifier(**arg.as_dict() | {"subpops": subpopulations}) + + assert inference_params.get_dim() == sum( + arg.number_of_subpopulations(subpopulations) for arg in modifier_args + ) + + +@pytest.mark.parametrize( + ("single_parameter_args", "proposal"), + [ + ( + [ + AddSingleParameterArg( + ptype="incidH::delay", + pname="Ro", + subpop="GA", + pdist=GammaDistribution(shape=14.0, scale=0.5), + ), + ], + np.array([7.0], dtype=np.float64), + ), + ( + [ + AddSingleParameterArg( + ptype="incidH::delay", + pname="Ro", + subpop="GA", + pdist=GammaDistribution(shape=14.0, scale=0.5), + ), + ], + np.array([0.0], dtype=np.float64), + ), + ( + [ + AddSingleParameterArg( + ptype="incidH::delay", + pname="Ro", + subpop="GA", + pdist=GammaDistribution(shape=14.0, scale=0.5), + ), + ], + np.array([-0.5], dtype=np.float64), + ), + ( + [ + AddSingleParameterArg( + ptype="seir_modifiers", + pname="gamma", + subpop="GA", + pdist=GammaDistribution(shape=2.0, scale=0.5), + ), + AddSingleParameterArg( + ptype="seir_modifiers", + pname="beta", + subpop="GA", + pdist=UniformDistribution(low=0.0, high=2.0), + ), + ], + np.array([1.0, 1.0], dtype=np.float64), + ), + ( + [ + AddSingleParameterArg( + ptype="seir_modifiers", + pname="gamma", + subpop="GA", + pdist=GammaDistribution(shape=2.0, scale=0.5), + ), + AddSingleParameterArg( + ptype="seir_modifiers", + pname="beta", + subpop="GA", + pdist=UniformDistribution(low=0.0, high=2.0), + ), + ], + np.array([-1.0, 1.0], dtype=np.float64), + ), + ( + [ + AddSingleParameterArg( + ptype="seir_modifiers", + pname="gamma", + subpop="GA", + pdist=GammaDistribution(shape=2.0, scale=0.5), + ), + AddSingleParameterArg( + ptype="seir_modifiers", + pname="beta", + subpop="GA", + pdist=UniformDistribution(low=0.0, high=2.0), + ), + ], + np.array([1.0, -1.0], dtype=np.float64), + ), + ( + [ + AddSingleParameterArg( + ptype="seir_modifiers", + pname="gamma", + subpop="GA", + pdist=GammaDistribution(shape=2.0, scale=0.5), + ), + AddSingleParameterArg( + ptype="seir_modifiers", + pname="beta", + subpop="GA", + pdist=UniformDistribution(low=0.0, high=2.0), + ), + ], + np.array([-1.0, -1.0], dtype=np.float64), + ), + ( + [ + AddSingleParameterArg( + ptype="seir_modifiers", + pname="gamma", + subpop="GA", + pdist=GammaDistribution(shape=2.0, scale=0.5), + ), + AddSingleParameterArg( + ptype="seir_modifiers", + pname="beta", + subpop="GA", + pdist=UniformDistribution(low=0.0, high=2.0), + ), + ], + np.array([0.0, 0.0], dtype=np.float64), + ), + ( + [ + AddSingleParameterArg( + ptype="seir_modifiers", + pname="gamma", + subpop="GA", + pdist=GammaDistribution(shape=2.0, scale=0.5), + ), + AddSingleParameterArg( + ptype="seir_modifiers", + pname="beta", + subpop="GA", + pdist=UniformDistribution(low=0.0, high=2.0), + ), + ], + np.array([0.0, 2.0], dtype=np.float64), + ), + ( + [ + AddSingleParameterArg( + ptype="seir_modifiers", + pname="gamma", + subpop="GA", + pdist=GammaDistribution(shape=2.0, scale=0.5), + ), + AddSingleParameterArg( + ptype="seir_modifiers", + pname="beta", + subpop="GA", + pdist=UniformDistribution(low=0.0, high=2.0), + ), + ], + np.array([0.0, 3.0], dtype=np.float64), + ), + ], +) +def test_check_in_bound( + single_parameter_args: list[AddSingleParameterArg], proposal: npt.NDArray[np.float64] +) -> None: + """Test checking if a proposal is within bounds for single parameters.""" + inference_params = InferenceParameters(create_confuse_configview_from_dict({}), []) + for arg in single_parameter_args: + inference_params.add_single_parameter(**arg) + assert proposal.ndim == 1 + assert len(inference_params) == len(proposal) + lower_bounds = [arg["pdist"].support[0] for arg in single_parameter_args] + upper_bounds = [arg["pdist"].support[1] for arg in single_parameter_args] + assert inference_params.check_in_bound(proposal) == np.all( + (proposal >= lower_bounds) & (proposal <= upper_bounds) + ) + + +@pytest.mark.parametrize( + ( + "modifier_args", + "subpopulations", + "proposal", + "snpi_df", + "hnpi_df", + "expected_snpi_df", + "expected_hnpi_df", + ), + [ + ( + [ + AddModifierArg( + pname="gamma_humid", + ptype="seir_modifiers", + parameter_config={ + "method": "SinglePeriodModifier", + "parameter": "gamma", + "subpop": "all", + "value": { + "distribution": "gamma", + "shape": 1.5, + "scale": 0.5, + }, + }, + ), + ], + ["GA", "NC", "SC"], + np.arange(1, 4, dtype=np.float64), + pd.DataFrame.from_records( + [ + { + "subpop": "GA", + "modifier_name": "gamma_humid", + "start_date": "2023-07-01", + "end_date": "2023-09-30", + "parameter": "gamma", + "value": 1.5, + }, + { + "subpop": "NC", + "modifier_name": "gamma_humid", + "start_date": "2023-07-01", + "end_date": "2023-09-30", + "parameter": "gamma", + "value": 1.5, + }, + { + "subpop": "SC", + "modifier_name": "gamma_humid", + "start_date": "2023-07-01", + "end_date": "2023-09-30", + "parameter": "gamma", + "value": 1.5, + }, + ] + ), + pd.DataFrame( + columns=[ + "subpop", + "modifier_name", + "start_date", + "end_date", + "parameter", + "value", + ] + ), + pd.DataFrame.from_records( + [ + { + "subpop": "GA", + "modifier_name": "gamma_humid", + "start_date": "2023-07-01", + "end_date": "2023-09-30", + "parameter": "gamma", + "value": 1.0, + }, + { + "subpop": "NC", + "modifier_name": "gamma_humid", + "start_date": "2023-07-01", + "end_date": "2023-09-30", + "parameter": "gamma", + "value": 2.0, + }, + { + "subpop": "SC", + "modifier_name": "gamma_humid", + "start_date": "2023-07-01", + "end_date": "2023-09-30", + "parameter": "gamma", + "value": 3.0, + }, + ] + ), + pd.DataFrame( + columns=[ + "subpop", + "modifier_name", + "start_date", + "end_date", + "parameter", + "value", + ] + ), + ), + ( + [ + AddModifierArg( + pname="incidH_spring_delay", + ptype="outcome_modifiers", + parameter_config={ + "method": "SinglePeriodModifier", + "parameter": "incidH::delay", + "subpop": "all", + "value": { + "distribution": "gamma", + "shape": 14, + "scale": 0.5, + }, + }, + ), + ], + ["USA", "Canada"], + np.array([7.0, 10.0], dtype=np.float64), + pd.DataFrame( + columns=[ + "subpop", + "modifier_name", + "start_date", + "end_date", + "parameter", + "value", + ] + ), + pd.DataFrame.from_records( + [ + { + "subpop": "Canada", + "modifier_name": "incidH_spring_delay", + "start_date": "2023-03-01", + "end_date": "2023-05-31", + "parameter": "incidH::delay", + "value": 14.0, + }, + { + "subpop": "USA", + "modifier_name": "incidH_spring_delay", + "start_date": "2023-03-01", + "end_date": "2023-05-31", + "parameter": "incidH::delay", + "value": 14.0, + }, + ] + ), + pd.DataFrame( + columns=[ + "subpop", + "modifier_name", + "start_date", + "end_date", + "parameter", + "value", + ] + ), + pd.DataFrame.from_records( + [ + { + "subpop": "Canada", + "modifier_name": "incidH_spring_delay", + "start_date": "2023-03-01", + "end_date": "2023-05-31", + "parameter": "incidH::delay", + "value": 7.0, + }, + { + "subpop": "USA", + "modifier_name": "incidH_spring_delay", + "start_date": "2023-03-01", + "end_date": "2023-05-31", + "parameter": "incidH::delay", + "value": 10.0, + }, + ] + ), + ), + ( + [ + AddModifierArg( + pname="beta_summer_regional", + ptype="seir_modifiers", + parameter_config={ + "method": "SinglePeriodModifier", + "parameter": "beta", + "subpop": "all", + "subpop_groups": [["VA", "NC"], ["SC", "GA"]], + "value": { + "distribution": "truncnorm", + "a": 0.1, + "b": 2.1, + "mean": 0.9, + "sd": 0.5, + }, + }, + ), + AddModifierArg( + pname="incidH_fall_prob", + ptype="outcome_modifiers", + parameter_config={ + "method": "SinglePeriodModifier", + "parameter": "incidH::probability", + "subpop": "all", + "value": { + "distribution": "beta", + "alpha": 2.0, + "beta": 5.0, + }, + }, + ), + ], + ["VA", "NC", "SC", "GA"], + np.array([1.0, 2.0, 0.8, 0.4, 0.6, 0.2], dtype=np.float64), + pd.DataFrame.from_records( + [ + { + "subpop": "NC,VA", + "modifier_name": "beta_summer_regional", + "start_date": "2023-07-01", + "end_date": "2023-09-30", + "parameter": "beta", + "value": 1.5, + }, + { + "subpop": "GA,SC", + "modifier_name": "beta_summer_regional", + "start_date": "2023-07-01", + "end_date": "2023-09-30", + "parameter": "beta", + "value": 1.5, + }, + { + "subpop": "GA,NC,SC", + "modifier_name": "beta_humid", + "start_date": "2023-07-01", + "end_date": "2023-09-30", + "parameter": "beta", + "value": 1.5, + }, + ] + ), + pd.DataFrame.from_records( + [ + { + "subpop": "VA", + "modifier_name": "incidH_fall_prob", + "start_date": "2023-07-01", + "end_date": "2023-09-30", + "parameter": "incidH::probability", + "value": 0.05, + }, + { + "subpop": "NC", + "modifier_name": "incidH_fall_prob", + "start_date": "2023-07-01", + "end_date": "2023-09-30", + "parameter": "incidH::probability", + "value": 0.05, + }, + { + "subpop": "SC", + "modifier_name": "incidH_fall_prob", + "start_date": "2023-07-01", + "end_date": "2023-09-30", + "parameter": "incidH::probability", + "value": 0.05, + }, + { + "subpop": "GA", + "modifier_name": "incidH_fall_prob", + "start_date": "2023-07-01", + "end_date": "2023-09-30", + "parameter": "incidH::probability", + "value": 0.05, + }, + ] + ), + pd.DataFrame.from_records( + [ + { + "subpop": "NC,VA", + "modifier_name": "beta_summer_regional", + "start_date": "2023-07-01", + "end_date": "2023-09-30", + "parameter": "beta", + "value": 1.0, + }, + { + "subpop": "GA,SC", + "modifier_name": "beta_summer_regional", + "start_date": "2023-07-01", + "end_date": "2023-09-30", + "parameter": "beta", + "value": 2.0, + }, + { + "subpop": "GA,NC,SC", + "modifier_name": "beta_humid", + "start_date": "2023-07-01", + "end_date": "2023-09-30", + "parameter": "beta", + "value": 1.5, + }, + ] + ), + pd.DataFrame.from_records( + [ + { + "subpop": "VA", + "modifier_name": "incidH_fall_prob", + "start_date": "2023-07-01", + "end_date": "2023-09-30", + "parameter": "incidH::probability", + "value": 0.2, + }, + { + "subpop": "NC", + "modifier_name": "incidH_fall_prob", + "start_date": "2023-07-01", + "end_date": "2023-09-30", + "parameter": "incidH::probability", + "value": 0.4, + }, + { + "subpop": "SC", + "modifier_name": "incidH_fall_prob", + "start_date": "2023-07-01", + "end_date": "2023-09-30", + "parameter": "incidH::probability", + "value": 0.6, + }, + { + "subpop": "GA", + "modifier_name": "incidH_fall_prob", + "start_date": "2023-07-01", + "end_date": "2023-09-30", + "parameter": "incidH::probability", + "value": 0.8, + }, + ] + ), + ), + ], +) +def test_inject_proposal( + modifier_args: list[AddModifierArg], + subpopulations: list[str], + proposal: npt.NDArray[np.float64], + snpi_df: pd.DataFrame, + hnpi_df: pd.DataFrame, + expected_snpi_df: pd.DataFrame, + expected_hnpi_df: pd.DataFrame, +) -> None: + """Test injecting a proposal into the modifiers with `InferenceParameters`.""" + # if subpopulations == ["USA", "Canada"]: + # import pdbpp; pdbpp.set_trace() + + inference_params = InferenceParameters(create_confuse_configview_from_dict({}), []) + for arg in modifier_args: + inference_params.add_modifier(**arg.as_dict() | {"subpops": subpopulations}) + modified_snpi_df, modified_hnpi_df = inference_params.inject_proposal( + proposal, snpi_df, hnpi_df + ) + pd.testing.assert_frame_equal( + modified_snpi_df.drop(columns=["value"]), snpi_df.drop(columns=["value"]) + ) + pd.testing.assert_frame_equal( + modified_hnpi_df.drop(columns=["value"]), hnpi_df.drop(columns=["value"]) + ) + pd.testing.assert_frame_equal(modified_snpi_df, expected_snpi_df, check_like=True) + pd.testing.assert_frame_equal(modified_hnpi_df, expected_hnpi_df, check_like=True) diff --git a/flepimop/gempyor_pkg/tests/npi/helpers/test_spatial_groups_class.py b/flepimop/gempyor_pkg/tests/npi/helpers/test_spatial_groups_class.py new file mode 100644 index 000000000..dbca3c3d5 --- /dev/null +++ b/flepimop/gempyor_pkg/tests/npi/helpers/test_spatial_groups_class.py @@ -0,0 +1,77 @@ +"""Unit tests for the `gempyor.NPI.helpers.SpatialGroups` class.""" + +import pytest +from gempyor.NPI.helpers import SpatialGroups + + +@pytest.mark.parametrize( + ("subpopulations", "subpopulation_groups", "expected_groups", "expected_ungrouped"), + [ + ( + ["A", "B", "C"], + None, + (), + ("A", "B", "C"), + ), + ( + ["A", "B", "C"], + [], + (), + ("A", "B", "C"), + ), + ( + ["A", "B", "C"], + [["A", "B"], ["C"]], + (("A", "B"), ("C",)), + (), + ), + ( + ["A", "B", "C"], + [["A", "B"]], + (("A", "B"),), + ("C",), + ), + ( + ["A", "B", "C"], + "all", + (("A", "B", "C"),), + (), + ), + ( + ["A", "B", "C", "D", "E", "F"], + [["A", "B", "C"], ["D", "E", "F"]], + (("A", "B", "C"), ("D", "E", "F")), + (), + ), + ( + ["A", "B", "C", "D", "E", "F"], + [["A", "B"], [], ["E", "F"]], + (("A", "B"), ("E", "F")), + ("C", "D"), + ), + ( + ["A", "B", "C", "D", "E", "F"], + [["A", "B"], [], ["E", "F"]], + (("A", "B"), ("E", "F")), + ("C", "D"), + ), + ( + ["VA", "NC", "SC", "GA", "FL"], + [["VA", "FL"], ["SC", "NC"]], + (("FL", "VA"), ("NC", "SC")), + ("GA",), + ), + ], +) +def test_from_subpopulations_for_exact_results_for_select_inputs( + subpopulations: list[str], + subpopulation_groups: list[list[str]] | list[str] | str | None, + expected_groups: tuple[tuple[str]], + expected_ungrouped: tuple[str], +): + """Test the exact results of `from_subpopulations` cls method for select inputs.""" + result = SpatialGroups.from_subpopulations(subpopulations, subpopulation_groups) + assert result.grouped == expected_groups + assert result.ungrouped == expected_ungrouped + assert result.ungrouped == tuple(sorted(result.ungrouped)) + assert all(tuple(sorted(group)) == group for group in result.grouped) diff --git a/flepimop/gempyor_pkg/tests/output/test_chains_class.py b/flepimop/gempyor_pkg/tests/output/test_chains_class.py new file mode 100644 index 000000000..f6dc72ea0 --- /dev/null +++ b/flepimop/gempyor_pkg/tests/output/test_chains_class.py @@ -0,0 +1,120 @@ +"""Unit tests for `gempyor.output.Chains` class.""" + +from datetime import date +from typing import Final + +import numpy as np +import pytest +from gempyor.output import Chains, ModifierInfo, ModifierInfoPeriod + +rng = np.random.default_rng(12345) + + +EXAMPLE_CHAINS_ONE: Final[Chains] = Chains( + shape=(4, 30, 2), + log_probability=-rng.lognormal(size=(4, 30)), + samples=rng.normal(size=(4, 30, 2)), + modifiers=[ + ModifierInfo( + kind="seir", + name="seasonal_beta", + subpops=["subpop1"], + periods=[ + ModifierInfoPeriod( + start_date=date(2020, 1, 1), + end_date=date(2020, 6, 30), + ), + ], + parameter="beta", + ), + ModifierInfo( + kind="outcome", + name="hospitalization_rate", + subpops=["subpop1", "subpop2"], + periods=[ + ModifierInfoPeriod( + start_date=date(2020, 1, 1), + end_date=date(2020, 12, 31), + ), + ], + parameter="hosp::probability", + ), + ], +) + + +def determine_new_shape(selector: list[int] | int | None, previous_shape: int) -> int: + """Determine the new shape dimension after subsetting.""" + if selector is None: + return previous_shape + if isinstance(selector, int): + return 1 + return len(selector) + + +@pytest.mark.parametrize( + ("chains", "which_chains", "which_iterations"), + [ + ( + EXAMPLE_CHAINS_ONE, + [0, 2], + [5, 10, 15], + ), + ( + EXAMPLE_CHAINS_ONE, + [0], + [5, 10, 15], + ), + ( + EXAMPLE_CHAINS_ONE, + [0, 2, 3], + [4], + ), + ( + EXAMPLE_CHAINS_ONE, + 0, + 0, + ), + ( + EXAMPLE_CHAINS_ONE, + None, + [0, 1, 2, 3, 4, 5], + ), + ( + EXAMPLE_CHAINS_ONE, + [0, 2], + None, + ), + ], +) +def test_subset_method( + chains: Chains, + which_chains: list[int] | int | None, + which_iterations: list[int] | int | None, +) -> None: + """Test the `subset` method of the `Chains` class.""" + chains_subset = chains.subset(chains=which_chains, iterations=which_iterations) + assert isinstance(chains_subset, Chains) + assert chains.modifiers == chains_subset.modifiers + assert all( + new_shape <= old_shape + for new_shape, old_shape in zip(chains_subset.shape, chains.shape) + ) + assert chains_subset.shape == ( + determine_new_shape(which_chains, chains.shape[0]), + determine_new_shape(which_iterations, chains.shape[1]), + chains.shape[2], + ) + if isinstance(which_chains, list) and isinstance(which_iterations, list): + for i, chain_idx in enumerate(which_chains): + for j, iter_idx in enumerate(which_iterations): + assert np.all( + np.isclose( + chains_subset.samples[i, j, :], + chains.samples[chain_idx, iter_idx, :], + ) + ) + assert np.isclose( + chains_subset.log_probability[i, j], + chains.log_probability[chain_idx, iter_idx], + )