学研君 2025-05-25 11:15 浙江
速查!用这些代码片段简化机器学习全流程。
速查!用这些代码片段简化机器学习全流程。
构建机器学习模型是数据科学的关键环节,涉及运用算法进行数据预测或挖掘数据中的模式。
本文分享一系列简洁的代码片段,涵盖机器学习过程的各个阶段,从数据准备、模型选择,到模型评估和超参数调优。这些代码示例能帮助你使用诸如Scikit-Learn、XGBoost、CatBoost、LightGBM等库,完成常见的机器学习任务,还包含使用Hyperopt进行超参数优化、利用SHAP值进行模型解释等高级技术。
借助这些快速参考代码,你可以简化机器学习工作流程,在不同领域开发出高效的预测模型。
一、数据处理与探索
data = pd.read_csv('dataset.csv')
data.head()
、data.info()
、data.describe()
data.dropna()
、data.fillna()
pd.get_dummies(data)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
scaler = StandardScaler()
,X_scaled = scaler.fit_transform(X)
二、模型初始化、训练与评估
model = RandomForestClassifier()
model.fit(X_train, y_train)
predictions = model.predict(X_test)
accuracy_score(y_test, predictions)
conf_matrix = confusion_matrix(y_test, predictions)
class_report = classification_report(y_test, predictions)
cv_scores = cross_val_score(model, X, y, cv=5)
grid_search = GridSearchCV(model, param_grid, cv=5)
,grid_search.fit(X, y)
feature_importance = model.feature_importances_
joblib.dump(model,'model.pkl')
loaded_model = joblib.load('model.pkl')
三、降维和聚类
pca = PCA(n_components=2)
,X_pca = pca.fit_transform(X)
pca = PCA(n_components=2)
,X_pca = pca.fit_transform(X)
kmeans = KMeans(n_clusters=3)
,kmeans.fit(X)
,labels = kmeans.labels_
Sum_of_squared_distances = []
,for k in range(1,11): kmeans = KMeans(n_clusters=k)
,kmeans.fit(X)
,Sum_of_squared_distances.append(kmeans.inertia_)
silhouette_avg = silhouette_score(X, labels)
四、各类分类模型
dt_model = DecisionTreeClassifier()
,dt_model.fit(X_train, y_train)
svm_model = SVC()
,svm_model.fit(X_train, y_train)
nb_model = GaussianNB()
,nb_model.fit(X_train, y_train)
knn_model = KNeighborsClassifier()
,knn_model.fit(X_train, y_train)
KNeighborsRegressor(n_neighbors=5).fit(X_train, y_train)
logreg_model = LogisticRegression()
,logreg_model.fit(X_train, y_train)
ridge_model = Ridge()
,ridge_model.fit(X_train, y_train)
lasso_model = Lasso()
,lasso_model.fit(X_train, y_train)
ensemble_model = VotingClassifier(estimators=[('clf1', clf1), ('clf2', clf2)], voting='soft')
,ensemble_model.fit(X_train, y_train)
bagging_model = BaggingClassifier(base_estimator=DecisionTreeClassifier(), n_estimators=100)
,bagging_model.fit(X_train, y_train)
rf_model = RandomForestClassifier(n_estimators=100)
,rf_model.fit(X_train, y_train)
gb_model = GradientBoostingClassifier()
,gb_model.fit(X_train, y_train)
adaboost_model = AdaBoostClassifier()
,adaboost_model.fit(X_train, y_train)
xgb_model = xgb.XGBClassifier()
,xgb_model.fit(X_train, y_train)
lgb_model = lgb.LGBMClassifier()
,lgb_model.fit(X_train, y_train)
catboost_model = CatBoostClassifier()
,catboost_model.fit(X_train, y_train)
五、模型评估指标
fpr, tpr, thresholds = roc_curve(y_test, predictions_prob[:,1])
roc_auc = roc_auc_score(y_test, predictions_prob[:,1])
precision, recall, thresholds = precision_recall_curve(y_test, predictions_prob[:,1])
pr_auc = auc(recall, precision)
f1 = f1_score(y_test, predictions)
roc_auc = roc_auc_score(y_test, predictions_prob[:,1])
mse = mean_squared_error(y_test, predictions)
r2 = r2_score(y_test, predictions)
六、交叉验证和采样技术
stratified_kfold = StratifiedKFold(n_splits=5)
time_series_split = TimeSeriesSplit(n_splits=5)
rus = RandomUnderSampler()
,X_resampled, y_resampled = rus.fit_resample(X, y)
ros = RandomOverSampler()
,X_resampled, y_resampled = ros.fit_resample(X, y)
smote = SMOTE()
,X_resampled, y_resampled = smote.fit_resample(X, y)
class_weight='balanced'
stratified_cv = StratifiedKFold(n_splits=5)
七、特征工程与转换
plot_learning_curve(model, X, y)
plot_validation_curve(model, X, y, param_name='param', param_range=param_range)
early_stopping_rounds=10
scaler = MinMaxScaler(feature_range=(0, 1))
,X_scaled = scaler.fit_transform(X)
data_encoded = pd.get_dummies(data)
label_encoder = LabelEncoder()
,data['label_encoded'] = label_encoder.fit_transform(data['label'])
scaler = StandardScaler()
,X_normalized = scaler.fit_transform(X)
scaler = MinMaxScaler()
,X_standardized = scaler.fit_transform(X)
X_transformed = np.log1p(data)
iso_forest = IsolationForest()
,outliers = iso_forest.fit_predict(X)
envelope = EllipticEnvelope(contamination=0.01)
,outliers = envelope.fit_predict(X)
imputer = SimpleImputer(strategy='mean')
,X_imputed = imputer.fit_transform(X)
poly = PolynomialFeatures(degree=2)
,X_poly = poly.fit_transform(X)
八、回归模型与技术
lasso = Lasso(alpha=1.0)
,lasso.fit(X_train, y_train)
ridge = Ridge(alpha=1.0)
,ridge.fit(X_train, y_train)
huber = HuberRegressor()
,huber.fit(X_train, y_train)
quantile_reg = QuantReg(y_train, X_train)
,quantile_result = quantile_reg.fit(q=0.5)
ransac = RANSACRegressor()
,ransac.fit(X_train, y_train)
九、自动化机器学习和高级技术
tpot = TPOTClassifier()
,tpot.fit(X_train, y_train)
h2o_automl = H2OAutoML(max_models=10, seed=1)
,h2o_automl.train(x=X_train.columns, y='target', training_frame=train)
十、绘图与可视化
plt.savefig('plot.png')
plot_feature_importance(model)
plt.scatter(X[:, 0], X[:, 1], c=KMeans(n_clusters=3).fit_predict(X), cmap='viridis')
十一、其他
cv_predictions = cross_val_predict(model, X, y, cv=5)
custom_metric = custom_metric(y_true, y_pred)
kbest = SelectKBest(chi2, k=5)
,X_selected = kbest.fit_transform(X, y)
rfecv = RFECV(estimator=DecisionTreeClassifier(), step=1, cv=5)
,X_rfecv = rfecv.fit_transform(X, y)
poly = PolynomialFeatures(degree=2)
,X_poly = poly.fit_transform(X)
class_weight='balanced'
learning_rate=0.1
random_state=42
ridge = Ridge(alpha=1.0)
,ridge.fit(X_train, y_train)
lasso = Lasso(alpha=1.0)
,lasso.fit(X_train, y_train)
dt_model = DecisionTreeClassifier(max_depth=3)
,dt_model.fit(X_train, y_train)
knn_model = KNeighborsClassifier(n_neighbors=5)
,knn_model.fit(X_train, y_train)
svm_model = SVC(kernel='rbf')
,svm_model.fit(X_train, y_train)
rf_model = RandomForestClassifier(n_estimators=100)
,rf_model.fit(X_train, y_train)
gb_model = GradientBoostingClassifier(learning_rate=0.1)
,gb_model.fit(X_train, y_train)
GridSearchCV(HuberRegressor(), {'epsilon': [1.1, 1.2, 1.3]}, cv=5).fit(X_train, y_train)
RidgeCV(alphas=[0.1, 1.0, 10.0], cv=5).fit(X_train, y_train)
stacked_model = StackingClassifier(classifiers=[clf1, clf2], meta_classifier=meta_clf)
,stacked_model.fit(X_train, y_train)