diff --git a/src/grelu/model/models.py b/src/grelu/model/models.py index 59292ed..0d3dce8 100644 --- a/src/grelu/model/models.py +++ b/src/grelu/model/models.py @@ -772,15 +772,21 @@ def __init__( self, n_tasks: int, n_transformers: int = 11, + organism="human", # head crop_len=0, final_pool_func="avg", dtype=None, device=None, ): + if organism == "human": + n_tasks_default = 5313 + elif organism == "mouse": + n_tasks_default = 1643 + model = EnformerModel( crop_len=crop_len, - n_tasks=5313, + n_tasks=n_tasks_default, channels=1536, n_transformers=11, n_heads=8, @@ -797,10 +803,10 @@ def __init__( # Load state dict from grelu.resources import get_artifact - art = get_artifact("human_state_dict", project="enformer", alias="latest") + art = get_artifact(f"{organism}_state_dict", project="enformer", alias="latest") with TemporaryDirectory() as d: art.download(d) - state_dict = torch.load(Path(d) / "human.h5") + state_dict = torch.load(Path(d) / f"{organism}.h5") model.load_state_dict(state_dict) @@ -809,7 +815,6 @@ def __init__( model.embedding.transformer_tower.blocks[:n_transformers] ) - # Change head head = ConvHead( n_tasks=n_tasks, in_channels=3072,