Skip to content

ValueError for predictions after transform with mismatched N tasks #175

@colinshew

Description

@colinshew

Solved but could benefit from an improved error message! After adding a transform that reduces the number of tasks, .predict_on_dataset() gives a ValueError related to mismatched shapes, if return_df=True. Returning an array works fine., as well as .predict_on_seqs()

transform = Specificity(
    on_tasks = [name],
    on_aggfunc = "min",
    off_tasks = [x for x in cell_types if x != name],
    off_aggfunc = "max",
    model = model,
)
model.add_transform(transform)

seqs_ds = grelu.data.dataset.SeqDataset(seqs)
probs = model.predict_on_dataset(seqs_ds, devices=0, num_workers=7)
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[39], line 1
----> 1 probs = model.predict_on_dataset(seqs_ds, devices=0, num_workers=7, return_df=True)

File ~/scratch/conda/envs/grelu_stable/lib/python3.11/site-packages/grelu/lightning/__init__.py:782, in LightningModel.predict_on_dataset(self, dataset, devices, num_workers, batch_size, augment_aggfunc, return_df, precision)
    780 if return_df:
    781     if (preds.ndim == 3) and (preds.shape[-1] == 1):
--> 782         preds = pd.DataFrame(
    783             preds.squeeze(-1), columns=self.data_params["tasks"]["name"]
    784         )
    785     else:
    786         warnings.warn(
    787             "Cannot produce dataframe output."
    788             + "Either output length > 1 or augmented sequences are not aggregated."
    789         )

File ~/scratch/conda/envs/grelu_stable/lib/python3.11/site-packages/pandas/core/frame.py:831, in DataFrame.__init__(self, data, index, columns, dtype, copy)
    820         mgr = dict_to_mgr(
    821             # error: Item "ndarray" of "Union[ndarray, Series, Index]" has no
    822             # attribute "name"
   (...)    828             copy=_copy,
    829         )
    830     else:
--> 831         mgr = ndarray_to_mgr(
    832             data,
    833             index,
    834             columns,
    835             dtype=dtype,
    836             copy=copy,
    837             typ=manager,
    838         )
    840 # For data is list-like, or Iterable (will consume into list)
    841 elif is_list_like(data):

File ~/scratch/conda/envs/grelu_stable/lib/python3.11/site-packages/pandas/core/internals/construction.py:336, in ndarray_to_mgr(values, index, columns, dtype, copy, typ)
    331 # _prep_ndarraylike ensures that values.ndim == 2 at this point
    332 index, columns = _get_axes(
    333     values.shape[0], values.shape[1], index=index, columns=columns
    334 )
--> 336 _check_values_indices_shape_match(values, index, columns)
    338 if typ == "array":
    339     if issubclass(values.dtype.type, str):

File ~/scratch/conda/envs/grelu_stable/lib/python3.11/site-packages/pandas/core/internals/construction.py:420, in _check_values_indices_shape_match(values, index, columns)
    418 passed = values.shape
    419 implied = (len(index), len(columns))
--> 420 raise ValueError(f"Shape of passed values is {passed}, indices imply {implied}")

ValueError: Shape of passed values is (5078, 1), indices imply (5078, 8)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions