Skip to content

Commit e3badf0

Browse files
authored
NumPy fusion type coercion fix (#703)
* Fix type coercion in fusion pass * Add fusion test case
1 parent fd8073f commit e3badf0

File tree

2 files changed

+19
-2
lines changed

2 files changed

+19
-2
lines changed

codon/cir/transform/numpy/expr.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,17 +63,25 @@ types::Type *coerceScalarArray(NumPyType &scalar, NumPyType &array,
6363
return scalar.getIRBaseType(T);
6464
}
6565

66+
bool isPythonScalar(NumPyType &t) {
67+
if (t.isArray())
68+
return false;
69+
auto dt = t.dtype;
70+
return (dt == NumPyType::NP_TYPE_BOOL || dt == NumPyType::NP_TYPE_I64 ||
71+
dt == NumPyType::NP_TYPE_F64 || dt == NumPyType::NP_TYPE_C128);
72+
}
73+
6674
template <typename E>
6775
types::Type *decideTypes(E *expr, NumPyType &lhs, NumPyType &rhs,
6876
NumPyPrimitiveTypes &T) {
6977
// Special case(s)
7078
if (expr->op == E::NP_OP_COPYSIGN)
7179
return expr->type.getIRBaseType(T);
7280

73-
if (lhs.isArray() && !rhs.isArray())
81+
if (lhs.isArray() && isPythonScalar(rhs))
7482
return coerceScalarArray(rhs, lhs, T);
7583

76-
if (!lhs.isArray() && rhs.isArray())
84+
if (isPythonScalar(lhs) && rhs.isArray())
7785
return coerceScalarArray(lhs, rhs, T);
7886

7987
auto *t1 = lhs.getIRBaseType(T);

test/numpy/test_routines.codon

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2479,3 +2479,12 @@ def test_dragon4_interface():
24792479
assert equal(fpos(f('1.001', tp), precision=1, trim='-'), "1")
24802480

24812481
test_dragon4_interface()
2482+
2483+
@test
2484+
def test_fusion_coercion():
2485+
# more extensive tests in test_fusion.codon
2486+
a = np.array([1.0, 2.0])
2487+
b = np.multiply(a, np.float32(2)) + np.float32(0.5)
2488+
assert np.array_equal(b, [2.5, 4.5])
2489+
2490+
test_fusion_coercion()

0 commit comments

Comments
 (0)