>  기사  >  기술 주변기기  >  Mamba 저자의 새로운 작업: Llama3을 하이브리드 선형 RNN으로 증류

Mamba 저자의 새로운 작업: Llama3을 하이브리드 선형 RNN으로 증류

王林
王林원래의
2024-09-02 13:41:30870검색

딥 러닝 분야에서 Transformer가 큰 성공을 거둘 수 있었던 열쇠는 어텐션 메커니즘입니다. 어텐션 메커니즘을 통해 Transformer 기반 모델은 입력 시퀀스와 관련된 부분에 집중하여 더 나은 상황 이해를 달성할 수 있습니다. 그러나 어텐션 메커니즘의 단점은 계산 오버헤드가 높고 입력 크기에 따라 2차적으로 증가하여 Transformer가 매우 긴 텍스트를 처리하기 어렵게 만든다는 것입니다.

얼마 전 Mamba의 출현으로 이러한 상황이 깨졌고, 이는 컨텍스트 길이가 증가함에 따라 선형 확장을 달성할 수 있습니다. Mamba가 출시됨에 따라 이러한 상태 공간 모델(SSM)은 이미 중소 규모에서 Transformer와 일치하거나 능가할 수 있을 뿐만 아니라 시퀀스 길이에 대한 선형 확장성을 유지하여 Mamba에 유리한 배포 특성을 제공합니다.

간단히 말하면 Mamba는 먼저 입력에 따라 SSM을 다시 매개변수화할 수 있는 간단하지만 효과적인 선택 메커니즘을 도입하여 모델이 관련 없는 정보 및 관련 데이터를 필터링하면서 필요한 정보를 무기한 유지할 수 있도록 합니다.

최근 "The Mamba in the Llama: Distilling and Acceleating Hybrid Models"라는 제목의 논문에서는 Attention 레이어의 가중치를 재사용함으로써 대형 트랜스포머를 대형 하이브리드 선형 RNN으로 증류할 수 있음을 입증했습니다. 대부분의 빌드 품질을 유지하면서.

주의 계층의 4분의 1을 포함하는 결과 하이브리드 모델은 채팅 벤치마크에서 원래 Transformer와 비슷한 성능을 달성하고 채팅 벤치마크 및 일반 벤치마크의 데이터 사용 성능을 능가합니다. 오픈 소스 하이브리드 Mamba 모델. 1조 개의 토큰으로 처음부터 훈련되었습니다. 또한 이 연구에서는 Mamba 및 하이브리드 모델에 대한 추론 속도를 높이는 하드웨어 인식 추측 디코딩 알고리즘을 제안합니다.

Mamba作者新作:将Llama3蒸馏成混合线性 RNN

논문 주소: https://arxiv.org/pdf/2408.15237

이 연구에서 가장 성능이 좋은 모델은 Llama3-8B-Instruct Distilled에서 나온 것입니다. , GPT-4에 비해 AlpacaEval 2에서 29.61의 길이 제어 승률을 달성했으며 MT-Bench에서 7.35의 승률을 달성하여 최고의 명령 조정 선형 RNN 모델을 능가했습니다.

방법

KD(Knowledge Distillation)는 대형 모델(교사 모델)에서 소형 모델(학생 모델) 모델로 지식을 전달하는 데 사용되는 모델 압축 기술입니다. ), 이는 교사 네트워크의 행동을 모방하도록 학생 네트워크를 훈련시키는 것을 목표로 합니다. 이 연구의 목표는 Transformer의 성능이 원래 언어 모델과 비슷하도록 증류하는 것입니다.

본 연구에서는 점진적 증류, 감독된 미세 조정 및 방향성 선호 최적화를 결합한 다단계 증류 방법을 제안합니다. 일반 증류와 비교하여 이 방법은 더 나은 혼란과 다운스트림 평가 결과를 얻을 수 있습니다.

본 연구에서는 Transformer의 지식 대부분이 원본 모델에서 전달된 MLP 계층에 유지된다고 가정하고 증류된 LLM의 미세 조정 및 정렬 단계에 중점을 둡니다. 이 단계에서는 MLP 계층이 동결된 상태로 유지되고 Mamba 계층이 훈련됩니다.

Mamba作者新作:将Llama3蒸馏成混合线性 RNN

이 연구는 선형 RNN과 주의 메커니즘 사이에 자연스러운 연결이 있다고 믿습니다. 어텐션 공식은 소프트맥스를 제거하여 선형화할 수 있습니다.

Mamba作者新作:将Llama3蒸馏成混合线性 RNN

그러나 어텐션을 선형화하면 모델 성능이 저하됩니다. 효율적인 증류 선형 RNN을 설계하기 위해 본 연구에서는 선형 RNN의 용량을 효율적인 방식으로 확장하면서 원래의 Transformer 매개변수화에 최대한 가깝게 접근합니다. 본 연구에서는 새로운 모델이 정확한 원래 주의 함수를 포착하도록 시도하지 않고 대신 선형화된 형태를 증류의 출발점으로 사용합니다.

알고리즘 1에서 볼 수 있듯이 이 연구에서는 Attention 메커니즘의 표준 Q, K, V 헤드를 Mamba 이산화에 직접 입력한 다음 결과 선형 RNN을 적용합니다. 이는 대략적인 초기화를 위해 선형 주의를 사용하는 것으로 생각할 수 있으며 모델이 확장된 숨겨진 상태를 통해 더 풍부한 상호 작용을 학습할 수 있도록 합니다.

Mamba作者新作:将Llama3蒸馏成混合线性 RNN

이 연구에서는 Transformer 어텐션 헤드를 미세 조정된 선형 RNN 레이어로 직접 대체하여 Transformer MLP 레이어를 변경하지 않고 훈련시키지 않습니다. 이 접근 방식은 헤드 간에 키와 값을 공유하는 그룹화된 쿼리 주의와 같은 다른 구성 요소도 처리해야 합니다. 연구팀은 이 아키텍처가 많은 Mamba 시스템에서 사용되는 것과 달리 이 초기화를 통해 모든 주의 블록을 선형 RNN 블록으로 대체할 수 있다는 점에 주목했습니다.

Mamba作者新作:将Llama3蒸馏成混合线性 RNN

이 연구에서는 하드웨어 인식 다단계 생성을 사용하여 선형 RNN 추측 디코딩을 위한 새로운 알고리즘도 제안합니다.

알고리즘 2와 그림 2는 완전한 알고리즘을 보여줍니다. 이 접근 방식은 검증을 위해 캐시에 RNN 숨겨진 상태만 유지하고 다단계 커널의 성공에 따라 느리게 발전합니다. 증류 모델에는 변환기 레이어가 포함되어 있으므로 이 연구에서는 추론적 디코딩을 Attention/RNN 하이브리드 아키텍처로 확장합니다. 이 설정에서 RNN 레이어는 알고리즘 2에 따라 검증을 수행하는 반면 Transformer 레이어는 병렬 검증만 수행합니다.

Mamba作者新作:将Llama3蒸馏成混合线性 RNN

Mamba作者新作:将Llama3蒸馏成混合线性 RNN

본 연구에서는 이 방법의 유효성을 검증하기 위해 Mamba 7B와 Mamba 2.8B를 추측 대상 모델로 사용했습니다. 결과를 표 1에 나타내었다.

Mamba作者新作:将Llama3蒸馏成混合线性 RNN

그림 3은 다단계 커널 자체의 성능 특성을 보여준다.

Mamba作者新作:将Llama3蒸馏成混合线性 RNN

H100 GPU의 가속. 본 연구에서 제안한 알고리즘은 위의 표 1과 같이 Ampere GPU에서 강력한 성능을 보여줍니다. 그러나 H100 GPU에는 큰 과제가 있습니다. 이는 주로 GEMM 작업이 너무 빠르기 때문에 캐싱 및 재계산 작업으로 인한 오버헤드가 더 눈에 띄게 되기 때문입니다. 실제로 연구된 알고리즘의 간단한 구현(여러 개의 서로 다른 커널 호출 사용)은 3090 GPU에서 상당한 속도 향상을 달성했지만 H100에서는 전혀 속도 향상이 없었습니다.

실험 및 결과

본 연구에서는 실험을 위해 두 가지 LLM 채팅 모델을 사용했습니다. Zephyr-7B는 Mistral 7B 모델을 기반으로 미세 조정되었으며 Llama-3 Instruct 8B. 선형 RNN 모델의 경우, 이 연구에서는 Attention Layer가 각각 50%, 25%, 12.5%, 0%인 Mamba와 Mamba2의 하이브리드 버전을 사용하고 0%를 순수 Mamba 모델이라고 부릅니다. Mamba2는 주로 최신 GPU 아키텍처용으로 설계된 Mamba의 아키텍처 변형입니다.

채팅 벤치마크 평가

표 2는 해당 모델의 채팅 벤치마크 성능을 보여준다. 비교 대상이 되는 주요 모델은 대형 Transformer 모델이다. 결과는 다음과 같습니다.

증류된 하이브리드 Mamba 모델(50%)은 MT 벤치마크에서 Teacher 모델과 비슷한 점수를 달성했으며, LC 승률 및 LC 승률 측면에서 AlpacaEval 벤치마크의 Teacher 모델보다 약간 더 우수했습니다. 전체 승률 .

증류된 하이브리드 Mamba(25% 및 12.5%)의 성능은 MT 벤치마크에서 Teacher 모델보다 약간 낮지만 AlpcaaEval에 더 많은 매개변수를 적용하더라도 여전히 일부 대형 Transformer보다 성능이 좋습니다.

증류된 순수(0%) Mamba 모델의 정확도는 크게 떨어집니다.

5T 이상의 토큰을 사용하여 처음부터 훈련하는 Falcon Mamba보다 증류된 하이브리드 모델이 더 나은 성능을 발휘한다는 점은 주목할 가치가 있습니다.

Mamba作者新作:将Llama3蒸馏成混合线性 RNN

일반 벤치마크 평가

제로샘플 평가. 표 3은 LM Eval 벤치마크의 다양한 교사 모델에서 추출한 Mamba 및 Mamba2의 제로샷 성능을 보여줍니다. Llama-3 Instruct 8B에서 증류된 하이브리드 Mamba-Llama3 및 Mamba2-Llama3 모델은 처음부터 훈련된 오픈 소스 TRI Mamba 및 Nvidia Mamba 모델에 비해 성능이 더 좋습니다.

Mamba作者新作:将Llama3蒸馏成混合线性 RNN

벤치마크 평가. 표 4는 증류된 하이브리드 모델의 성능이 Open LLM Leaderboard에서 최고의 오픈 소스 선형 RNN 모델과 일치하는 동시에 GSM8K 및 CRUX에서 해당 오픈 소스 명령 모델보다 성능이 우수하다는 것을 보여줍니다.

Mamba作者新作:将Llama3蒸馏成混合线性 RNN

하이브리드 추측 복호

50% 및 25% 증류 모델에 대해 비투기 기준과 비교하여 본 연구에서는 Zephyr-Hybrid에서 1.8배 이상의 속도 향상을 달성했습니다.

실험에서도 본 연구에서 훈련된 4-layer 드래프트 모델이 더 높은 수신률을 달성하지만 드래프트 모델의 크기가 증가함에 따라 추가 오버헤드도 커지는 것으로 나타났습니다. 후속 작업에서 이 연구는 이러한 초안 모델을 축소하는 데 중점을 둘 것입니다.

Mamba作者新作:将Llama3蒸馏成混合线性 RNN

다른 증류 방법과의 비교: 표 6(왼쪽)은 다양한 모델 변형의 복잡성을 비교합니다. 이 연구는 Ultrachat을 시드 프롬프트로 사용하여 한 시대 내에서 증류를 수행하고 당혹감을 비교했습니다. 더 많은 레이어를 제거하면 상황이 더욱 악화되는 것으로 나타났습니다. 또한 이 연구에서는 증류 방법을 이전 기준선과 비교한 결과 새로운 방법이 더 작은 성능 저하를 보인 반면 Distill Hyena 모델은 훨씬 더 작은 모델을 사용하여 WikiText 데이터 세트에서 훈련되었으며 더 큰 혼동 정도의 성능 저하를 보였다는 사실을 발견했습니다.

표 6(오른쪽)을 보면 SFT나 DPO만 사용하면 그다지 개선되지 않는 반면, SFT + DPO를 사용하면 가장 좋은 점수를 얻을 수 있음을 알 수 있습니다.

Mamba作者新作:将Llama3蒸馏成混合线性 RNN

표 7에서는 여러 가지 모델에 대한 절제 연구를 비교합니다. 표 7(왼쪽)은 다양한 초기화를 사용한 증류 결과를 보여주고, 표 7(오른쪽)은 Mamba를 사용한 점진적 증류 및 인터리빙 Attention 레이어에서 더 작은 이득을 보여줍니다.

Mamba作者新作:将Llama3蒸馏成混合线性 RNN

표 8은 두 가지 초기화 방법을 사용하는 하이브리드 모델의 성능을 비교합니다. 결과는 어텐션 가중치의 초기화가 중요하다는 것을 확인시켜줍니다.

Mamba作者新作:将Llama3蒸馏成混合线性 RNN

표 9는 Mamba 블록이 있는 모델과 없는 모델의 성능을 비교합니다. Mamba 블록이 있는 모델은 Mamba 블록이 없는 모델보다 훨씬 더 나은 성능을 발휘합니다. 이는 Mamba 레이어를 추가하는 것이 중요하며 성능 향상이 단지 남아 있는 Attention 메커니즘 때문만은 아니라는 점을 확인시켜 줍니다.

Mamba作者新作:将Llama3蒸馏成混合线性 RNN

관심 있는 독자는 논문 원문을 읽고 연구 내용에 대해 자세히 알아볼 수 있습니다.

위 내용은 Mamba 저자의 새로운 작업: Llama3을 하이브리드 선형 RNN으로 증류의 상세 내용입니다. 자세한 내용은 PHP 중국어 웹사이트의 기타 관련 기사를 참조하세요!

성명:
본 글의 내용은 네티즌들의 자발적인 기여로 작성되었으며, 저작권은 원저작자에게 있습니다. 본 사이트는 이에 상응하는 법적 책임을 지지 않습니다. 표절이나 침해가 의심되는 콘텐츠를 발견한 경우 admin@php.cn으로 문의하세요.