Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
360 views
in Technique[技术] by (71.8m points)

python - How can I standardize only numeric variables in an sklearn pipeline?

I am trying to create an sklearn pipeline with 2 steps:

  1. Standardize the data
  2. Fit the data using KNN

However, my data has both numeric and categorical variables, which I have converted to dummies using pd.get_dummies. I want to standardize the numeric variables but leave the dummies as they are. I have been doing this like this:

X = dataframe containing both numeric and categorical columns
numeric = [list of numeric column names]
categorical = [list of categorical column names]
scaler = StandardScaler()
X_numeric_std = pd.DataFrame(data=scaler.fit_transform(X[numeric]), columns=numeric)
X_std = pd.merge(X_numeric_std, X[categorical], left_index=True, right_index=True)

However, if I were to create a pipeline like:

pipe = sklearn.pipeline.make_pipeline(StandardScaler(), KNeighborsClassifier())

It would standardize all of the columns in my DataFrame. Is there a way to do this while standardizing only the numeric columns?

See Question&Answers more detail:os

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Reply

0 votes
by (71.8m points)

UPD: 2021-05-10

For sklearn >= 0.20 we can use sklearn.compose.ColumnTransformer

Here is a small example:

imports and data loading

# Author: Pedro Morales <part.morales@gmail.com>
#
# License: BSD 3 clause

import numpy as np

from sklearn.compose import ColumnTransformer
from sklearn.datasets import fetch_openml
from sklearn.pipeline import Pipeline
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split, GridSearchCV

np.random.seed(0)

# Load data from https://www.openml.org/d/40945
X, y = fetch_openml("titanic", version=1, as_frame=True, return_X_y=True)

pipeline-aware data preprocessing using ColumnTransformer:

numeric_features = ['age', 'fare']
numeric_transformer = Pipeline(steps=[
    ('imputer', SimpleImputer(strategy='median')),
    ('scaler', StandardScaler())])

categorical_features = ['embarked', 'sex', 'pclass']
categorical_transformer = OneHotEncoder(handle_unknown='ignore')

preprocessor = ColumnTransformer(
    transformers=[
        ('num', numeric_transformer, numeric_features),
        ('cat', categorical_transformer, categorical_features)])

classification

# Append classifier to preprocessing pipeline.
# Now we have a full prediction pipeline.
clf = Pipeline(steps=[('preprocessor', preprocessor),
                      ('classifier', LogisticRegression())])

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2,
                                                    random_state=0)

clf.fit(X_train, y_train)
print("model score: %.3f" % clf.score(X_test, y_test))

OLD Answer:

Assuming you have the following DF:

In [163]: df
Out[163]:
     a     b    c    d
0  aaa  1.01  xxx  111
1  bbb  2.02  yyy  222
2  ccc  3.03  zzz  333

In [164]: df.dtypes
Out[164]:
a     object
b    float64
c     object
d      int64
dtype: object

you can find all numeric columns:

In [165]: num_cols = df.columns[df.dtypes.apply(lambda c: np.issubdtype(c, np.number))]

In [166]: num_cols
Out[166]: Index(['b', 'd'], dtype='object')

In [167]: df[num_cols]
Out[167]:
      b    d
0  1.01  111
1  2.02  222
2  3.03  333

and apply StandardScaler only to those numeric columns:

In [168]: scaler = StandardScaler()

In [169]: df[num_cols] = scaler.fit_transform(df[num_cols])

In [170]: df
Out[170]:
     a         b    c         d
0  aaa -1.224745  xxx -1.224745
1  bbb  0.000000  yyy  0.000000
2  ccc  1.224745  zzz  1.224745

now you can "one hot encode" categorical (non-numeric) columns...


与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
OGeek|极客中国-欢迎来到极客的世界,一个免费开放的程序员编程交流平台!开放,进步,分享!让技术改变生活,让极客改变未来! Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...