Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

V1.9 #10

Open
wants to merge 27 commits into
base: main
Choose a base branch
from
Open

V1.9 #10

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,8 @@ example_scripts/**/*
!example_scripts/*
!example_scripts/**/*.py
configs/tests.yaml
build_pip.sh
.DS_Store
**/*pkl
**/*yaml
.vscode/settings.json
30 changes: 17 additions & 13 deletions grad_june/configs/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ data_path: '@grad_june/test/data/data.pkl'

save_path: ./example
age_bins_to_save: [0, 18, 65, 100]
store_differentiable_deaths: true
store_cases_by_age: true

timer:
total_days: 15
Expand Down Expand Up @@ -43,31 +45,33 @@ timer:
- household

infection_seed:
log_fraction_initial_cases: -1
type: InfectionSeedByFraction
params:
log_fraction: -1.5

networks:
household:
log_beta: -0.4
log_beta: 0.4
company:
log_beta: -0.3
log_beta: 0.3
school:
log_beta: -0.3
log_beta: 0.3
pub:
log_beta: -1.2
log_beta: 0.2
gym:
log_beta: -1.2
log_beta: 0.2
grocery:
log_beta: -1.2
log_beta: 0.2
visit:
log_beta: -1.2
log_beta: 0.2
cinema:
log_beta: -1.2
log_beta: 0.2
university:
log_beta: -0.5
log_beta: 0.5
care_visit:
log_beta: -0.4
log_beta: 0.4
care_home:
log_beta: -0.4
log_beta: 0.4

policies:
interaction:
Expand Down Expand Up @@ -112,7 +116,7 @@ transmission:
scale: 0.03
shift:
dist: Normal
loc: -2.12
loc: 2.12
scale: 0.1

symptoms:
Expand Down
63 changes: 63 additions & 0 deletions grad_june/demographics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import torch
from torch_geometric.data import HeteroData

def store_differentiable_deaths(data: HeteroData, dead_idx: int):
"""
Returns differentiable deaths. The results are stored
in data["results"]
"""
symptoms = data["agent"].symptoms
#dead_idx = self.model.symptoms_updater.stages_ids[-1]
deaths = (
(symptoms["current_stage"] == dead_idx)
* symptoms["current_stage"]
/ dead_idx
)
if data["results"]["deaths_per_timestep"] is not None:
data["results"]["deaths_per_timestep"] = torch.hstack(
(data["results"]["deaths_per_timestep"], deaths.sum())
)
else:
data["results"]["deaths_per_timestep"] = deaths.sum()

def get_cases_by_age(data: HeteroData, age_bins: torch.Tensor):
device = age_bins.device
ret = torch.zeros(age_bins.shape[0] - 1, device=device)
for i in range(1, age_bins.shape[0]):
mask1 = data["agent"].age < age_bins[i]
mask2 = data["agent"].age > age_bins[i - 1]
mask = mask1 * mask2
ret[i - 1] = (data["agent"].is_infected * mask).sum()
return ret

def get_people_by_age(ages: torch.Tensor, age_bins: torch.Tensor):
ret = {}
for i in range(1, age_bins.shape[0]):
mask1 = ages < age_bins[i]
mask2 = ages > age_bins[i - 1]
mask = mask1 * mask2
ret[int(age_bins[i].item())] = mask.sum()
return ret

def get_cases_by_ethnicity(data: HeteroData, ethnicities):
device = ethnicities.device
ret = torch.zeros(len(ethnicities), device=device)
for i, ethnicity in enumerate(ethnicities):
mask = torch.tensor(

Check warning on line 46 in grad_june/demographics.py

View check run for this annotation

Codecov / codecov/patch

grad_june/demographics.py#L43-L46

Added lines #L43 - L46 were not covered by tests
data["agent"].ethnicity == ethnicity, device=device
)
ret[i] = (mask * data["agent"].is_infected).sum()
return ret

Check warning on line 50 in grad_june/demographics.py

View check run for this annotation

Codecov / codecov/patch

grad_june/demographics.py#L49-L50

Added lines #L49 - L50 were not covered by tests

def get_people_per_area(agent_ids: torch.Tensor, area_ids: torch.Tensor):
"""Gets people ids in each area.

**Arguments:**

- `agent_ids`: Ids of all agents.
- `area_ids`: Area ids of all agents.
"""
people_per_area = {}
for area_id in torch.unique(area_ids):
people_per_area[area_id.item()] = agent_ids[area_ids == area_id]
return people_per_area
53 changes: 13 additions & 40 deletions grad_june/infection.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
from torch_geometric.data import HeteroData

class IsInfectedSampler(torch.nn.Module):
def forward(self, not_infected_probs):
Expand All @@ -14,50 +15,22 @@ def forward(self, not_infected_probs):
infection = torch.nn.functional.gumbel_softmax(
logits, dim=0, tau=0.1, hard=True
)
is_infected = 1.0 - infection[0, :]
return is_infected
return infection[1, :]

def infect_people(data: HeteroData, time: int, new_infected: torch.Tensor):
"""
Sets the `new_infected` individuals to infected at time `time`.

def infect_people(data, timer, new_infected):
**Arguments:**

- `data`: the graph data
- `time`: the time step at which the infection happens
- `new_infected`: a tensor of size [N] where N is the number of agents.
"""
data["agent"].susceptibility = torch.clamp(
data["agent"].susceptibility - new_infected, min=0.0
)
data["agent"].is_infected = data["agent"].is_infected + new_infected
data["agent"].infection_time = data["agent"].infection_time + new_infected * (
timer.now - data["agent"].infection_time
)


def infect_fraction_of_people(
data, timer, symptoms_updater, fraction, device
):
n_infections = data["agent"].susceptibility.shape[0]
n_agents = data["agent"].id.shape[0]
probs = fraction * torch.ones(n_agents, device=device)
sampler = IsInfectedSampler()
new_infected = sampler(
1.0 - probs
) # sampler takes not inf probs
infect_people(data, timer, new_infected)
return new_infected


def infect_people_at_indices(data, indices, device="cpu"):
susc = data["agent"]["susceptibility"].cpu().numpy()
is_inf = data["agent"]["is_infected"].cpu().numpy()
inf_t = data["agent"]["infection_time"].cpu().numpy()
next_stage = data["agent"]["symptoms"]["next_stage"].cpu().numpy()
current_stage = data["agent"]["symptoms"]["current_stage"].cpu().numpy()
susc[indices] = 0.0
is_inf[indices] = 1.0
inf_t[indices] = 0.0
next_stage[indices] = 2
current_stage[indices] = 1
data["agent"]["susceptibility"] = torch.tensor(susc, device=device)
data["agent"]["is_infected"] = torch.tensor(is_inf, device=device)
data["agent"]["infection_time"] = torch.tensor(inf_t, device=device)
data["agent"]["symptoms"]["next_stage"] = torch.tensor(next_stage, device=device)
data["agent"]["symptoms"]["current_stage"] = torch.tensor(
current_stage, device=device
)
return data
time - data["agent"].infection_time
)
20 changes: 10 additions & 10 deletions grad_june/infection_networks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

class InfectionNetwork(MessagePassing):
def __init__(self, log_beta, device="cpu"):
super().__init__( aggr="add", node_dim=-1)
super().__init__(aggr="add", node_dim=-1)
self.device = device
if type(log_beta) != torch.nn.Parameter:
self.log_beta = torch.tensor(float(log_beta))
Expand Down Expand Up @@ -38,11 +38,10 @@ def _get_beta(self, policies, timer, data):
beta = 10.0**self.log_beta
if interaction_policies:
beta = interaction_policies.apply(beta=beta, name=self.name, timer=timer)
beta = beta * torch.ones(len(data[self.name]["id"]), device=self.device)
return beta
return beta

def _get_people_per_group(self, data):
return data[self.name]["people"]
def _get_people_per_group(self, data, timer):
return data[self.name].people

def _get_transmissions(self, data, policies, timer):
if policies.quarantine_policies:
Expand All @@ -60,7 +59,7 @@ def _get_susceptibilities(self, data, policies, timer):

def forward(self, data, timer, policies):
beta = self._get_beta(policies=policies, timer=timer, data=data)
people_per_group = self._get_people_per_group(data)
people_per_group = self._get_people_per_group(data, timer)
p_contact = torch.maximum(
torch.minimum(
1.0 / (people_per_group - 1), torch.tensor(1.0, device=self.device)
Expand Down Expand Up @@ -134,12 +133,16 @@ def forward(
network = self.networks[activity]
trans_susc += network(data=data, timer=timer, policies=policies)
trans_susc = torch.clamp(
trans_susc, min=1e-6, max = 100
trans_susc, min=1e-6, max=100
) # this is necessary to avoid gradient nans
not_infected_probs = torch.exp(-trans_susc * delta_time)
not_infected_probs = torch.clamp(not_infected_probs, min=0.0, max=1.0)
return not_infected_probs

def __iter__(self):
for network in self.networks.values():
yield network


class HouseholdNetwork(InfectionNetwork):
def _get_transmissions(self, data, policies, timer):
Expand All @@ -148,8 +151,6 @@ def _get_transmissions(self, data, policies, timer):
def _get_susceptibilities(self, data, policies, timer):
return data["agent"].susceptibility

pass


class CareHomeNetwork(InfectionNetwork):
pass
Expand All @@ -165,4 +166,3 @@ class CompanyNetwork(InfectionNetwork):

class UniversityNetwork(InfectionNetwork):
pass

16 changes: 13 additions & 3 deletions grad_june/infection_networks/leisure_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,21 @@ def _get_beta(self, policies, timer, data):
beta = 10.0**self.log_beta
if interaction_policies:
beta = interaction_policies.apply(beta=beta, name=self.name, timer=timer)
beta = beta * torch.ones(len(data["leisure"]["id"]), device=self.device)
beta = beta
return beta

def _get_people_per_group(self, data):
return data["leisure"]["people"]
def _get_people_per_group(self, data, timer):
if self.weekday_probabilities is None:
self.initialize_leisure_probabilities(data)
if timer.day_type == "weekday":
leisure_mask = self.weekday_probabilities
else:
leisure_mask = self.weekend_probabilities
aux = torch.ones(len(data["leisure"]["id"]), device=self.device)
prob_leisure = leisure_mask
edge_index = self._get_edge_index(data)
people_per_group = self.propagate(edge_index, x=prob_leisure, y=aux)
return people_per_group

def _get_transmissions(self, data, policies, timer):
if self.weekday_probabilities is None:
Expand Down
Loading
Loading