使用Scikit-Learn,快速掌握機(jī)器學(xué)習(xí)預(yù)測方法
在本文中,我們將討論預(yù)測函數(shù)的區(qū)別和它們的用途。
在機(jī)器學(xué)習(xí)中,predict和predict_proba、predict_log_proba和decision_function方法都是用來根據(jù)訓(xùn)練好的模型進(jìn)行預(yù)測的。
predict方法
predict方法用于進(jìn)行二元分類或多元分類預(yù)測,并返回輸入數(shù)據(jù)的預(yù)測類標(biāo)簽。例如,如果你已經(jīng)訓(xùn)練了一個邏輯回歸模型來預(yù)測一個客戶是否會購買產(chǎn)品,則可以使用predict方法來預(yù)測一個新客戶是否會購買產(chǎn)品。
我們將使用來自scikit-learn的乳腺癌數(shù)據(jù)集。這個數(shù)據(jù)集包含了腫瘤觀察結(jié)果和腫瘤是惡性還是良性的相應(yīng)標(biāo)簽。
import numpy as np
from sklearn.svm import SVC
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import make_pipeline
import matplotlib.pyplot as plt
from sklearn.datasets import load_breast_cancer
# 加載數(shù)據(jù)集
dataset = load_breast_cancer(as_frame=True)
# 創(chuàng)建特征和目標(biāo)
X = dataset['data']
y = dataset['target']
# 將數(shù)據(jù)集分割成訓(xùn)練集和測試集
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y , test_size=0.25, random_state=0)
# 我們創(chuàng)建一個簡單的管道來規(guī)范數(shù)據(jù)并使用`SVC`分類器訓(xùn)練模型
svc_clf = make_pipeline(StandardScaler(),SVC(max_iter=1000, probability=True))
svc_clf.fit(X_train, y_train)
# 我們正在預(yù)測X_test的第一個條目
print(svc_clf.predict(X_test[:1]))
# 預(yù)測X_test的第一個條目屬于哪一類
[0]
predict_proba方法
predict_proba方法用于預(yù)測每個類標(biāo)簽的概率,它返回每個可能的類別標(biāo)簽的概率估計(jì)。這種方法通常用于二元或多元分類問題,在這些問題中你想知道每種可能結(jié)果的概率。例如,如果你已經(jīng)訓(xùn)練了一個模型,將動物的圖像分為貓、狗和馬,你可以使用predict_proba方法來獲得每個類別標(biāo)簽的概率估計(jì)。
print(svc_clf.predict_proba(X_test[:1]))
[[0.99848307 0.00151693]]
predict_log_proba方法
predict_log_proba方法與predict_proba類似,但它返回概率估計(jì)值的對數(shù),而不是原始概率。在處理非常小或非常大的概率值時,這可能很有用,因?yàn)樗兄诒苊鈹?shù)值下溢或溢出問題。
print(svc_clf.predict_log_proba(X_test[:1]))
[[-1.51808474e-03 -6.49106473e+00]]
decision_function方法
decision_function方法用于線性模型的二元分類問題。它為每個輸入的數(shù)據(jù)點(diǎn)返回一個分?jǐn)?shù),該分?jǐn)?shù)可用于確定類別標(biāo)簽的預(yù)測。可以根據(jù)應(yīng)用或領(lǐng)域知識來設(shè)置將數(shù)據(jù)點(diǎn)分類為正或負(fù)的閾值。
print(svc_clf.decision_function(X_test[:1]))
[-1.70756057]
總結(jié)
- 當(dāng)你想要得到輸入數(shù)據(jù)的預(yù)測類標(biāo)簽時,對二元或多元分類問題使用predict。
- 當(dāng)你想要獲得每個可能的類別標(biāo)簽的概率估計(jì)值時,請使用predict_proba處理二元或多元分類問題。
- 當(dāng)你需要處理非常小或非常大的概率值時,或者當(dāng)你想要避免數(shù)字下溢或溢出問題時,請使用predict_log_proba。
- 當(dāng)你想獲得每個輸入數(shù)據(jù)點(diǎn)的分?jǐn)?shù)時,使用decision_function處理線性模型的二元分類問題。
注意:有些分類器沒有所有的預(yù)測方法或需要額外的參數(shù)來訪問函數(shù)。例如:SVC需要將概率參數(shù)設(shè)置為True,才能使用概率預(yù)測。