Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,50 @@ def test_init_dataset_pandas_feature_columns(self):
self.assertEqual(vm_dataset.feature_columns_categorical, [])
self.assertEqual(vm_dataset.feature_columns, ["col1"])

def test_get_numeric_columns(self):
"""
Test that get_numeric_columns returns all numeric columns in the dataset,
including target and extra columns, not just feature columns.
"""
# Dataset: col1 (numeric), col2 (object/categorical), text_col (text),
# target (numeric)
df_with_text = pd.DataFrame(
{
"col1": [1, 2, 3],
"col2": ["a", "b", "c"],
"text_col": ["first review", "second review", "third review"],
"target": [0, 1, 0],
}
)
vm_dataset = DataFrameDataset(
raw_dataset=df_with_text,
target_column="target",
text_column="text_col",
feature_columns=["col1", "col2"],
)
numeric_cols = vm_dataset.get_numeric_columns()
self.assertEqual(set(numeric_cols), {"col1", "target"})
self.assertNotIn("col2", numeric_cols)
self.assertNotIn("text_col", numeric_cols)

# feature_columns_numeric only has feature columns that are numeric
self.assertEqual(vm_dataset.feature_columns_numeric, ["col1"])
# get_numeric_columns includes target as well
self.assertIn("target", numeric_cols)

# All-numeric dataset: get_numeric_columns returns all columns
df_numeric = pd.DataFrame({"a": [1, 2], "b": [3.0, 4.0], "c": [5, 6]})
vm_numeric = DataFrameDataset(raw_dataset=df_numeric)
self.assertEqual(
set(vm_numeric.get_numeric_columns()),
{"a", "b", "c"},
)

# No numeric columns (only object/categorical/text)
df_cat = pd.DataFrame({"x": ["a", "b"], "y": ["c", "d"]})
vm_cat = DataFrameDataset(raw_dataset=df_cat)
self.assertEqual(vm_cat.get_numeric_columns(), [])

def test_dtype_preserved(self):
"""
Test that dtype is preserved in DataFrameDataset.
Expand Down
15 changes: 15 additions & 0 deletions validmind/vm_models/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -885,6 +885,21 @@ def df(self) -> pd.DataFrame:
else:
return as_df(self._df[columns]).copy()

def get_numeric_columns(self) -> List[str]:
"""
Returns the names of all columns in the dataset that have numeric dtype.

Unlike feature_columns_numeric, this includes every column in the dataset
(target, extra columns, prediction columns, etc.) that is numeric.

Returns:
List[str]: The names of all numeric columns in the dataset.
"""
dtypes = self._df.dtypes
return dtypes[
dtypes.apply(lambda x: pd.api.types.is_numeric_dtype(x))
].index.tolist()

@property
def x(self) -> np.ndarray:
"""
Expand Down