|
| 1 | +"""Test pyfixest against R fixest using mtcars dataset.""" |
| 2 | + |
| 3 | +import numpy as np |
| 4 | +import pandas as pd |
| 5 | +import pytest |
| 6 | +import rpy2.robjects as ro |
| 7 | +from rpy2.robjects import pandas2ri |
| 8 | +from rpy2.robjects.packages import importr |
| 9 | + |
| 10 | +import pyfixest as pf |
| 11 | + |
| 12 | +pandas2ri.activate() |
| 13 | + |
| 14 | +fixest = importr("fixest") |
| 15 | +stats = importr("stats") |
| 16 | +broom = importr("broom") |
| 17 | + |
| 18 | +# Tolerance for comparing results |
| 19 | +rtol = 1e-06 |
| 20 | +atol = 1e-06 |
| 21 | + |
| 22 | + |
| 23 | +def check_absolute_diff(x1, x2, tol, msg=None): |
| 24 | + """Check for absolute differences.""" |
| 25 | + if isinstance(x1, (int, float)): |
| 26 | + x1 = np.array([x1]) |
| 27 | + if isinstance(x2, (int, float)): |
| 28 | + x2 = np.array([x2]) |
| 29 | + msg = "" if msg is None else msg |
| 30 | + |
| 31 | + # handle nan values |
| 32 | + nan_mask_x1 = np.isnan(x1) |
| 33 | + nan_mask_x2 = np.isnan(x2) |
| 34 | + |
| 35 | + if not np.array_equal(nan_mask_x1, nan_mask_x2): |
| 36 | + raise AssertionError(f"{msg}: NaN positions do not match") |
| 37 | + |
| 38 | + valid_mask = ~nan_mask_x1 # Mask for non-NaN elements (same for x1 and x2) |
| 39 | + assert np.all(np.abs(x1[valid_mask] - x2[valid_mask]) < tol), msg |
| 40 | + |
| 41 | + |
| 42 | +def _get_r_tidy(r_fit): |
| 43 | + """Get tidied results from R fixest fit.""" |
| 44 | + tidied_r = broom.tidy_fixest(r_fit, conf_int=ro.BoolVector([False])) |
| 45 | + df_r = pd.DataFrame(tidied_r).T |
| 46 | + df_r.columns = ["term", "estimate", "std.error", "statistic", "p.value"] |
| 47 | + return df_r.set_index("term") |
| 48 | + |
| 49 | + |
| 50 | +@pytest.fixture(scope="module") |
| 51 | +def mtcars_data(): |
| 52 | + """Load mtcars dataset from R.""" |
| 53 | + mtcars = ro.r["mtcars"] |
| 54 | + return pandas2ri.rpy2py(mtcars) |
| 55 | + |
| 56 | + |
| 57 | +@pytest.mark.against_r_core |
| 58 | +@pytest.mark.parametrize("formula", [ |
| 59 | + "mpg ~ hp + wt + C(cyl)", |
| 60 | + "mpg ~ hp + wt + C(cyl) + C(gear)", |
| 61 | + "mpg ~ hp * wt + C(cyl)", |
| 62 | +]) |
| 63 | +@pytest.mark.parametrize("vcov", ["iid", "hetero"]) |
| 64 | +def test_feols_mtcars(mtcars_data, formula, vcov): |
| 65 | + """Test feols against R fixest using mtcars.""" |
| 66 | + # Convert formula for R |
| 67 | + r_formula = formula.replace("C(", "factor(") |
| 68 | + |
| 69 | + # Fit in R |
| 70 | + r_fit = fixest.feols( |
| 71 | + ro.Formula(r_formula), |
| 72 | + data=mtcars_data, |
| 73 | + vcov=vcov, |
| 74 | + ssc=fixest.ssc(True, "none", False, True, "min", "min"), |
| 75 | + ) |
| 76 | + |
| 77 | + # Fit in Python |
| 78 | + py_fit = pf.feols( |
| 79 | + fml=formula, |
| 80 | + data=mtcars_data, |
| 81 | + vcov=vcov, |
| 82 | + ssc=pf.ssc(k_adj=True, k_fixef="none", G_adj=True), |
| 83 | + ) |
| 84 | + |
| 85 | + # Get results |
| 86 | + r_tidy = _get_r_tidy(r_fit) |
| 87 | + py_tidy = py_fit.tidy() |
| 88 | + |
| 89 | + # Compare coefficient for 'wt' |
| 90 | + r_coef = r_tidy.loc["wt", "estimate"] |
| 91 | + py_coef = py_tidy.loc["wt", "Estimate"] |
| 92 | + check_absolute_diff(py_coef, r_coef, atol, f"Coefficients don't match for {formula}") |
| 93 | + |
| 94 | + # Compare standard error for 'wt' |
| 95 | + r_se = r_tidy.loc["wt", "std.error"] |
| 96 | + py_se = py_tidy.loc["wt", "Std. Error"] |
| 97 | + check_absolute_diff(py_se, r_se, atol, f"Standard errors don't match for {formula}, vcov={vcov}") |
| 98 | + |
| 99 | + |
| 100 | +@pytest.mark.against_r_core |
| 101 | +@pytest.mark.parametrize("formula", [ |
| 102 | + "mpg ~ hp + wt + C(cyl)", |
| 103 | + "mpg ~ hp + wt + C(cyl) + C(gear)", |
| 104 | + "mpg ~ hp * wt + C(cyl)", |
| 105 | +]) |
| 106 | +@pytest.mark.parametrize("vcov", ["iid", "hetero"]) |
| 107 | +def test_feglm_gaussian_mtcars(mtcars_data, formula, vcov): |
| 108 | + """Test feglm with Gaussian family against R fixest using mtcars.""" |
| 109 | + # Convert formula for R |
| 110 | + r_formula = formula.replace("C(", "factor(") |
| 111 | + |
| 112 | + # Fit in R |
| 113 | + r_fit = fixest.feglm( |
| 114 | + ro.Formula(r_formula), |
| 115 | + data=mtcars_data, |
| 116 | + family=stats.gaussian(), |
| 117 | + vcov=vcov, |
| 118 | + ) |
| 119 | + |
| 120 | + # Fit in Python |
| 121 | + py_fit = pf.feglm( |
| 122 | + fml=formula, |
| 123 | + data=mtcars_data, |
| 124 | + family="gaussian", |
| 125 | + vcov=vcov, |
| 126 | + ) |
| 127 | + |
| 128 | + # Get results |
| 129 | + r_tidy = _get_r_tidy(r_fit) |
| 130 | + py_tidy = py_fit.tidy() |
| 131 | + |
| 132 | + # Compare coefficient for 'wt' |
| 133 | + r_coef = r_tidy.loc["wt", "estimate"] |
| 134 | + py_coef = py_tidy.loc["wt", "Estimate"] |
| 135 | + check_absolute_diff(py_coef, r_coef, atol, f"Coefficients don't match for {formula}") |
| 136 | + |
| 137 | + # Compare standard error for 'wt' |
| 138 | + r_se = r_tidy.loc["wt", "std.error"] |
| 139 | + py_se = py_tidy.loc["wt", "Std. Error"] |
| 140 | + check_absolute_diff( |
| 141 | + py_se, r_se, atol, |
| 142 | + f"Standard errors don't match for feglm {formula}, vcov={vcov}" |
| 143 | + ) |
| 144 | + |
| 145 | + |
| 146 | +@pytest.mark.against_r_core |
| 147 | +@pytest.mark.parametrize("formula", [ |
| 148 | + "mpg ~ hp + wt | cyl", |
| 149 | + "mpg ~ hp + wt | cyl + gear", |
| 150 | +]) |
| 151 | +@pytest.mark.parametrize("vcov", ["iid", "hetero"]) |
| 152 | +def test_feglm_gaussian_with_fe_mtcars(mtcars_data, formula, vcov): |
| 153 | + """Test feglm with Gaussian family and fixed effects against R fixest using mtcars.""" |
| 154 | + # Fit in R |
| 155 | + r_fit = fixest.feglm( |
| 156 | + ro.Formula(formula), |
| 157 | + data=mtcars_data, |
| 158 | + family=stats.gaussian(), |
| 159 | + vcov=vcov, |
| 160 | + ) |
| 161 | + |
| 162 | + # Fit in Python |
| 163 | + py_fit = pf.feglm( |
| 164 | + fml=formula, |
| 165 | + data=mtcars_data, |
| 166 | + family="gaussian", |
| 167 | + vcov=vcov, |
| 168 | + ) |
| 169 | + |
| 170 | + # Get results |
| 171 | + r_tidy = _get_r_tidy(r_fit) |
| 172 | + py_tidy = py_fit.tidy() |
| 173 | + |
| 174 | + # Compare coefficient for 'wt' |
| 175 | + r_coef = r_tidy.loc["wt", "estimate"] |
| 176 | + py_coef = py_tidy.loc["wt", "Estimate"] |
| 177 | + check_absolute_diff( |
| 178 | + py_coef, r_coef, atol, |
| 179 | + f"Coefficients don't match for {formula} with FE" |
| 180 | + ) |
| 181 | + |
| 182 | + # Compare standard error for 'wt' |
| 183 | + r_se = r_tidy.loc["wt", "std.error"] |
| 184 | + py_se = py_tidy.loc["wt", "Std. Error"] |
| 185 | + check_absolute_diff( |
| 186 | + py_se, r_se, atol, |
| 187 | + f"Standard errors don't match for feglm {formula} with FE, vcov={vcov}" |
| 188 | + ) |
| 189 | + |
| 190 | + |
0 commit comments