Make yli.mannwhitney support pandas nullable dtypes
This commit is contained in:
parent
ee36ac9d14
commit
b31ae6686f
@ -29,7 +29,7 @@ import warnings
|
||||
from .bayes_factors import BayesFactor, bayesfactor_afbf
|
||||
from .config import config
|
||||
from .sig_tests import FTestResult
|
||||
from .utils import Estimate, check_nan, cols_for_formula, fmt_p, formula_factor_ref_category, parse_patsy_term
|
||||
from .utils import Estimate, check_nan, cols_for_formula, convert_pandas_nullable, fmt_p, formula_factor_ref_category, parse_patsy_term
|
||||
|
||||
def vif(df, formula=None, *, nan_policy='warn'):
|
||||
"""
|
||||
@ -534,10 +534,8 @@ def regress(
|
||||
if df[dep].dtype != 'float64':
|
||||
df[dep] = df[dep].astype('float64')
|
||||
|
||||
# Convert pandas nullable types for independent variables
|
||||
for col in df.columns:
|
||||
if df[col].dtype == 'Int64':
|
||||
df[col] = df[col].astype('float64')
|
||||
# Convert pandas nullable types for independent variables as this breaks statsmodels
|
||||
df = convert_pandas_nullable(df)
|
||||
|
||||
# Fit model
|
||||
model = model_class.from_formula(formula=dep + ' ~ ' + formula, data=df, **model_kwargs)
|
||||
|
@ -23,7 +23,7 @@ import functools
|
||||
import warnings
|
||||
|
||||
from .config import config
|
||||
from .utils import Estimate, as_2groups, check_nan, fmt_p
|
||||
from .utils import Estimate, as_2groups, check_nan, convert_pandas_nullable, fmt_p
|
||||
|
||||
# ----------------
|
||||
# Student's t test
|
||||
@ -334,6 +334,9 @@ def mannwhitney(df, dep, ind, *, nan_policy='warn', brunnermunzel=True, use_cont
|
||||
# Check for/clean NaNs
|
||||
df = check_nan(df[[ind, dep]], nan_policy)
|
||||
|
||||
# Convert pandas nullable types for independent variables as this breaks statsmodels
|
||||
df = convert_pandas_nullable(df)
|
||||
|
||||
# Ensure 2 groups for ind
|
||||
group1, data1, group2, data2 = as_2groups(df, dep, ind)
|
||||
|
||||
|
12
yli/utils.py
12
yli/utils.py
@ -41,6 +41,18 @@ def check_nan(df, nan_policy):
|
||||
else:
|
||||
raise Exception('Invalid nan_policy, expected "raise", "warn" or "omit"')
|
||||
|
||||
def convert_pandas_nullable(df):
|
||||
"""Convert pandas nullable dtypes (e.g. Int64) to non-nullable numpy dtypes"""
|
||||
|
||||
# TODO: Can we avoid this copy?
|
||||
df = df.copy()
|
||||
|
||||
for col in df.columns:
|
||||
if df[col].dtype == 'Int64':
|
||||
df[col] = df[col].astype('int')
|
||||
|
||||
return df
|
||||
|
||||
def as_2groups(df, data, group):
|
||||
"""Group the data by the given variable, ensuring only 2 groups"""
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user