diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 45423a88d..4e9527220 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -66,6 +66,50 @@ def test_init_dataset_pandas_feature_columns(self): self.assertEqual(vm_dataset.feature_columns_categorical, []) self.assertEqual(vm_dataset.feature_columns, ["col1"]) + def test_get_numeric_columns(self): + """ + Test that get_numeric_columns returns all numeric columns in the dataset, + including target and extra columns, not just feature columns. + """ + # Dataset: col1 (numeric), col2 (object/categorical), text_col (text), + # target (numeric) + df_with_text = pd.DataFrame( + { + "col1": [1, 2, 3], + "col2": ["a", "b", "c"], + "text_col": ["first review", "second review", "third review"], + "target": [0, 1, 0], + } + ) + vm_dataset = DataFrameDataset( + raw_dataset=df_with_text, + target_column="target", + text_column="text_col", + feature_columns=["col1", "col2"], + ) + numeric_cols = vm_dataset.get_numeric_columns() + self.assertEqual(set(numeric_cols), {"col1", "target"}) + self.assertNotIn("col2", numeric_cols) + self.assertNotIn("text_col", numeric_cols) + + # feature_columns_numeric only has feature columns that are numeric + self.assertEqual(vm_dataset.feature_columns_numeric, ["col1"]) + # get_numeric_columns includes target as well + self.assertIn("target", numeric_cols) + + # All-numeric dataset: get_numeric_columns returns all columns + df_numeric = pd.DataFrame({"a": [1, 2], "b": [3.0, 4.0], "c": [5, 6]}) + vm_numeric = DataFrameDataset(raw_dataset=df_numeric) + self.assertEqual( + set(vm_numeric.get_numeric_columns()), + {"a", "b", "c"}, + ) + + # No numeric columns (only object/categorical/text) + df_cat = pd.DataFrame({"x": ["a", "b"], "y": ["c", "d"]}) + vm_cat = DataFrameDataset(raw_dataset=df_cat) + self.assertEqual(vm_cat.get_numeric_columns(), []) + def test_dtype_preserved(self): """ Test that dtype is preserved in DataFrameDataset. diff --git a/validmind/vm_models/dataset/dataset.py b/validmind/vm_models/dataset/dataset.py index c59d37077..f2b3e0439 100644 --- a/validmind/vm_models/dataset/dataset.py +++ b/validmind/vm_models/dataset/dataset.py @@ -885,6 +885,21 @@ def df(self) -> pd.DataFrame: else: return as_df(self._df[columns]).copy() + def get_numeric_columns(self) -> List[str]: + """ + Returns the names of all columns in the dataset that have numeric dtype. + + Unlike feature_columns_numeric, this includes every column in the dataset + (target, extra columns, prediction columns, etc.) that is numeric. + + Returns: + List[str]: The names of all numeric columns in the dataset. + """ + dtypes = self._df.dtypes + return dtypes[ + dtypes.apply(lambda x: pd.api.types.is_numeric_dtype(x)) + ].index.tolist() + @property def x(self) -> np.ndarray: """