JORA: JAX Tensor-Parallel LoRA Library for Retrieval Augmented Fine-Tuning

要約

検索ベースのタスク、特に検索拡張生成 (RAG) における大規模言語モデル (LLM) のスケーリングは、特に広範なプロンプト シーケンスを微調整する場合に、重大なメモリ制約に直面します。
現在のオープンソース ライブラリは、複数の GPU にわたるフルモデルの推論と微調整をサポートしていますが、取得されたコンテキストに必要な効率的なパラメータ分散に対応するには至っていません。
このギャップに対処するために、分散トレーニングを活用して、PEFT 互換の Llama-2 モデルの微調整のための新しいフレームワークを導入します。
私たちのフレームワークは、効率的なリソース管理のために JAX のジャストインタイム (JIT) コンパイルとテンソルシャーディングを独自に利用しており、それによってメモリ要件を削減しながら微調整を高速化できます。
この進歩により、GPU リソースが限られたシステム上でも、複雑な RAG アプリケーション向けに LLM を微調整するスケーラビリティと実現可能性が大幅に向上します。
私たちの実験では、4 つの GPU を使用した Hugging Face/DeepSpeed の実装と比較して、実行時間が 12 倍以上向上し、GPU あたりの VRAM 消費量が半分未満であることがわかりました。
私たちのライブラリはやがてオープンソース化される予定です。

要約(オリジナル)

The scaling of Large Language Models (LLMs) for retrieval-based tasks, particularly in Retrieval Augmented Generation (RAG), faces significant memory constraints, especially when fine-tuning extensive prompt sequences. Current open-source libraries support full-model inference and fine-tuning across multiple GPUs but fall short of accommodating the efficient parameter distribution required for retrieved context. Addressing this gap, we introduce a novel framework for PEFT-compatible fine-tuning of Llama-2 models, leveraging distributed training. Our framework uniquely utilizes JAX’s just-in-time (JIT) compilation and tensor-sharding for efficient resource management, thereby enabling accelerated fine-tuning with reduced memory requirements. This advancement significantly improves the scalability and feasibility of fine-tuning LLMs for complex RAG applications, even on systems with limited GPU resources. Our experiments show more than 12x improvement in runtime compared to Hugging Face/DeepSpeed implementation with four GPUs while consuming less than half the VRAM per GPU. Our library will be open-sourced in due course.

arxiv情報

著者 Anique Tahir,Lu Cheng,Huan Liu
発行日 2024-03-17 23:02:04+00:00
arxivサイト arxiv_id(pdf)

提供元, 利用サービス

arxiv.jp, Google

カテゴリー: cs.CL, cs.DC, cs.LG パーマリンク