diff --git a/yli/regress.py b/yli/regress.py index b7af386..b386302 100644 --- a/yli/regress.py +++ b/yli/regress.py @@ -29,7 +29,7 @@ import warnings from .bayes_factors import BayesFactor, bayesfactor_afbf from .config import config from .sig_tests import FTestResult -from .utils import Estimate, check_nan, cols_for_formula, fmt_p, formula_factor_ref_category, parse_patsy_term +from .utils import Estimate, check_nan, cols_for_formula, convert_pandas_nullable, fmt_p, formula_factor_ref_category, parse_patsy_term def vif(df, formula=None, *, nan_policy='warn'): """ @@ -534,10 +534,8 @@ def regress( if df[dep].dtype != 'float64': df[dep] = df[dep].astype('float64') - # Convert pandas nullable types for independent variables - for col in df.columns: - if df[col].dtype == 'Int64': - df[col] = df[col].astype('float64') + # Convert pandas nullable types for independent variables as this breaks statsmodels + df = convert_pandas_nullable(df) # Fit model model = model_class.from_formula(formula=dep + ' ~ ' + formula, data=df, **model_kwargs) diff --git a/yli/sig_tests.py b/yli/sig_tests.py index 5d63321..2a63819 100644 --- a/yli/sig_tests.py +++ b/yli/sig_tests.py @@ -23,7 +23,7 @@ import functools import warnings from .config import config -from .utils import Estimate, as_2groups, check_nan, fmt_p +from .utils import Estimate, as_2groups, check_nan, convert_pandas_nullable, fmt_p # ---------------- # Student's t test @@ -334,6 +334,9 @@ def mannwhitney(df, dep, ind, *, nan_policy='warn', brunnermunzel=True, use_cont # Check for/clean NaNs df = check_nan(df[[ind, dep]], nan_policy) + # Convert pandas nullable types for independent variables as this breaks statsmodels + df = convert_pandas_nullable(df) + # Ensure 2 groups for ind group1, data1, group2, data2 = as_2groups(df, dep, ind) diff --git a/yli/utils.py b/yli/utils.py index 5887505..a19acc5 100644 --- a/yli/utils.py +++ b/yli/utils.py @@ -41,6 +41,18 @@ def check_nan(df, nan_policy): else: raise Exception('Invalid nan_policy, expected "raise", "warn" or "omit"') +def convert_pandas_nullable(df): + """Convert pandas nullable dtypes (e.g. Int64) to non-nullable numpy dtypes""" + + # TODO: Can we avoid this copy? + df = df.copy() + + for col in df.columns: + if df[col].dtype == 'Int64': + df[col] = df[col].astype('int') + + return df + def as_2groups(df, data, group): """Group the data by the given variable, ensuring only 2 groups"""