要約
検索ベースのタスク、特に検索拡張生成 (RAG) における大規模言語モデル (LLM) のスケーリングは、特に広範なプロンプト シーケンスを微調整する場合に、重大なメモリ制約に直面します。
現在のオープンソース ライブラリは、複数の GPU にわたるフルモデルの推論と微調整をサポートしていますが、取得されたコンテキストに必要な効率的なパラメータ分散に対応するには至っていません。
このギャップに対処するために、分散トレーニングを活用して、PEFT 互換の Llama-2 モデルの微調整のための新しいフレームワークを導入します。
私たちのフレームワークは、効率的なリソース管理のために JAX のジャストインタイム (JIT) コンパイルとテンソルシャーディングを独自に利用しており、それによってメモリ要件を削減しながら微調整を高速化できます。
この進歩により、GPU リソースが限られたシステム上でも、複雑な RAG アプリケーション向けに LLM を微調整するスケーラビリティと実現可能性が大幅に向上します。
私たちの実験では、4 つの GPU を使用した Hugging Face/DeepSpeed の実装と比較して、実行時間が 12 倍以上向上し、GPU あたりの VRAM 消費量が半分未満であることがわかりました。
要約(オリジナル)
The scaling of Large Language Models (LLMs) for retrieval-based tasks, particularly in Retrieval Augmented Generation (RAG), faces significant memory constraints, especially when fine-tuning extensive prompt sequences. Current open-source libraries support full-model inference and fine-tuning across multiple GPUs but fall short of accommodating the efficient parameter distribution required for retrieved context. Addressing this gap, we introduce a novel framework for PEFT-compatible fine-tuning of Llama-2 models, leveraging distributed training. Our framework uniquely utilizes JAX’s just-in-time (JIT) compilation and tensor-sharding for efficient resource management, thereby enabling accelerated fine-tuning with reduced memory requirements. This advancement significantly improves the scalability and feasibility of fine-tuning LLMs for complex RAG applications, even on systems with limited GPU resources. Our experiments show more than 12x improvement in runtime compared to Hugging Face/DeepSpeed implementation with four GPUs while consuming less than half the VRAM per GPU.
arxiv情報
著者 | Anique Tahir,Lu Cheng,Huan Liu |
発行日 | 2024-03-19 16:19:49+00:00 |
arxivサイト | arxiv_id(pdf) |
提供元, 利用サービス
arxiv.jp, Google