Java為什么不能真正支持機器/深度學習?到底還欠缺了什么
如何讓團隊開始使用ML以及如何最好地將ML與我們運行的現有系統集成?
實際上沒有用Java構建的ML框架(有DL4J,但我真的不知道有誰使用它,MXNet有一個Scala API而不是Java,而且它不是用Java編寫的,Tensorflow有一個不完整的Java API),但是Java在企業中擁有巨大的使用范圍,在過去的20年中,在全球范圍內投資了數萬億美元的金融服務,交易,電子商務和電信公司 - 這個名單是無窮無盡的。對于機器學習,“第一個公民”編程語言不是Java,而是Python。
就個人而言,我喜歡用Python和Java編寫代碼,但Frank Greco提出了一個讓我思考的有趣問題:
Java還需要什么才可在ML中與Python競爭?如果Java認真對待真正支持機器學習怎么辦?
很重要么?
自1998年以來,就多個企業的變革而言,Java一直處于領先地位 - 網絡,移動,瀏覽器與原生,消息傳遞,i18n和l10n全球化支持,擴展和支持各種企業信息存儲值得一提的是,從關系數據庫到Elasticsearch。
機器學習行業并非如此。Java團隊如果進入ML只能有兩個選擇:
- 在Python中重新訓練/共同訓練。
- 使用供應商API為您的企業系統添加機器學習功能。
這兩種選擇都不是真的很好。第一個需要大量的前期時間和投資加上持續的維護成本,而第二個風險是供應商鎖定,供應商解除支持,引入第三方組件(需要支付網絡價格),這可能是一個性能關鍵系統,并且需要您可以在組織邊界之外共享數據 - 對某些人來說是不行的。
在我看來,最具破壞性的是文化消耗的可能性?- 團隊無法改變他們不理解或無法維護的代碼,Java團隊有可能在企業計算的下一波浪潮機器學習浪潮中落后 。
因此,Java編程語言和平臺擁有一流的機器學習支持是非常重要,如果沒有,Java將面臨被未來5到10年內支持ML的語言慢慢取代的風險。
為什么Python在ML中占據主導地位?
首先,讓我們考慮為什么Python是機器學習和深度學習的主要語言。
我懷疑這一切都始于一個功能 - 列表的切片slicing支持。這種支持是可擴展的:任何實現__getitem__和__setitem__方法的Python類都可以使用這種語法進行切片。下面的代碼段顯示了這個Python功能的強大和自然性。
a = [1, 2, 3, 4, 5, 6, 7, 8]
print(a[1:4])
#returns [2, 3, 4] - -挑選出中間元素的切片
print(a[1:-1])
#returns [2, 3, 4, 5, 6, 7] - 跳過第0和最后一個元素
print(a[4:])
#returns [5, 6, 7, 8] - 終點默認
print(a[:4])
#returns [1, 2, 3, 4] -開始點被默認
當然,還有更多。與舊的Java代碼相比,Python代碼更簡潔,更簡潔。支持未經檢查的異常,開發人員可以輕松地編寫一次性Python腳本來嘗試填充,而不會陷入“一切都是一個類”的Java思維模式中。使用Python很容易。
但是現在我認為是主要因素 - 盡管Python社區在維持2.7和3之間的凝聚力方面做了一頓狗晚餐,但他們在構建設計良好,快速的數字計算庫(NumPy)方面做得更好 。Numpy是圍繞ndarray構建的 - N維數組對象。直接來自文檔:“ NumPy的主要對象是同構多維數組。它是一個元素表(通常是數字),所有相同的類型,由正整數元組索引 “。
NumPy中的所有內容都是將數據放入ndarray然后對其執行操作。NumPy支持多種類型的索引,廣播,矢量化以提高速度,并且通常允許開發人員輕松創建和操作大型數字數組。
下一個片段顯示了ndarray 索引和正在進行的廣播,這些是ML / DL中的核心操作。
import numpy as np
#Simple broadcast example
a = np.array([1.0, 2.0, 3.0])
b = 2.0
c = a * b
print(c)
#returns [ 2. 4. 6.] - the scalar b is automatically promoted / broadcast and applied to the vector a to create c
#return返回[2. 4. 6.] - 標量b被自動提升/廣播并應用于向量a以創建c
#2-d (matrix with rank 2) indexing in NumPy - this extends to Tensors - i.e. rank > 2
y = np.arange(35).reshape(5,7)
print(y)
# array([[ 0, 1, 2, 3, 4, 5, 6],
# [ 7, 8, 9, 10, 11, 12, 13],
# [14, 15, 16, 17, 18, 19, 20],
# [21, 22, 23, 24, 25, 26, 27],
# [28, 29, 30, 31, 32, 33, 34]])
print(y[0,0])
# 單個單元格訪問 - notation is row-major, returns 0
print(y[4,])
# returns all of row 4: array([28, 29, 30, 31, 32, 33, 34])
print(y[:,2])
# returns all of column 2: array([ 2, 9, 16, 23, 30])
處理大型多維數字數組是機器學習編碼的核心,尤其是深度學習。深度神經網絡是節點格和邊格的數字模型。在訓練網絡或對其進行推理時的運行時操作需要快速矩陣乘法。
NumPy已經促成并啟用了更多 - ?scipy,pandas和許多其他依賴于NumPy的庫。領先的深度學習庫(Tensorflow來自谷歌,PyTorch來自Facebook)都投入巨資在Python。Tensorflow還有其他用于Go,Java和JavaScript的API,但它們不完整且被視為不穩定。PyTorch最初是用Lua編寫的,當它們從2017年相當小的語言轉移到主要的Python ML生態系統時,它的受歡迎程度大幅上升。
Python的缺點
Python不是一種完美的語言 - 特別是最流行的Python運行時 - CPython - 具有全局解釋器鎖(GIL),因此性能縮放并不簡單。此外,像PyTorch和Tensorflow這樣的Python DL框架仍然將核心方法交給不透明的實現。
例如,NVidia 的cuDNN庫對PyTorch中[url=
https://pytorch.org/docs/stable/nn.html#rnn]RNN / LSTM實現[/url]的范圍產生了深遠的影響。RNN和LSTM是一種非常重要的DL技術,特別適用于商業應用,因為它們專門用于對順序,可變長度序列進行分類和預測 - 例如網絡點擊流,文本片段,用戶事件等。
為了公平對待Python,這種不透明度/限制幾乎適用于任何未用C或C ++編寫的ML / DL框架。為什么?因為為了獲得核心的最大性能,像矩陣乘法這樣的高頻操作,開發人員盡可能“接近底層冶金工藝”。
Java需要做些什么才能參與競爭?
我建議Java平臺有三個主要的補充,如果存在的話,會促使Java中一個健康且蓬勃發展的機器學習生態系統的萌芽:
1.在核心語言中添加本機索引/切片支持,以與Python的易用性和表現力相媲美,可能以現有的有序集合List
接口為中心。這種支持還需要承認重載以支持#2點。
2.構建Tensor實現 - 可能在java.math包中,但也可以橋接到Collections API。這組類和接口將作為ndarray的等價物,并提供額外的索引支持 - 特別是三種類型的NumPy索引:字段訪問,基本切片和編碼ML所必需的高級索引。
3.支持廣播 - 任意(但兼容)維度的標量和張量。
如果在核心Java語言和運行時中存在這三件事,它將開辟構建“ NumJava ” 的道路,相當于NumPy。巴拿馬項目還可以用于提供對CPU,GPU,TPU等運行的快速張量操作的矢量化低級訪問,以幫助Java ML成為最快的。
我并不是說這些補充是微不足道的 - 遠非如此,但Java平臺的潛在優勢是巨大的。
下面的代碼片段展示了我們的NumPy廣播和索引示例如何在NumJava中使用Tensor類,核心語言支持切片語法,并尊重當前對運算符重載的限制。
//Java廣播的張量
//使用Java 10中的var語法進行簡潔性
// Java不支持運算符重載,所以我們不能做“a * b”
//我們應該將其添加到需求列表中嗎?
var a = new Tensor([1.0, 2.0, 3.0]);
var b = 2.0;
var c = a.mult(b);
/**
* And a snippet showing how the Java Tensor class could look.
*顯示Java Tensor類的外觀的片段。
*/
import static java.math.Numeric.arange;
//arange returns a tensor instance and reshape is defined on tensor
var y = arange(35).reshape(5,7);
System.out.println(y);
// tensor([[ 0, 1, 2, 3, 4, 5, 6],
// [ 7, 8, 9, 10, 11, 12, 13],
// [14, 15, 16, 17, 18, 19, 20],
// [21, 22, 23, 24, 25, 26, 27],
// [28, 29, 30, 31, 32, 33, 34]])
System.out.println(y[0,0]);
// single cell access - notation is row-major, returns 0
System.out.println(y[4,]);
// returns all of row 4 (5th row starting from 0 idx): tensor([28, 29, 30, 31, 32, 33, 34])
System.out.println(y[:,2]);
// returns all of column 2 (3rd col starting from 0 idx): tensor([ 2, 9, 16, 23, 30])
總結
從本文中概述的實用起點開始,我們可以擁有用Java編寫并在JRE上運行的盡可能多的機器/深度學習框架,因為我們有Web,持久性或XML解析器 - 想象一下!我們可以設想Java框架支持卷積神經網絡(CNN)用于前沿計算機視覺,像LSTM這樣的循環神經網絡實現對于順序數據集(對業務至關重要),具有尖端的ML功能,如自動差異化等。然后,這些框架將為下一代企業級系統提供動力并為其提供動力 - 所有這些系統都使用相同的工具 - IDE,測試框架和持續集成。