diff --git a/src/datajoint/expression.py b/src/datajoint/expression.py index 971a1ee5e..354e3ef35 100644 --- a/src/datajoint/expression.py +++ b/src/datajoint/expression.py @@ -728,7 +728,7 @@ def to_polars(self, order_by=None, limit=None, offset=None, squeeze=False): try: import polars except ImportError: - raise ImportError("polars is required for to_polars(). " "Install with: pip install datajoint[polars]") + raise ImportError("polars is required for to_polars(). Install with: pip install datajoint[polars]") dicts = self.to_dicts(order_by=order_by, limit=limit, offset=offset, squeeze=squeeze) return polars.DataFrame(dicts) @@ -747,7 +747,7 @@ def to_arrow(self, order_by=None, limit=None, offset=None, squeeze=False): try: import pyarrow except ImportError: - raise ImportError("pyarrow is required for to_arrow(). " "Install with: pip install datajoint[arrow]") + raise ImportError("pyarrow is required for to_arrow(). Install with: pip install datajoint[arrow]") dicts = self.to_dicts(order_by=order_by, limit=limit, offset=offset, squeeze=squeeze) if not dicts: return pyarrow.table({}) @@ -1039,7 +1039,7 @@ def __len__(self): ).fetchone()[0] def __bool__(self): - return bool(self.connection.query("SELECT EXISTS({sql})".format(sql=self.make_sql()))) + return bool(self.connection.query("SELECT EXISTS({sql})".format(sql=self.make_sql())).fetchone()[0]) class Union(QueryExpression): @@ -1101,7 +1101,7 @@ def __len__(self): ).fetchone()[0] def __bool__(self): - return bool(self.connection.query("SELECT EXISTS({sql})".format(sql=self.make_sql()))) + return bool(self.connection.query("SELECT EXISTS({sql})".format(sql=self.make_sql())).fetchone()[0]) class U: diff --git a/tests/integration/test_aggr_regressions.py b/tests/integration/test_aggr_regressions.py index 22d10c676..cf4f920b0 100644 --- a/tests/integration/test_aggr_regressions.py +++ b/tests/integration/test_aggr_regressions.py @@ -180,3 +180,70 @@ def test_extend_invalid_raises_error(schema_uuid): with pytest.raises(DataJointError) as exc_info: Topic.extend(Item) assert "left operand to determine" in str(exc_info.value).lower() + + +class TestBoolMethod: + """ + Tests for __bool__ method on Aggregation and Union (issue #1234). + + bool(query) should return True if query has rows, False if empty. + """ + + def test_aggregation_bool_with_results(self, schema_aggr_reg_with_abx): + """Aggregation with results should be truthy.""" + A.insert([(1,), (2,), (3,)]) + B.insert([(1, 10), (1, 20), (2, 30)]) + aggr = A.aggr(B, count="count(id2)") + assert bool(aggr) is True + assert len(aggr) > 0 + + def test_aggregation_bool_empty(self, schema_aggr_reg_with_abx): + """Aggregation with no results should be falsy.""" + A.insert([(1,), (2,), (3,)]) + B.insert([(1, 10), (1, 20), (2, 30)]) + # Restrict to non-existent entry + aggr = (A & "id=999").aggr(B, count="count(id2)") + assert bool(aggr) is False + assert len(aggr) == 0 + + def test_aggregation_bool_matches_len(self, schema_aggr_reg_with_abx): + """bool(aggr) should equal len(aggr) > 0.""" + A.insert([(10,), (20,)]) + B.insert([(10, 100)]) + # With results + aggr_has = A.aggr(B, count="count(id2)") + assert bool(aggr_has) == (len(aggr_has) > 0) + # Without results + aggr_empty = (A & "id=999").aggr(B, count="count(id2)") + assert bool(aggr_empty) == (len(aggr_empty) > 0) + + def test_union_bool_with_results(self, schema_aggr_reg_with_abx): + """Union with results should be truthy.""" + A.insert([(100,), (200,)]) + B.insert([(100, 1), (200, 2)]) + q1 = B & "id=100" + q2 = B & "id=200" + union = q1 + q2 + assert bool(union) is True + assert len(union) > 0 + + def test_union_bool_empty(self, schema_aggr_reg_with_abx): + """Union with no results should be falsy.""" + A.insert([(100,), (200,)]) + B.insert([(100, 1), (200, 2)]) + q1 = B & "id=999" + q2 = B & "id=998" + union = q1 + q2 + assert bool(union) is False + assert len(union) == 0 + + def test_union_bool_matches_len(self, schema_aggr_reg_with_abx): + """bool(union) should equal len(union) > 0.""" + A.insert([(100,), (200,)]) + B.insert([(100, 1)]) + # With results + union_has = (B & "id=100") + (B & "id=100") + assert bool(union_has) == (len(union_has) > 0) + # Without results + union_empty = (B & "id=999") + (B & "id=998") + assert bool(union_empty) == (len(union_empty) > 0)