譯者 | 朱先忠?
審校 | 孫淑娟?
簡介?
通常,在嘗試改進機器學習模型時,人們首先想到的解決方案是添加更多的訓練數據。額外的數據通常是有幫助(在某些情況下除外)的,但生成高質量的數據可能非常昂貴。通過使用現有數據獲得最佳模型性能,超參數優化可以節省我們的時間和資源。?
顧名思義,超參數優化是為機器學習模型確定最佳超參數組合以滿足優化函數(即,給定研究中的數據集,最大化模型的性能)的過程。換句話說,每個模型都會提供多個有關選項的調整“按鈕”,我們可以改變它們,直到我們模型的超參數達到最佳組合。在超參數優化過程中,我們可以更改的參數的一些示例可以是學習率、神經網絡的架構(例如,隱藏層的數量)、正則化等。?
在這篇文章中,我們將從概念上介紹三種最常見的超參數優化方法,即網格搜索、隨機搜索和貝葉斯優化,然后對它們進行逐一實現。?
我將在文章一開始提供一個高級別的比較表,以供讀者參考,然后將在本文的其余部分進一步探討、解釋和實施比較表中的每一項。?
表1:超參數優化方法比較?
1.網格搜索算法?
網格搜索可能是超參數優化的最簡單和最直觀的方法,它涉及在定義的搜索空間中徹底搜索超參數的最佳組合。在此上下文中的“搜索空間”是整個超參數以及在優化過程中考慮的此類超參數的值。讓我們通過一個示例來更好地理解網格搜索。?
假設我們有一個只有三個參數的機器學習模型,每個參數都可以取下表中提供的值:?
我們不知道這些參數的哪個組合將優化我們的模型的優化功能(即為我們的機器學習模型提供最佳輸出)。在網格搜索中,我們只需嘗試這些參數的每一個組合,測量每個參數的模型性能,然后簡單地選擇產生最佳性能的組合!在此示例中,參數1可以取3個值(即1、2或3),參數2可以取3種值(即a、b和c),參數3可以取3種值(即x、y和z)。換句話說,總共有3*3*3=27個組合。本例中的網格搜索將涉及27輪評估機器學習模型的性能,以找到性能最佳的組合。?
如您所見,這種方法非常簡單(類似于試錯任務),但也有一些局限性。讓我們一起總結一下此方法的優點和缺點。?
其中,優勢包括:?
- 易于理解和實施?
- 易于并行化?
- 適用于離散和連續空間?缺點主要有:?
- 在具有大量超參數的大型和/或復雜模型中成本高昂(因為必須嘗試和評估所有組合)?
- 無記憶——不從過去的觀察中學習?
- 如果搜索空間太大,可能無法找到最佳組合我的建議是,如果您有一個搜索空間較小的簡單模型,請使用網格搜索;否則,建議繼續往下閱讀以找到更適合更大搜索空間的解決方案。
現在,讓我們用一個真實的例子來實現網格搜索。
1.1.網格搜索算法實現?
為了實現網格搜索,我們將使用scikit-learn中的Iris數據集創建一個隨機森林分類模型。該數據集包括3種不同的鳶尾花瓣和萼片長度,將用于本次分類練習。在本文中,模型開發是次要的,因為目標是比較各種超參數優化策略的性能。我鼓勵您關注模型評估結果,以及每種超參數優化方法達到所選超參數集所需的時間。我將描述運行結果,然后為本文中使用的三種方法提供一個匯總比較表。?
包括所有超參數值的搜索空間,定義如下:?
上述搜索空間由4*5*3*3=180個超參數的總組合組成。我們將使用網格搜索來找到優化目標函數的組合,如下所示:?
上面代碼的輸出結果如下:?
這里我們可以看到使用網格搜索選擇的超參數值。其中,best_score描述了使用所選超參數集的評估結果,而elapsed_time描述了我的本地筆記本電腦執行此超參數優化策略所花費的時間。在進行下一種方法時,請記住評估結果和經過的時間,以便進行比較?,F在,讓我們進入隨機搜索的討論。?
2.隨機搜索算法?
顧名思義,隨機搜索是從定義的搜索空間中隨機采樣超參數的過程。與網格搜索不同,隨機搜索只會在預定義的迭代次數(取決于可用資源,如時間、預算、目標等)中選擇超參數值的隨機子集,并計算每個超參數的機器學習模型的性能,然后選擇最佳的超參數值。?
根據上述方法,您可以想象,與完整的網格搜索相比,隨機搜索成本更低,但仍有其自身的優勢和劣勢,如下所示:?
優勢:?
- 易于理解和實施?
- 易于并行化?
- 適用于離散和連續空間?
- 比網格搜索便宜?
- 與具有相同嘗試次數的網格搜索相比,更可能收斂到最優?缺點:?
- 無記憶——不從過去的觀察中學習?
- 考慮到隨機選擇,可能會錯過重要的超參數值?
在下一個方法中,我們將通過貝葉斯優化解決網格和隨機搜索的“無記憶”缺點。但在討論此方法之前,讓我們先來實現隨機搜索。?
2.1.隨機搜索算法實現
使用下面的代碼片段,我們將針對網格搜索實現中描述的相同問題實現隨機搜索超參數優化。?
上面代碼的輸出結果如下:?
隨機搜索結果?
與網格搜索的結果相比,這些結果非常有趣。best_score保持不變,但elapsed_time從352.0秒減少到75.5秒!真是令人印象深刻!換句話說,隨機搜索算法設法找到了一組超參數,在網格搜索所需時間的21%左右,其性能與網格搜索相同!但是,這里的效率高得多。?
接下來,讓我們繼續我們的下一種方法,稱為貝葉斯優化,它從優化過程中的每一次嘗試中學習。?
3.貝葉斯優化
貝葉斯優化是一種超參數優化方法,它使用概率模型從以前的嘗試中“學習”,并將搜索引向搜索空間中超參數的最佳組合,從而優化機器學習模型的目標函數。?
貝葉斯優化方法可以分為4個步驟,我將在下面描述。我鼓勵您通讀這些步驟,以便更好地理解流程,但使用這種方法并不需要什么前提知識。?
- 定義一個“先驗”,這是一個關于我們在某個時間點對優化目標函數最可能的超參數組合的信念的概率模型?
- 評估超參數樣本的模型?
- 使用步驟2中獲得的知識,更新步驟1中的概率模型(即我們所稱的“先驗”),了解我們認為優化目標函數的超參數的最可能組合在哪里。我們更新的信念稱為“后驗”。換句話說,在步驟2中獲得的知識幫助我們更好地了解搜索空間,并將我們從先驗帶到后驗,使后驗成為我們關于搜索空間和目標函數的“最新”知識,由步驟2提供信息?
- 重復步驟2和3,直到模型性能收斂、資源耗盡或滿足其他預定義指標?
如果您有興趣了解更多有關貝葉斯優化的詳細信息,可以查看以下帖子:?
《機器學習中的貝葉斯優化算法》,地址是:
??https://medium.com/@fmnobar/conceptual-overview-of-bayesian-optimization-for-parameter-tuning-in-machine-learning-a3b1b4b9339f。???
現在,既然我們已經了解了貝葉斯優化是如何工作的,那么讓我們來看看它的優點和缺點。?
優勢:?
- 從過去的觀察中學習,因此效率更高。換句話說,與無記憶方法相比,它有望在更少的迭代中找到一組更好的超參數?
- 在給定某些假設的情況下收斂到最優?缺點:?
- 難以并行化?
- 計算量大于網格和每次迭代的隨機搜索?
- 先驗和貝葉斯優化中使用的函數(例如,獲取函數等)的初始概率分布的選擇會顯著影響性能及其學習曲線?
在排除了細節之后,讓我們實現貝葉斯優化并查看結果。?
3.1.貝葉斯優化算法實現
與上一節類似,我們將使用下面的代碼片段為網格搜索實現中描述的相同問題實現貝葉斯超參數優化。?
上面代碼的輸出結果如下:?
貝葉斯優化結果?
另一組有趣的結果!best_score與我們通過網格和隨機搜索獲得的結果保持一致,但結果僅用了23.1秒,而隨機搜索為75.5秒,網格搜索為352.0秒!換句話說,使用貝葉斯優化所需的時間比網格搜索所需的時間大約少93%。這是一個巨大的生產力提升,在更大、更復雜的模型和搜索空間中變得更有意義。?
請注意,貝葉斯優化只使用了10次迭代就獲得了這些結果,因為它可以從以前的迭代中學習(與隨機和網格搜索相反)。?
結果比較
下表對目前所討論的三種方法的結果進行了比較?!癕ethodology(方法論)”一欄描述了所使用的超參數優化方法。隨后是使用每種方法選擇的超參數?!癇est Score”是使用特定方法獲得的分數,然后是“Elapsed Time”,表示優化策略在我的本地筆記本電腦上運行所需的時間。最后一列“獲得的效率(Gained Efficiency)”假設網格搜索為基線,然后計算與網格搜索相比,其他兩種方法中每種方法獲得的效率(使用經過的時間)。例如,由于隨機搜索耗時75.5秒,而網格搜索耗時352.0秒,因此相對于網格搜索的基線,隨機搜索的效率計算為1–75.5/352.0=78.5%。?
表2——方法性能比較表?
以上比較表中的兩個主要結論:?
- 效率:我們可以看到貝葉斯優化等學習方法如何在更短的時間內找到一組優化的超參數。?
- 參數選擇:可以有多個正確答案。例如,貝葉斯優化的選定參數與網格和隨機搜索的參數不同,盡管評估度量(即best_score)保持不變。這在更大、更復雜的環境中更為重要。?
結論
在這篇文章中,我們討論了什么是超參數優化,并介紹了用于此優化練習的三種最常見的方法。然后,我們詳細介紹了這三種方法中的每一種,并在分類練習中實現了它們。最后,我們比較了實施這三種方法的結果。我們發現,從以前的嘗試中學習的貝葉斯優化等方法可以顯著提高效率,這可能是大型復雜模型(如深度神經網絡)中的一個重要因素,其中效率可能是一個決定因素。?
譯者介紹
朱先忠,51CTO社區編輯,51CTO專家博客、講師,濰坊一所高校計算機教師,自由編程界老兵一枚。?
原文標題:??Hyperparameter Optimization — Intro and Implementation of Grid Search, Random Search and Bayesian Optimization??,作者:Farzad Mahmoodinobar?