Make yli.mannwhitney support pandas nullable dtypes

This commit is contained in:
RunasSudo 2022-10-20 20:58:42 +11:00
parent ee36ac9d14
commit b31ae6686f
Signed by: RunasSudo
GPG Key ID: 7234E476BF21C61A
3 changed files with 19 additions and 6 deletions

View File

@ -29,7 +29,7 @@ import warnings
from .bayes_factors import BayesFactor, bayesfactor_afbf
from .config import config
from .sig_tests import FTestResult
from .utils import Estimate, check_nan, cols_for_formula, fmt_p, formula_factor_ref_category, parse_patsy_term
from .utils import Estimate, check_nan, cols_for_formula, convert_pandas_nullable, fmt_p, formula_factor_ref_category, parse_patsy_term
def vif(df, formula=None, *, nan_policy='warn'):
"""
@ -534,10 +534,8 @@ def regress(
if df[dep].dtype != 'float64':
df[dep] = df[dep].astype('float64')
# Convert pandas nullable types for independent variables
for col in df.columns:
if df[col].dtype == 'Int64':
df[col] = df[col].astype('float64')
# Convert pandas nullable types for independent variables as this breaks statsmodels
df = convert_pandas_nullable(df)
# Fit model
model = model_class.from_formula(formula=dep + ' ~ ' + formula, data=df, **model_kwargs)

View File

@ -23,7 +23,7 @@ import functools
import warnings
from .config import config
from .utils import Estimate, as_2groups, check_nan, fmt_p
from .utils import Estimate, as_2groups, check_nan, convert_pandas_nullable, fmt_p
# ----------------
# Student's t test
@ -334,6 +334,9 @@ def mannwhitney(df, dep, ind, *, nan_policy='warn', brunnermunzel=True, use_cont
# Check for/clean NaNs
df = check_nan(df[[ind, dep]], nan_policy)
# Convert pandas nullable types for independent variables as this breaks statsmodels
df = convert_pandas_nullable(df)
# Ensure 2 groups for ind
group1, data1, group2, data2 = as_2groups(df, dep, ind)

View File

@ -41,6 +41,18 @@ def check_nan(df, nan_policy):
else:
raise Exception('Invalid nan_policy, expected "raise", "warn" or "omit"')
def convert_pandas_nullable(df):
"""Convert pandas nullable dtypes (e.g. Int64) to non-nullable numpy dtypes"""
# TODO: Can we avoid this copy?
df = df.copy()
for col in df.columns:
if df[col].dtype == 'Int64':
df[col] = df[col].astype('int')
return df
def as_2groups(df, data, group):
"""Group the data by the given variable, ensuring only 2 groups"""