你真的對你構(gòu)建的ML模型足夠了解么?!
點擊以上,盡在AI中國
我們真的知道我們建立的ML模型中發(fā)生了什么嗎?我們來探索一下。
在我之前的博客中,我們看到了對XGBoost、LightGBM和Catboost的對比研究。通過這種分析,我們得出結(jié)論,catboost在速度和準(zhǔn)確性上優(yōu)于其他兩種。在這一部分中,我們將深入討論catboost,并探索catboost為高效建模和理解超參數(shù)提供的新功能。
對于新讀者,catboost是Yandex團隊在2017年開發(fā)的開源梯度增強算法。它是一種機器學(xué)習(xí)算法,允許用戶快速處理大數(shù)據(jù)集的分類特征,不同于XGBoost和LightGBM。Catboost可用于解決回歸、分類和排序問題。
作為數(shù)據(jù)科學(xué)家,我們可以很容易地訓(xùn)練模型并做出預(yù)測,但我們往往無法理解這些奇怪的算法中發(fā)生了什么。這也是為什么我們看到離線評測和最終生產(chǎn)的模型性能存在巨大差異的原因之一?,F(xiàn)在我們應(yīng)該停止將ML視為“黑箱”,在提高模型準(zhǔn)確性的同時,關(guān)注模型解釋。這也有助于我們識別數(shù)據(jù)偏差。在本節(jié)中,我們將了解catboost如何通過以下函數(shù)幫助我們分析模型并提高可見性:
特征重要性
你為什么想知道它?刪除不必要的特征以簡化模型和縮短訓(xùn)練/預(yù)測時間,獲得對您的目標(biāo)值最有影響的函數(shù)并操縱它們以獲得商業(yè)利益(例如,醫(yī)療醫(yī)療保健提供商希望確定哪些因素會導(dǎo)致每個患者患某些疾病的風(fēng)險,以便他們可以通過靶向藥物直接解決這些風(fēng)險因素)。
除了選擇特征重要性的類型,我們還應(yīng)該知道我們想要使用哪些數(shù)據(jù)來發(fā)現(xiàn)特征重要性——訓(xùn)練、測試或完成數(shù)據(jù)集。選擇一個對另一個有利有弊,但是最后你需要決定你是想知道模型在多大程度上依賴于每個特征來預(yù)測(使用訓(xùn)練數(shù)據(jù))還是該特征對性能的貢獻程度以及模型對看不見的數(shù)據(jù)的影響(使用測試數(shù)據(jù))。正如我們將在后面看到的,只有一些方法可用于發(fā)現(xiàn)不用于訓(xùn)練模型的數(shù)據(jù)的特征重要性。
如果你關(guān)心第二個,并假設(shè)你有所有的時間和資源,找出特征重要性的最原始和可靠的方法是訓(xùn)練多個模型,一次留下一個特征,并在測試集上比較性能。如果性能從基線(我們使用所有函數(shù)時的性能)變化很大,說明這個特性非常重要。但是,因為我們生活在一個需要優(yōu)化精度和計算時間的實際環(huán)境中,所以這種方法是不必要的。以下是一些智能方法,其中catboost可以幫助您找到最適合您的模型的函數(shù):
CB . get _ feature _ importance(type = " _ _ _ _ _ _ ")
“類型”可能值:
-預(yù)測值更改
-損失功能變化
-功能重要性
預(yù)測值更改非排名指標(biāo)和損失函數(shù)更改排名指標(biāo)
-形狀值
計算每個對象的SHAP值
-互動
計算每個特征之間的成對得分
預(yù)測值變化
對于每個特征,PredictionValuesChange顯示了特征值變化時預(yù)測的平均變化程度。重要性值越大,平均值越大。如果這個特性改變了,預(yù)測值也會改變。
優(yōu)點:計算成本低,因為不需要多次訓(xùn)練或測試,不會存儲任何額外的信息。您將得到標(biāo)準(zhǔn)化的值作為輸出(所有的重要性加起來為100)。
缺點:可能會給排名目標(biāo)帶來誤導(dǎo)性的結(jié)果,可能會把groupwise的特性放在最前面,即使它們對最終的損失值有一點影響。
損失函數(shù)損失函數(shù)變化
為了獲得該特征的重要性,catboost簡單地采用在正常條件下(當(dāng)我們包括該特征時)使用該模型獲得的度量(損失函數(shù))和沒有該特征的模型之間的差異。差異越大,特征越重要。在catboost文檔中,沒有明確提到如何找到?jīng)]有特征的模型。
優(yōu)缺點:這對于大部分題型都很有效,不同于預(yù)測值的變化。在這種情況下,你可以得到排名問題的誤導(dǎo)性結(jié)果,同時,它的計算量很大。
形狀值
https://github.com/slundberg/shap
SHAP值將預(yù)測值分解為每個元素的貢獻。它測量特征對單個預(yù)測值的影響,并比較基線預(yù)測(訓(xùn)練數(shù)據(jù)集的目標(biāo)值的平均值)。
shap值的兩個主要使用案例:
1.特征的對象級貢獻
shap _ values = model . get _ feature _ importance(Pool(X _ test,label=y_test,cat _ features = category _ features _ indexes),
type="ShapValues ")
expected_value = shap_values[0,-1]
形狀值=形狀值[:,:-1]
shap.initjs()
shap.force_plot(expected_value,shap_values[3,:,X_test.iloc[3,:))
https://github.com/slundberg/shap
2.整個數(shù)據(jù)集的概要(整體特征重要性)
形狀摘要圖(形狀值,X測試)
雖然我們可以通過shap得到準(zhǔn)確的特征重要度,但是它們在計算上比catboost內(nèi)置的特征重要度更昂貴。
捕獲賞金
基于相同概念但不同實現(xiàn)的另一個特征重要性是基于排列的特征重要性。Catboost不使用,純模型無關(guān),計算簡單。
我們?nèi)绾芜x擇?
雖然PredictionValuesChange和losfunctionchange都可以用于所有類型的指標(biāo),但建議使用losfunctionchange對指標(biāo)進行排序。除了PredictionValuesChange,其他所有方法都可以使用測試數(shù)據(jù),并使用根據(jù)訓(xùn)練數(shù)據(jù)訓(xùn)練的模型來查找特征重要性。
為了更好地理解這些差異,下面是我們討論的所有方法的結(jié)果:
catboost功能的結(jié)果。從經(jīng)典的“成人”人口普查數(shù)據(jù)集中預(yù)測人們是否會上報5萬美元以上的收入(使用對數(shù)損失)。
從上圖可以看出,大多數(shù)方法都符合頂層特性??雌饋鞮ossFunctionChange最接近shap(更可靠)。但是直接比較這些方法是不公平的,因為預(yù)測值的變化是基于訓(xùn)練數(shù)據(jù)的,而其他方法都是基于測試數(shù)據(jù)的。
我們還應(yīng)該看到運行所有這些所需的時間:
互相地
有了這個參數(shù),就可以求出一對元素的強弱(兩個元素的重要性)。
在輸出中,您將獲得每對特征的列表。該列表將有三個值,第一個值是該對中第一個元素的索引,第二個值是該對中第二個元素的索引,第三個值是該對的要素重要性分?jǐn)?shù)。請查看嵌入式筆記本了解實施詳情。
值得注意的是,單個特征重要性中的前兩個特征不一定是最強的一對。
筆記本
筆記本電腦中使用的數(shù)據(jù)集
對象重要性
你為什么想知道它?從訓(xùn)練數(shù)據(jù)中刪除最無用的訓(xùn)練對象。根據(jù)哪些新對象預(yù)計最“有用”,優(yōu)先給一批新對象做標(biāo)記,類似于主動學(xué)習(xí)。
通過這個函數(shù),可以計算出每個對象對測試數(shù)據(jù)優(yōu)化指標(biāo)的影響。正值反映優(yōu)化指標(biāo)的增加,負值反映優(yōu)化指標(biāo)的減少。該方法是本文中描述的方法的實現(xiàn)。這些算法的細節(jié)超出了本文的范圍。
關(guān)于對象重要性的Catboost教程
cb.get_object_importance中有三種類型的update _ methods:single point:最快最不精確的方法TopKLeaves:指定葉子的個數(shù)。該值越高,計算越精確,速度越慢。AllPoints:最慢最準(zhǔn)確的方法。
例如,以下值將方法設(shè)置為TopKLeaves,并將葉數(shù)限制為3:
TopKLeaves:top= 3
模型分析圖
Catboost最近在其最新更新中引入了這一功能。有了這個函數(shù),我們將能夠可視化算法如何劃分每個特征的數(shù)據(jù),并查看特定于特征的統(tǒng)計數(shù)據(jù)。更具體地說,我們將能夠看到:每個bin的平均目標(biāo)值(bin用于連續(xù)函數(shù))或每個類別的平均預(yù)測值(目前僅支持OHE函數(shù));通過不同的特征值預(yù)測每個箱中的對象數(shù)量;對于每一個物體,特征值都是變化的,所以會落在某個區(qū)域。然后,模型根據(jù)新的特征值對目標(biāo)進行預(yù)測,得到預(yù)測在區(qū)域內(nèi)的平均值(由紅點給出)。
這個圖表將為我們提供信息,例如我們?nèi)绾纹骄指?我們不希望所有對象都進入一個區(qū)域),我們的預(yù)測是否接近目標(biāo)(藍色和橙色線),紅線將告訴我們我們的預(yù)測有多敏感。
數(shù)字特征分析
單一熱編碼特征分析