Python 標(biāo)準(zhǔn)庫中非常有用的裝飾器
眾所周知,Python 語言靈活、簡(jiǎn)潔,對(duì)程序員友好,但在性能上有點(diǎn)不太令人滿意,這一點(diǎn)通過一個(gè)遞歸的求斐波那契額函數(shù)就可以說明:
- def fib(n):
- if n <= 1:
- return n
- return fib(n - 1) + fib(n - 2)
在我的 MBP 上計(jì)算 fib(40) 花費(fèi)了 33 秒:
- import time
- def main():
- start = time.time()
- result = fib(40)
- end = time.time()
- cost = end - start
- print(f"{result = } {cost = :.4f}")
- if __name__ == '__main__':
- main()
但是,假如使用標(biāo)準(zhǔn)庫中的這個(gè)裝飾器,那結(jié)果完全不一樣
- from functools import lru_cache
- @lru_cache
- def fib(n):
- if n <= 1:
- return n
- return fib(n - 1) + fib(n - 2)
這次的結(jié)果是 0 秒,你沒看錯(cuò),我保留了 4 位小數(shù),后面的忽略了。
提升了多少倍?我已經(jīng)計(jì)算不出來了。
為什么 lru_cache 裝飾器這么牛逼,它到底做了什么事情?今天就來聊一聊這個(gè)最有用的裝飾器。
如果看過計(jì)算機(jī)操作系統(tǒng)的話,你對(duì) LRU 一定不會(huì)陌生,這就是著名的最近最久未使用緩存淘汰算法。
而 lru_cache 就是這個(gè)算法的具體實(shí)現(xiàn)。(這個(gè)算法可是面試經(jīng)常考的哦,有的面試官要求現(xiàn)場(chǎng)手寫代碼)
現(xiàn)在,我們來看一個(gè) lru_cache 的源代碼,其中的英文注釋,我已經(jīng)為你翻譯為中文:
- def lru_cache(maxsize=128, typed=False):
- """LRU 緩存裝飾器
- 如果 *maxsize* 是 None, 將不會(huì)淘汰緩存,緩存大小也不做限制
- 如果 *typed* 是 True, 不同類型的參數(shù)將獨(dú)立做緩存,比如 f(3.0) and f(3) 將認(rèn)為是不同的函數(shù)調(diào)用而緩存在兩個(gè)緩存節(jié)點(diǎn)上。
- 函數(shù)的參數(shù)必須可以被 hash
- 查看緩存信息使用的是命名元組 (hits, misses, maxsize, currsize)
- 查看緩存信息:user_func.cache_info(). 清理緩存信息:user_func.cache_clear().
- LRU 算法: http://en.wikipedia.org/wiki/Cache_replacement_policies#Least_recently_used_(LRU)
- """
- # lru_cache 的內(nèi)部實(shí)現(xiàn)是線程安全的
- if isinstance(maxsize, int):
- # 負(fù)數(shù)轉(zhuǎn)換為 0
- if maxsize < 0:
- maxsize = 0
- elif callable(maxsize) and isinstance(typed, bool):
- #如果被裝飾的函數(shù)(user_function)直接通過 maxsize 參數(shù)傳入
- user_function, maxsize = maxsize, 128
- wrapper = _lru_cache_wrapper(user_function, maxsize, typed, _CacheInfo)
- return update_wrapper(wrapper, user_function)
- elif maxsize is not None:
- raise TypeError(
- 'Expected first argument to be an integer, a callable, or None')
- def decorating_function(user_function):
- wrapper = _lru_cache_wrapper(user_function, maxsize, typed, _CacheInfo)
- return update_wrapper(wrapper, user_function)
- return decorating_function
這里面有兩個(gè)參數(shù),一個(gè)是 maxsize,表示緩存的大小,當(dāng)傳入負(fù)數(shù)時(shí),自動(dòng)設(shè)置為 0,如果不傳入 maxsize,或者設(shè)置為 None,表示緩存沒有大小限制,此時(shí)沒有緩存淘汰。還有一個(gè)是 type,當(dāng) type 傳入 True 時(shí),不同的參數(shù)類型會(huì)當(dāng)作不同的 key 存到緩存當(dāng)中。
接下來,lru_cache 的核心在這個(gè)函數(shù)上 _lru_cache_wrapper,建議有感情的閱讀、背誦并默寫。我們來看下它的源代碼
- def _lru_cache_wrapper(user_function, maxsize, typed, _CacheInfo):
- # 所有 lru cache 實(shí)例共享的常量:
- sentinel = object() # 用來表示緩存未命中的唯一對(duì)象
- make_key = _make_key # build a key from the function arguments
- PREV, NEXT, KEY, RESULT = 0, 1, 2, 3 # names for the link fields
- cache = {}
- hits = misses = 0
- full = False
- cache_get = cache.get # 綁定函數(shù)來獲取緩存中 key 的值
- cache_len = cache.__len__ # 綁定函數(shù)獲取緩存大小
- lock = RLock() # 因?yàn)殒湵砩系母率蔷€程不安全的
- root = [] # 循環(huán)雙向鏈表的根節(jié)點(diǎn)
- root[:] = [root, root, None, None] # 初始化根節(jié)點(diǎn)的前后指針都指向它自己
- if maxsize == 0:
- def wrapper(*args, **kwds):
- # 沒有緩存,僅更新統(tǒng)計(jì)信息
- nonlocal misses
- misses += 1
- result = user_function(*args, **kwds)
- return result
- elif maxsize is None:
- def wrapper(*args, **kwds):
- # 僅僅排序,不考慮排序和緩存大小限制
- nonlocal hits, misses
- key = make_key(args, kwds, typed)
- result = cache_get(key, sentinel)
- if result is not sentinel:
- hits += 1
- return result
- misses += 1
- result = user_function(*args, **kwds)
- cache[key] = result
- return result
- else:
- def wrapper(*args, **kwds):
- # 大小有限制,并跟蹤最近使用的緩存
- nonlocal root, hits, misses, full
- key = make_key(args, kwds, typed)
- with lock:
- link = cache_get(key)
- if link is not None:
- # 緩存命中,將命中的緩存移動(dòng)到循環(huán)雙向鏈表的頭部
- link_prev, link_next, _key, result = link
- link_prev[NEXT] = link_next
- link_next[PREV] = link_prev
- last = root[PREV]
- last[NEXT] = root[PREV] = link
- link[PREV] = last
- link[NEXT] = root
- hits += 1
- return result
- misses += 1
- result = user_function(*args, **kwds)
- with lock:
- if key in cache:
- # 走到這里說明 key 已經(jīng)放在了緩存,且鎖已經(jīng)釋放了,鏈表已經(jīng)更新了,這里什么也不需要做了,最后只需要返回計(jì)算的結(jié)果就可以了。
- pass
- elif full:
- # 如果緩存滿了, 使用最老的根節(jié)點(diǎn)來存儲(chǔ)新節(jié)點(diǎn)就可以了,鏈表上不需要?jiǎng)h除(是不是很聰明)
- oldroot = root
- oldroot[KEY] = key
- oldroot[RESULT] = result
- root = oldroot[NEXT]
- oldkey = root[KEY]
- oldresult = root[RESULT]
- root[KEY] = root[RESULT] = None
- # 最后,我們需要從緩存中清除這個(gè) key,因?yàn)樗呀?jīng)無效了。
- del cache[oldkey]
- # 新值放入緩存
- cache[key] = oldroot
- else:
- # 如果沒有滿,將新的結(jié)果放入循環(huán)雙向鏈表的頭部
- last = root[PREV]
- link = [last, root, key, result]
- last[NEXT] = root[PREV] = cache[key] = link
- # 使用 cache_len 綁定方法而不是 len() 函數(shù),后者可能會(huì)被包裝在 lru_cache 本身中
- full = (cache_len() >= maxsize)
- return result
- def cache_info():
- """報(bào)告緩存統(tǒng)計(jì)信息"""
- with lock:
- return _CacheInfo(hits, misses, maxsize, cache_len())
- def cache_clear():
- """清理緩存信息"""
- nonlocal hits, misses, full
- with lock:
- cache.clear()
- root[:] = [root, root, None, None]
- hits = misses = 0
- full = False
- wrapper.cache_info = cache_info
- wrapper.cache_clear = cache_clear
- return wrapper
如果我寫的注釋你都看明白了,那也不用看我下面的廢話了,如果還有點(diǎn)不太明白,我啰嗦幾句,也許你就明白了。
第一、所謂緩存,用的仍然是內(nèi)存,為了快速存取,用的就是一個(gè) hash 表,也就是 Python 的字典,都是在內(nèi)存里的操作。
- cache = {}
第二、如果 maxsize == 0,就相當(dāng)于沒有使用緩存,每調(diào)用一次,未命中數(shù)就 + 1,代碼邏輯是這樣的:
- def wrapper(*args, **kwds):
- nonlocal misses
- misses += 1 # 未命中數(shù)
- result = user_function(*args, **kwds)
- return result
第三、如果 maxsize == None,相當(dāng)于緩存無限制,也就不需要考慮淘汰,這個(gè)實(shí)現(xiàn)非常簡(jiǎn)單,我們直接在函數(shù)中用一個(gè)字典就可以實(shí)現(xiàn),比如說:
- cache = {}
- def fib(n):
- if n in cache:
- return cache[n]
- if n <= 1:
- return n
- result = fib(n - 1) + fib(n - 2)
- cache[n] = result
- return result
運(yùn)行時(shí)間:
理解了這一點(diǎn),在裝飾器中,這段邏輯就不難看懂:
- def wrapper(*args, **kwds):
- nonlocal hits, misses
- key = make_key(args, kwds, typed)
- result = cache_get(key, sentinel)
- if result is not sentinel:
- hits += 1
- return result
- misses += 1
- result = user_function(*args, **kwds)
- cache[key] = result
- return result
第四、真正的緩存淘汰算法。
為了實(shí)現(xiàn)緩存(鍵值對(duì))的淘汰,我們需要對(duì)緩存按時(shí)間進(jìn)行排序,這就需要用到鏈表,鏈表的頭部是最新插入的,尾部是最老插入的,當(dāng)緩存數(shù)量已經(jīng)達(dá)到最大值時(shí),我們刪除最久未使用的鏈尾節(jié)點(diǎn),為了不刪除鏈尾,我們可以使用循環(huán)鏈表,當(dāng)緩存滿了,直接更新鏈尾節(jié)點(diǎn)賦值為新節(jié)點(diǎn),并把它做為新的鏈頭就可以了。
當(dāng)緩存命中時(shí),我們需要把這個(gè)節(jié)點(diǎn)移動(dòng)到鏈表的頭部,保證鏈表的頭部是最近經(jīng)常使用的,為了移動(dòng)方便,我們需要雙向鏈表。
雙向循環(huán)鏈表在 Python 中實(shí)現(xiàn),可以簡(jiǎn)單的這么寫:
- PREV, NEXT, KEY, RESULT = 0, 1, 2, 3 # names for the link fields
- root = [] # root of the circular doubly linked list
- root[:] = [root, root, None, None] # initialize by pointing to self
可能有些朋友看不懂最后那行代碼:root[:] = [root, root, None, None],畫個(gè)圖你就理解了:
這些箭頭指向的都是節(jié)點(diǎn)的內(nèi)存地址,隨著節(jié)點(diǎn)的增多,就是這個(gè)樣子的:
對(duì)比這個(gè)圖,再看源代碼,就很容易看懂了。尤其是這塊的代碼邏輯,是面試常考的重點(diǎn),如果你能手寫出這樣線程安全的 LRU 緩存淘汰算法,那無疑是非常優(yōu)秀的。
其他 LRU 算法的實(shí)現(xiàn)
其他關(guān)于 LRU 算法的實(shí)現(xiàn),我自己寫了兩個(gè),可以看這里:
LRU 緩存淘汰算法-雙鏈表+hash 表[1]
LRU 緩存淘汰算法-Python 有序字典[2]
最后的話
裝飾器 lru_cache 的作用就是把函數(shù)的計(jì)算機(jī)結(jié)果保存下來,下次用的時(shí)候可以直接從 hash 表中取出,避免重復(fù)計(jì)算從而提升效率,簡(jiǎn)單點(diǎn)的,直接在函數(shù)中使用個(gè)字典就搞定了,復(fù)雜點(diǎn)的,請(qǐng)看 lru_cache 的代碼實(shí)現(xiàn)。另一方面,遞歸函數(shù)慢的一個(gè)主要原因就是重復(fù)計(jì)算。
Python 標(biāo)準(zhǔn)庫的源碼,是學(xué)習(xí)編程最有營(yíng)養(yǎng)的原料,當(dāng)你有好奇心時(shí),不妨去窺探一下源碼,相信你有定會(huì)有新的收獲。今天的分享就到這里,如果有收獲的話,請(qǐng)點(diǎn)贊、在看、轉(zhuǎn)發(fā)、關(guān)注,感謝你的支持。
參考資料
[1]
LRU 緩存淘汰算法-雙鏈表+hash 表: https://github.com/somenzz/geekbang/blob/master/algorthms/lru_use_link_table.py
[2]
LRU 緩存淘汰算法-Python 有序字典: https://github.com/somenzz/geekbang/blob/master/algorthms/lru_use_ordered_dict.py