diff --git a/docs/general.rst b/docs/general.rst index d6eaf01..4209d16 100644 --- a/docs/general.rst +++ b/docs/general.rst @@ -12,8 +12,6 @@ Most functions take a parameter **nan_policy** to specify how to handle *nan* va * **raise** – Raise an error on *nan* values * **omit** – Silently drop rows with *nan* values -In determining whether there is *nan* in the data, only the columns specified in the function (if applicable) are considered. - .. autofunction:: yli.utils.check_nan General result classes diff --git a/yli/utils.py b/yli/utils.py index 937a36b..678821c 100644 --- a/yli/utils.py +++ b/yli/utils.py @@ -26,7 +26,7 @@ from .config import config # ---------------------------- # Data cleaning and validation -def check_nan(df, nan_policy): +def check_nan(df, nan_policy, *, cols=None): """ Check df against *nan_policy* and return cleaned input @@ -34,22 +34,25 @@ def check_nan(df, nan_policy): :type df: DataFrame :param nan_policy: Policy to apply when encountering NaN values (*warn*, *raise*, *omit*) :type nan_policy: str + :param cols: Columns to check for NaN, or *None* for all columns + :type cols: List[str] :return: Data with NaNs removed, which may or may not be copied :rtype: DataFrame """ if nan_policy == 'raise': - if pd.isna(df).any(axis=None): + df_to_check = df if cols is None else df[cols] + if pd.isna(df_to_check).any(axis=None): raise ValueError('NaN in input, pass nan_policy="warn" or "omit" to ignore') return df elif nan_policy == 'warn': - df_cleaned = df.dropna() + df_cleaned = df.dropna(subset=cols) if len(df_cleaned) < len(df): warnings.warn('Omitting {} rows with NaN'.format(len(df) - len(df_cleaned))) return df_cleaned elif nan_policy == 'omit': - return df.dropna() + return df.dropna(subset=cols) else: raise Exception('Invalid nan_policy, expected "raise", "warn" or "omit"')