Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/datajoint/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,7 +728,7 @@ def to_polars(self, order_by=None, limit=None, offset=None, squeeze=False):
try:
import polars
except ImportError:
raise ImportError("polars is required for to_polars(). " "Install with: pip install datajoint[polars]")
raise ImportError("polars is required for to_polars(). Install with: pip install datajoint[polars]")
dicts = self.to_dicts(order_by=order_by, limit=limit, offset=offset, squeeze=squeeze)
return polars.DataFrame(dicts)

Expand All @@ -747,7 +747,7 @@ def to_arrow(self, order_by=None, limit=None, offset=None, squeeze=False):
try:
import pyarrow
except ImportError:
raise ImportError("pyarrow is required for to_arrow(). " "Install with: pip install datajoint[arrow]")
raise ImportError("pyarrow is required for to_arrow(). Install with: pip install datajoint[arrow]")
dicts = self.to_dicts(order_by=order_by, limit=limit, offset=offset, squeeze=squeeze)
if not dicts:
return pyarrow.table({})
Expand Down Expand Up @@ -1039,7 +1039,7 @@ def __len__(self):
).fetchone()[0]

def __bool__(self):
return bool(self.connection.query("SELECT EXISTS({sql})".format(sql=self.make_sql())))
return bool(self.connection.query("SELECT EXISTS({sql})".format(sql=self.make_sql())).fetchone()[0])


class Union(QueryExpression):
Expand Down Expand Up @@ -1101,7 +1101,7 @@ def __len__(self):
).fetchone()[0]

def __bool__(self):
return bool(self.connection.query("SELECT EXISTS({sql})".format(sql=self.make_sql())))
return bool(self.connection.query("SELECT EXISTS({sql})".format(sql=self.make_sql())).fetchone()[0])


class U:
Expand Down
67 changes: 67 additions & 0 deletions tests/integration/test_aggr_regressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,3 +180,70 @@ def test_extend_invalid_raises_error(schema_uuid):
with pytest.raises(DataJointError) as exc_info:
Topic.extend(Item)
assert "left operand to determine" in str(exc_info.value).lower()


class TestBoolMethod:
"""
Tests for __bool__ method on Aggregation and Union (issue #1234).

bool(query) should return True if query has rows, False if empty.
"""

def test_aggregation_bool_with_results(self, schema_aggr_reg_with_abx):
"""Aggregation with results should be truthy."""
A.insert([(1,), (2,), (3,)])
B.insert([(1, 10), (1, 20), (2, 30)])
aggr = A.aggr(B, count="count(id2)")
assert bool(aggr) is True
assert len(aggr) > 0

def test_aggregation_bool_empty(self, schema_aggr_reg_with_abx):
"""Aggregation with no results should be falsy."""
A.insert([(1,), (2,), (3,)])
B.insert([(1, 10), (1, 20), (2, 30)])
# Restrict to non-existent entry
aggr = (A & "id=999").aggr(B, count="count(id2)")
assert bool(aggr) is False
assert len(aggr) == 0

def test_aggregation_bool_matches_len(self, schema_aggr_reg_with_abx):
"""bool(aggr) should equal len(aggr) > 0."""
A.insert([(10,), (20,)])
B.insert([(10, 100)])
# With results
aggr_has = A.aggr(B, count="count(id2)")
assert bool(aggr_has) == (len(aggr_has) > 0)
# Without results
aggr_empty = (A & "id=999").aggr(B, count="count(id2)")
assert bool(aggr_empty) == (len(aggr_empty) > 0)

def test_union_bool_with_results(self, schema_aggr_reg_with_abx):
"""Union with results should be truthy."""
A.insert([(100,), (200,)])
B.insert([(100, 1), (200, 2)])
q1 = B & "id=100"
q2 = B & "id=200"
union = q1 + q2
assert bool(union) is True
assert len(union) > 0

def test_union_bool_empty(self, schema_aggr_reg_with_abx):
"""Union with no results should be falsy."""
A.insert([(100,), (200,)])
B.insert([(100, 1), (200, 2)])
q1 = B & "id=999"
q2 = B & "id=998"
union = q1 + q2
assert bool(union) is False
assert len(union) == 0

def test_union_bool_matches_len(self, schema_aggr_reg_with_abx):
"""bool(union) should equal len(union) > 0."""
A.insert([(100,), (200,)])
B.insert([(100, 1)])
# With results
union_has = (B & "id=100") + (B & "id=100")
assert bool(union_has) == (len(union_has) > 0)
# Without results
union_empty = (B & "id=999") + (B & "id=998")
assert bool(union_empty) == (len(union_empty) > 0)
Loading