Transformers meet Neural Algorithmic Reasoners

要約

Transformers は、シンプルかつ効果的なアーキテクチャにより機械学習に革命をもたらしました。
インターネットからの大量のテキスト データセットで Transformers を事前トレーニングすることで、自然言語理解 (NLU) タスクの比類のない一般化が実現しました。
ただし、このような言語モデルは、計算が正確かつ堅牢である必要があるアルゴリズム形式の推論を担当する場合には脆弱なままです。
この制限に対処するために、Transformer の言語理解とグラフ ニューラル ネットワーク (GNN) ベースのニューラル アルゴリズム推論 (NAR) の堅牢性を組み合わせた新しいアプローチを提案します。
このような NAR は、グラフ形式で指定すると、アルゴリズム タスクの汎用ソルバーとして効果的であることが証明されました。
Transformer がそのエンベディングにアクセスできるようにするために、言語モデル内のトークンが NAR からのノード エンベディングに相互参加できるようにする 2 フェーズのトレーニング手順を備えたハイブリッド アーキテクチャを提案します。
結果として得られた TransNAR モデルを CLRS-30 ベンチマークのテキストベース バージョンである CLRS-Text で評価し、配布内外の両方で、アルゴリズム推論において Transformer のみのモデルよりも大幅に向上していることを実証しました。

要約(オリジナル)

Transformers have revolutionized machine learning with their simple yet effective architecture. Pre-training Transformers on massive text datasets from the Internet has led to unmatched generalization for natural language understanding (NLU) tasks. However, such language models remain fragile when tasked with algorithmic forms of reasoning, where computations must be precise and robust. To address this limitation, we propose a novel approach that combines the Transformer’s language understanding with the robustness of graph neural network (GNN)-based neural algorithmic reasoners (NARs). Such NARs proved effective as generic solvers for algorithmic tasks, when specified in graph form. To make their embeddings accessible to a Transformer, we propose a hybrid architecture with a two-phase training procedure, allowing the tokens in the language model to cross-attend to the node embeddings from the NAR. We evaluate our resulting TransNAR model on CLRS-Text, the text-based version of the CLRS-30 benchmark, and demonstrate significant gains over Transformer-only models for algorithmic reasoning, both in and out of distribution.

arxiv情報

著者 Wilfried Bounsi,Borja Ibarz,Andrew Dudzik,Jessica B. Hamrick,Larisa Markeeva,Alex Vitvitskyi,Razvan Pascanu,Petar Veličković
発行日 2024-06-13 16:42:06+00:00
arxivサイト arxiv_id(pdf)

提供元, 利用サービス

arxiv.jp, Google

カテゴリー: cs.CL, cs.LG パーマリンク