Skip to content

Commit 24df467

Browse files
committed
adds an mtcars test
1 parent 6429482 commit 24df467

File tree

1 file changed

+190
-0
lines changed

1 file changed

+190
-0
lines changed

tests/test_mtcars_comparison.py

Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
"""Test pyfixest against R fixest using mtcars dataset."""
2+
3+
import numpy as np
4+
import pandas as pd
5+
import pytest
6+
import rpy2.robjects as ro
7+
from rpy2.robjects import pandas2ri
8+
from rpy2.robjects.packages import importr
9+
10+
import pyfixest as pf
11+
12+
pandas2ri.activate()
13+
14+
fixest = importr("fixest")
15+
stats = importr("stats")
16+
broom = importr("broom")
17+
18+
# Tolerance for comparing results
19+
rtol = 1e-06
20+
atol = 1e-06
21+
22+
23+
def check_absolute_diff(x1, x2, tol, msg=None):
24+
"""Check for absolute differences."""
25+
if isinstance(x1, (int, float)):
26+
x1 = np.array([x1])
27+
if isinstance(x2, (int, float)):
28+
x2 = np.array([x2])
29+
msg = "" if msg is None else msg
30+
31+
# handle nan values
32+
nan_mask_x1 = np.isnan(x1)
33+
nan_mask_x2 = np.isnan(x2)
34+
35+
if not np.array_equal(nan_mask_x1, nan_mask_x2):
36+
raise AssertionError(f"{msg}: NaN positions do not match")
37+
38+
valid_mask = ~nan_mask_x1 # Mask for non-NaN elements (same for x1 and x2)
39+
assert np.all(np.abs(x1[valid_mask] - x2[valid_mask]) < tol), msg
40+
41+
42+
def _get_r_tidy(r_fit):
43+
"""Get tidied results from R fixest fit."""
44+
tidied_r = broom.tidy_fixest(r_fit, conf_int=ro.BoolVector([False]))
45+
df_r = pd.DataFrame(tidied_r).T
46+
df_r.columns = ["term", "estimate", "std.error", "statistic", "p.value"]
47+
return df_r.set_index("term")
48+
49+
50+
@pytest.fixture(scope="module")
51+
def mtcars_data():
52+
"""Load mtcars dataset from R."""
53+
mtcars = ro.r["mtcars"]
54+
return pandas2ri.rpy2py(mtcars)
55+
56+
57+
@pytest.mark.against_r_core
58+
@pytest.mark.parametrize("formula", [
59+
"mpg ~ hp + wt + C(cyl)",
60+
"mpg ~ hp + wt + C(cyl) + C(gear)",
61+
"mpg ~ hp * wt + C(cyl)",
62+
])
63+
@pytest.mark.parametrize("vcov", ["iid", "hetero"])
64+
def test_feols_mtcars(mtcars_data, formula, vcov):
65+
"""Test feols against R fixest using mtcars."""
66+
# Convert formula for R
67+
r_formula = formula.replace("C(", "factor(")
68+
69+
# Fit in R
70+
r_fit = fixest.feols(
71+
ro.Formula(r_formula),
72+
data=mtcars_data,
73+
vcov=vcov,
74+
ssc=fixest.ssc(True, "none", False, True, "min", "min"),
75+
)
76+
77+
# Fit in Python
78+
py_fit = pf.feols(
79+
fml=formula,
80+
data=mtcars_data,
81+
vcov=vcov,
82+
ssc=pf.ssc(k_adj=True, k_fixef="none", G_adj=True),
83+
)
84+
85+
# Get results
86+
r_tidy = _get_r_tidy(r_fit)
87+
py_tidy = py_fit.tidy()
88+
89+
# Compare coefficient for 'wt'
90+
r_coef = r_tidy.loc["wt", "estimate"]
91+
py_coef = py_tidy.loc["wt", "Estimate"]
92+
check_absolute_diff(py_coef, r_coef, atol, f"Coefficients don't match for {formula}")
93+
94+
# Compare standard error for 'wt'
95+
r_se = r_tidy.loc["wt", "std.error"]
96+
py_se = py_tidy.loc["wt", "Std. Error"]
97+
check_absolute_diff(py_se, r_se, atol, f"Standard errors don't match for {formula}, vcov={vcov}")
98+
99+
100+
@pytest.mark.against_r_core
101+
@pytest.mark.parametrize("formula", [
102+
"mpg ~ hp + wt + C(cyl)",
103+
"mpg ~ hp + wt + C(cyl) + C(gear)",
104+
"mpg ~ hp * wt + C(cyl)",
105+
])
106+
@pytest.mark.parametrize("vcov", ["iid", "hetero"])
107+
def test_feglm_gaussian_mtcars(mtcars_data, formula, vcov):
108+
"""Test feglm with Gaussian family against R fixest using mtcars."""
109+
# Convert formula for R
110+
r_formula = formula.replace("C(", "factor(")
111+
112+
# Fit in R
113+
r_fit = fixest.feglm(
114+
ro.Formula(r_formula),
115+
data=mtcars_data,
116+
family=stats.gaussian(),
117+
vcov=vcov,
118+
)
119+
120+
# Fit in Python
121+
py_fit = pf.feglm(
122+
fml=formula,
123+
data=mtcars_data,
124+
family="gaussian",
125+
vcov=vcov,
126+
)
127+
128+
# Get results
129+
r_tidy = _get_r_tidy(r_fit)
130+
py_tidy = py_fit.tidy()
131+
132+
# Compare coefficient for 'wt'
133+
r_coef = r_tidy.loc["wt", "estimate"]
134+
py_coef = py_tidy.loc["wt", "Estimate"]
135+
check_absolute_diff(py_coef, r_coef, atol, f"Coefficients don't match for {formula}")
136+
137+
# Compare standard error for 'wt'
138+
r_se = r_tidy.loc["wt", "std.error"]
139+
py_se = py_tidy.loc["wt", "Std. Error"]
140+
check_absolute_diff(
141+
py_se, r_se, atol,
142+
f"Standard errors don't match for feglm {formula}, vcov={vcov}"
143+
)
144+
145+
146+
@pytest.mark.against_r_core
147+
@pytest.mark.parametrize("formula", [
148+
"mpg ~ hp + wt | cyl",
149+
"mpg ~ hp + wt | cyl + gear",
150+
])
151+
@pytest.mark.parametrize("vcov", ["iid", "hetero"])
152+
def test_feglm_gaussian_with_fe_mtcars(mtcars_data, formula, vcov):
153+
"""Test feglm with Gaussian family and fixed effects against R fixest using mtcars."""
154+
# Fit in R
155+
r_fit = fixest.feglm(
156+
ro.Formula(formula),
157+
data=mtcars_data,
158+
family=stats.gaussian(),
159+
vcov=vcov,
160+
)
161+
162+
# Fit in Python
163+
py_fit = pf.feglm(
164+
fml=formula,
165+
data=mtcars_data,
166+
family="gaussian",
167+
vcov=vcov,
168+
)
169+
170+
# Get results
171+
r_tidy = _get_r_tidy(r_fit)
172+
py_tidy = py_fit.tidy()
173+
174+
# Compare coefficient for 'wt'
175+
r_coef = r_tidy.loc["wt", "estimate"]
176+
py_coef = py_tidy.loc["wt", "Estimate"]
177+
check_absolute_diff(
178+
py_coef, r_coef, atol,
179+
f"Coefficients don't match for {formula} with FE"
180+
)
181+
182+
# Compare standard error for 'wt'
183+
r_se = r_tidy.loc["wt", "std.error"]
184+
py_se = py_tidy.loc["wt", "Std. Error"]
185+
check_absolute_diff(
186+
py_se, r_se, atol,
187+
f"Standard errors don't match for feglm {formula} with FE, vcov={vcov}"
188+
)
189+
190+

0 commit comments

Comments
 (0)