在一張 24 GB 的消費級顯卡上用 RLHF 微調(diào) 20B LLMs(24g顯卡是干嘛的)
在一張 24 GB 的消費級顯卡上用 RLHF 微調(diào) 20B LLMs(24g顯卡是干嘛的)
我們很高興正式發(fā)布 trl 與 peft 的集成,使任何人都可以更輕松地使用強化學習進行大型語言模型 (LLM) 微調(diào)!在這篇文章中,我們解釋了為什么這是現(xiàn)有微調(diào)方法的有競爭力的替代方案。
請注意, peft 是一種通用工具,可以應用于許多 ML 用例,但它對 RLHF 特別有趣,因為這種方法特別需要內(nèi)存!
如果你想直接深入研究代碼,請直接在 TRL 的文檔頁面 直接查看示例腳本。
介紹
LLMs & RLHF
LLM 結(jié)合 RLHF (人類反饋強化學習) 似乎是構(gòu)建非常強大的 AI 系統(tǒng) (例如 ChatGPT) 的下一個首選方法。
使用 RLHF 訓練語言模型通常包括以下三個步驟:
- 在特定領(lǐng)域或指令和人類示范語料庫上微調(diào)預訓練的 LLM;
- 收集人類標注的數(shù)據(jù)集,訓練一個獎勵模型;
- 使用 RL (例如 PPO),用此數(shù)據(jù)集和獎勵模型進一步微調(diào)步驟 1 中的 LLM。
ChatGPT 的訓練協(xié)議概述,從數(shù)據(jù)收集到 RL 部分。 資料來源: OpenAI 的 ChatGPT 博文 |
基礎(chǔ) LLM 的選擇在這里是至關(guān)重要的。在撰寫本文時,可以“開箱即用”地用于許多任務的“最佳”開源 LLM 是指令微調(diào) LLMs。著名的模型有: BLOOMZ Flan-T5、Flan-UL2 和 OPT-IML。這些模型的缺點是它們的尺寸。要獲得一個像樣的模型,你至少需要玩 10B 級別的模型,在全精度情況下這將需要高達 40GB GPU 內(nèi)存,只是為了將模型裝在單個 GPU 設(shè)備上而不進行任何訓練!
什么是 TRL?
trl 庫的目的是使 RL 的步驟更容易和靈活,讓每個人可以在他們自己的數(shù)據(jù)集和訓練設(shè)置上用 RL 微調(diào) LM。在許多其他應用程序中,你可以使用此算法微調(diào)模型以生成 正面電影評論、進行 受控生成 或 降低模型的毒性。
使用 trl 你可以在分布式管理器或者單個設(shè)備上運行最受歡迎的深度強化學習算法之一: PPO。我們利用 Hugging Face 生態(tài)系統(tǒng)中的 accelerate 來實現(xiàn)這一點,這樣任何用戶都可以將實驗擴大到一個有趣的規(guī)模。
使用 RL 微調(diào)語言模型大致遵循下面詳述的協(xié)議。這需要有 2 個原始模型的副本; 為避免活躍模型與其原始行為/分布偏離太多,你需要在每個優(yōu)化步驟中計算參考模型的 logits 。這對優(yōu)化過程增加了硬約束,因為你始終需要每個 GPU 設(shè)備至少有兩個模型副本。如果模型的尺寸變大,在單個 GPU 上安裝設(shè)置會變得越來越棘手。
TRL 中 PPO 訓練設(shè)置概述。 |
在 trl 中,你還可以在參考模型和活躍模型之間使用共享層以避免整個副本。 模型解毒示例中展示了此功能的具體示例。
大規(guī)模訓練
大規(guī)模訓練是具有挑戰(zhàn)性的。第一個挑戰(zhàn)是在可用的 GPU 設(shè)備上擬合模型,及其優(yōu)化器狀態(tài)。 單個參數(shù)占用的 GPU 內(nèi)存量取決于其“精度”(或更具體地說是 dtype)。 最常見的 dtype 是 float32 (32 位) 、 float16 和 bfloat16 (16 位)。 最近,“奇異的”精度支持開箱即用的訓練和推理 (具有特定條件和約束),例如 int8 (8 位)。 簡而言之,要在 GPU 設(shè)備上加載一個模型,每十億個參數(shù)在 float32 精度上需要 4GB,在 float16 上需要 2GB,在 int8 上需要 1GB。 如果你想了解關(guān)于這個話題的更多信息,請查看這篇研究深入的 文章。
如果您使用 AdamW 優(yōu)化器,每個參數(shù)需要 8 個字節(jié) (例如,如果您的模型有 1B 個參數(shù),則模型的完整 AdamW 優(yōu)化器將需要 8GB GPU 內(nèi)存 來源)。
許多技術(shù)已經(jīng)被采用以應對大規(guī)模訓練上的挑戰(zhàn)。最熟悉的范式是管道并行、張量并行和數(shù)據(jù)并行。
圖片來自 這篇博文 |
通過數(shù)據(jù)并行性,同一模型并行托管在多臺機器上,并且每個實例都被提供不同的數(shù)據(jù)批次。 這是最直接的并行策略,本質(zhì)上是復制單 GPU 的情況,并且已經(jīng)被 trl 支持。 使用管道并行和張量并行,模型本身分布在機器上: 在管道并行中,模型按層拆分,而張量并行則跨 GPU 拆分張量操作 (例如矩陣乘法)。使用這些模型并行策略,你需要將模型權(quán)重分片到許多設(shè)備上,這需要你定義跨進程的激活和梯度的通信協(xié)議。 這實現(xiàn)起來并不簡單,可能需要采用一些框架,例如 Megatron-DeepSpeed 或 Nemo。其他對擴展訓練至關(guān)重要的工具也需要被強調(diào),例如自適應激活檢查點和融合內(nèi)核。 可以在 擴展閱讀 找到有關(guān)并行范式的進一步閱讀。
因此,我們問自己下面一個問題: 僅用數(shù)據(jù)并行我們可以走多遠?我們能否使用現(xiàn)有的工具在單個設(shè)備中適應超大型訓練過程 (包括活躍模型、參考模型和優(yōu)化器狀態(tài))? 答案似乎是肯定的。 主要因素是: 適配器和 8 位矩陣乘法! 讓我們在以下部分中介紹這些主題:
8 位矩陣乘法
高效的 8 位矩陣乘法是論文 LLM.int8() 中首次引入的一種方法,旨在解決量化大規(guī)模模型時的性能下降問題。 所提出的方法將在線性層中應用的矩陣乘法分解為兩個階段: 在 float16 中將被執(zhí)行的異常值隱藏狀態(tài)部分和在 int8 中被執(zhí)行的“非異常值”部分。
高效的 8 位矩陣乘法是論文 LLM.int8() 中首次引入的一種方法,旨在解決量化大規(guī)模模型時的性能下降問題。 所提出的方法將在線性層中應用的矩陣乘法分解為兩個階段: 在 float16 中被執(zhí)行的異常值隱藏狀態(tài)部分和在 int8 中被執(zhí)行的“非異常值”部分。 |
簡而言之,如果使用 8 位矩陣乘法,則可以將全精度模型的大小減小到 4 分之一 (因此,對于半精度模型,可以減小 2 分之一)。
低秩適配和 PEFT
在 2021 年,一篇叫 LoRA: Low-Rank Adaption of Large Language Models 的論文表明,可以通過凍結(jié)預訓練權(quán)重,并創(chuàng)建查詢和值層的注意力矩陣的低秩版本來對大型語言模型進行微調(diào)。這些低秩矩陣的參數(shù)遠少于原始模型,因此可以使用更少的 GPU 內(nèi)存進行微調(diào)。 作者證明,低階適配器的微調(diào)取得了與微調(diào)完整預訓練模型相當?shù)慕Y(jié)果。
原始 (凍結(jié)的) 預訓練權(quán)重 (左) 的輸出激活由一個由權(quán)重矩陣 A 和 B 組成的低秩適配器 (右) 增強。 |
這種技術(shù)允許使用一小部分內(nèi)存來微調(diào) LLM。 然而,也有一些缺點。由于適配器層中的額外矩陣乘法,前向和反向傳遞的速度大約是原來的兩倍。
什么是 PEFT?
Parameter-Efficient Fine-Tuning (PEFT) 是一個 Hugging Face 的庫,它被創(chuàng)造出來以支持在 LLM 上創(chuàng)建和微調(diào)適配器層。 peft 與 Accelerate 無縫集成,用于利用了 DeepSpeed 和 Big Model Inference 的大規(guī)模模型。
此庫支持很多先進的模型,并且有大量的例子,包括:
- 因果語言建模
- 條件生成
- 圖像分類
- 8 位 int8 訓練
- Dreambooth 模型的低秩適配
- 語義分割
- 序列分類
- 詞符分類
該庫仍在廣泛和積極的開發(fā)中,許多即將推出的功能將在未來幾個月內(nèi)公布。
使用低質(zhì)適配器微調(diào) 20B 參數(shù)量的模型
現(xiàn)在先決條件已經(jīng)解決,讓我們一步步過一遍整個管道,并用圖說明如何在單個 24GB GPU 上使用上述工具使用 RL 微調(diào) 20B 參數(shù)量的 LLM!
第 1 步: 在 8 位精度下加載你的活躍模型
與全精度模型相比,以 8 位精度加載模型最多可節(jié)省 4 倍的內(nèi)存 |
使用 transformers 減少 LLM 內(nèi)存的“免費午餐”是使用 LLM.int8 中描述的方法,以 8 位精度加載模型。 這可以通過在調(diào)用 from_pretrained 方法時簡單地添加標志 load_in_8bit=True 來執(zhí)行 (你可以在 文檔中 閱讀更多相關(guān)信息)。
如前一節(jié)所述,計算加載模型所需的 GPU 內(nèi)存量的“技巧”是根據(jù)“十億個參數(shù)量”進行思考。 由于一個字節(jié)需要 8 位,因此全精度模型 (32 位 = 4 字節(jié)) 每十億個參數(shù)需要 4GB,半精度模型每十億個參數(shù)需要 2GB,int8 模型每十億個參數(shù)需要 1GB。
所以首先,我們只加載 8 位的活躍模型。 讓我們看看第二步需要做什么!
第 2 步: 使用peft添加額外可訓練的適配器
您可以輕松地在凍結(jié)的 8 位模型上添加適配器,從而通過訓練一小部分參數(shù)來減少優(yōu)化器狀態(tài)的內(nèi)存需求 |
第二步是在模型中加載適配器并使這些適配器可訓練。 這可以大幅減少活躍模型所需的可訓練權(quán)重的數(shù)量。 此步驟利用 peft 庫,只需幾行代碼即可執(zhí)行。 請注意,一旦適配器經(jīng)過訓練,您就可以輕松地將它們推送到 Hub 以供以后使用。
第 3 步: 使用同樣的模型得到參考和活躍 logits
你可以方便地使用 peft 關(guān)閉和使能適配器。 |
由于適配器可以停用,我們可以使用相同的模型來獲取 PPO 的參考和活躍的 logits 值,而無需創(chuàng)建兩個相同的模型副本! 這利用了 peft 庫中的一個功能,即 disable_adapters 上下文管理器。
訓練腳本概述
我們現(xiàn)在將描述如何使用 transformers 、 peft 和 trl 訓練 20B 參數(shù)量的 gpt-neox 模型。 這個例子的最終目標是微調(diào) LLM 以在內(nèi)存受限的設(shè)置中生成積極的電影評論。類似的步驟可以應用于其他任務,例如對話模型。
整體來看有三個關(guān)鍵步驟和訓練腳本:
- 腳本 1 – 在凍結(jié)的 8 位模型上微調(diào)低秩適配器,以便在 imdb 數(shù)據(jù)集上生成文本。
- 腳本 2 – 將適配器層合并到基礎(chǔ)模型的權(quán)重中并將它們存儲在 Hub 上。
- 腳本 3 – 對低等級適配器進行情感微調(diào)以創(chuàng)建正面評價。
我們在 24GB NVIDIA 4090 GPU 上測試了這些步驟。雖然可以在 24 GB GPU 上執(zhí)行整個訓練過程,但在 研究集群上的單個 A100 上無法進行完整的訓練過程。
訓練過程的第一步是對預訓練模型進行微調(diào)。 通常這需要幾個高端的 80GB A100 GPU,因此我們選擇訓練低階適配器。 我們將其視為一種因果語言建模設(shè)置,并針從 imdb 數(shù)據(jù)集中訓練了一個 epoch 的示例,該數(shù)據(jù)集具有電影評論和指明積極還是消極情緒的標簽。
在 imdb 數(shù)據(jù)集上訓練 gpt-neox-20b 模型一個 epoch 期間的訓練損失 |
為了利用已經(jīng)適配了的模型并使用 RL 執(zhí)行進一步微調(diào),我們首先需要組合自適應權(quán)重,這是通過加載預訓練模型和 16 位浮點精度的適配器并且加和權(quán)重矩陣 (應用適當?shù)目s放比例) 來實現(xiàn)的。
最后,我們可以在凍結(jié)的、用 imdb 微調(diào)過的模型之上微調(diào)另一個低秩適配器。我們使用一個 imdb 情感分類器 來為 RL 算法提供獎勵。
RL 微調(diào) peft 適配過的 20B 參數(shù)量的模型以生成積極影評時的獎勵均值。 |
如果您想查看更多圖表和文本生成,可在 鏈接處 獲取此實驗的完整權(quán)重和偏差報告。
結(jié)論
我們在 trl 中實現(xiàn)了一項新功能,允許用戶利用 peft 和 bitsandbytes 庫以合理的成本使用 RLHF 微調(diào)大型語言模型。 我們證明了可以在 24GB 消費級 GPU 上微調(diào) gpt-neo-x (以 bfloat16 精度需要 40GB!),我們期望社區(qū)將廣泛使用此集成來微調(diào)利用了 RLHF 的大型模型,并分享出色的工件。
我們已經(jīng)為接下來的步驟確定了一些有趣的方向,以挑戰(zhàn)這種集成的極限:
- 這將如何在多 GPU 設(shè)置中擴展? 我們將主要探索這種集成將如何根據(jù) GPU 的數(shù)量進行擴展,是否可以開箱即用地應用數(shù)據(jù)并行,或者是否需要在任何相關(guān)庫上采用一些新功能。
- 我們可以利用哪些工具來提高訓練速度? 我們觀察到這種集成的主要缺點是整體訓練速度。 在未來,我們將持續(xù)探索使訓練更快的可能方向。
參考
- 并行范式: https://hf.co/docs/transformers/v4.17.0/en/parallelism
- transformers 中的 8 位集成: https://hf.co/blog/hf-bitsandbytes-integration
- LLM.int8 論文: https://arxiv.org/abs/2208.07339
- 梯度檢查點解釋: https://docs.aws.amazon.com/sagemaker/latest/dg/model-parallel-extended-features-pytorch-activation-checkpointing.html