search
HomeTechnology peripheralsAIDetailed explanation of GQA, the attention mechanism commonly used in large models, and Pytorch code implementation

Grouped Query Attention (Grouped Query Attention) is a multi-query attention method in large language models. Its goal is to achieve the quality of MHA while maintaining the speed of MQA. Grouped Query Attention groups queries so that queries within each group share the same attention weight, which helps reduce computational complexity and increase inference speed.

In this article, we will explain the idea of ​​GQA and how to translate it into code.

GQA is proposed in the paper GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints paper. It is a fairly simple and clean idea, and is built on multi-head attention. above strength.

Detailed explanation of GQA, the attention mechanism commonly used in large models, and Pytorch code implementation

GQA

The standard multi-head attention layer (MHA) consists of H query heads, key heads and Value header composition. Each head has D dimensions. The Pytorch code is as follows:

from torch.nn.functional import scaled_dot_product_attention  # shapes: (batch_size, seq_len, num_heads, head_dim) query = torch.randn(1, 256, 8, 64) key = torch.randn(1, 256, 8, 64) value = torch.randn(1, 256, 8, 64)  output = scaled_dot_product_attention(query, key, value) print(output.shape) # torch.Size([1, 256, 8, 64])

For each query header, there is a corresponding key. This process is shown in the figure below:

Detailed explanation of GQA, the attention mechanism commonly used in large models, and Pytorch code implementation

And GQA divides the query header into G groups, and each group shares a key and value. It can be expressed as:

Detailed explanation of GQA, the attention mechanism commonly used in large models, and Pytorch code implementation

Using visual expression, you can clearly understand the working principle of GQA, just like what we said above. GQA is a fairly simple and clean idea.

Pytorch code implementation

Let us write code to divide the query header into G groups, each group sharing a key and value . We can use the einops library to efficiently perform complex operations on tensors.

First, define the query, keys, and values. Then set the number of attention heads. The number is arbitrary, but it must be ensured that num_heads_for_query % num_heads_for_key = 0, which means it must be divisible. Our definition is as follows:

import torch  # shapes: (batch_size, seq_len, num_heads, head_dim) query = torch.randn(1, 256, 8, 64) key = torch.randn(1, 256, 2, 64) value = torch.randn(1, 256, 2, 64)  num_head_groups = query.shape[2] // key.shape[2] print(num_head_groups) # each group is of size 4 since there are 2 kv_heads

To improve efficiency, swapping the seq_len and num_heads dimensions, einops can be done simply as follows:

from einops import rearrange  query = rearrange(query, "b n h d -> b h n d") key = rearrange(key, "b s h d -> b h s d") value = rearrange(value, "b s h d -> b h s d")

Then the concept of "grouping" needs to be introduced into the query matrix.

from einops import rearrange query = rearrange(query, "b (h g) n d -> b g h n d", g=num_head_groups) print(query.shape) # torch.Size([1, 4, 2, 256, 64])

With the code above we reshape 2D into 2D: for the tensor we defined, the original dimension 8 (the number of heads in the query) is now Split into two groups (to match the number of heads in keys and values), each group size 4.

The last and hardest part is calculating the attention score. But in fact, it can be done in one line through the insum operation. The

from einops import einsum, rearrange # g stands for the number of groups # h stands for the hidden dim # n and s are equal and stands for sequence length scores = einsum(query, key, "b g h n d, b h s d -> b h n s") print(scores.shape) # torch.Size([1, 2, 256, 256])

scores tensor has the same shape as the value tensor above. Let’s see how it works

einsum does two things for us:

1. A query and matrix multiplication of keys . In our case, the shapes of these tensors are (1,4,2,256,64) and (1,2,256,64), so matrix multiplication along the last two dimensions gives us (1,4,2,256,256).

2. Sum the elements in the second dimension (dimension g) - if the dimension is omitted in the specified output shape, einsum will automatically complete this work, so The summation is used to match the number of keys and values ​​in the header.

Finally, note the standard multiplication of fractions and values:

import torch.nn.functional as F  scale = query.size(-1) ** 0.5 attention = F.softmax(similarity / scale, dim=-1)  # here we do just a standard matrix multiplication out = einsum(attention, value, "b h n s, b h s d -> b h n d")  # finally, just reshape back to the (batch_size, seq_len, num_kv_heads, hidden_dim) out = rearrange(out, "b h n d -> b n h d") print(out.shape) # torch.Size([1, 256, 2, 64])

The simplest GQA implementation is now complete, requiring less than 16 lines of python code:

Detailed explanation of GQA, the attention mechanism commonly used in large models, and Pytorch code implementation

Finally, I will briefly mention A word about MQA: Multiple Query Attention (MQA) is another popular method to simplify MHA. All queries will share the same keys and values. The schematic diagram is as follows:

Detailed explanation of GQA, the attention mechanism commonly used in large models, and Pytorch code implementation

As you can see, both MQA and MHA can be derived from GQA. GQA with a single key and value is equivalent to MQA, while GQA with groups equal to the number of headers is equivalent to MHA.

What are the benefits of GQA?

GQA is one between the best performance (MQA) and the best model quality (MHA) Very good trade-off.

The following figure shows that using GQA, you can obtain almost the same model quality as MHA, while increasing the processing time by 3 times, reaching the performance of MQA. This may be essential for high load systems.

Detailed explanation of GQA, the attention mechanism commonly used in large models, and Pytorch code implementation

There is no official implementation of GQA in pytorch. So I found a better unofficial implementation. If you are interested, you can try it:

https://www.php.cn/link/5b52e27a9d5bf294f5b593c4c071500e

GQA paper:

https://www.php.cn/link/e4ba31fba036a999321d5460f7f2d1d1

The above is the detailed content of Detailed explanation of GQA, the attention mechanism commonly used in large models, and Pytorch code implementation. For more information, please follow other related articles on the PHP Chinese website!

Statement
This article is reproduced at:51CTO.COM. If there is any infringement, please contact admin@php.cn delete
Tesla's Robovan Was The Hidden Gem In 2024's Robotaxi TeaserTesla's Robovan Was The Hidden Gem In 2024's Robotaxi TeaserApr 22, 2025 am 11:48 AM

Since 2008, I've championed the shared-ride van—initially dubbed the "robotjitney," later the "vansit"—as the future of urban transportation. I foresee these vehicles as the 21st century's next-generation transit solution, surpas

Sam's Club Bets On AI To Eliminate Receipt Checks And Enhance RetailSam's Club Bets On AI To Eliminate Receipt Checks And Enhance RetailApr 22, 2025 am 11:29 AM

Revolutionizing the Checkout Experience Sam's Club's innovative "Just Go" system builds on its existing AI-powered "Scan & Go" technology, allowing members to scan purchases via the Sam's Club app during their shopping trip.

Nvidia's AI Omniverse Expands At GTC 2025Nvidia's AI Omniverse Expands At GTC 2025Apr 22, 2025 am 11:28 AM

Nvidia's Enhanced Predictability and New Product Lineup at GTC 2025 Nvidia, a key player in AI infrastructure, is focusing on increased predictability for its clients. This involves consistent product delivery, meeting performance expectations, and

Exploring the Capabilities of Google's Gemma 2 ModelsExploring the Capabilities of Google's Gemma 2 ModelsApr 22, 2025 am 11:26 AM

Google's Gemma 2: A Powerful, Efficient Language Model Google's Gemma family of language models, celebrated for efficiency and performance, has expanded with the arrival of Gemma 2. This latest release comprises two models: a 27-billion parameter ver

The Next Wave of GenAI: Perspectives with Dr. Kirk Borne - Analytics VidhyaThe Next Wave of GenAI: Perspectives with Dr. Kirk Borne - Analytics VidhyaApr 22, 2025 am 11:21 AM

This Leading with Data episode features Dr. Kirk Borne, a leading data scientist, astrophysicist, and TEDx speaker. A renowned expert in big data, AI, and machine learning, Dr. Borne offers invaluable insights into the current state and future traje

AI For Runners And Athletes: We're Making Excellent ProgressAI For Runners And Athletes: We're Making Excellent ProgressApr 22, 2025 am 11:12 AM

There were some very insightful perspectives in this speech—background information about engineering that showed us why artificial intelligence is so good at supporting people’s physical exercise. I will outline a core idea from each contributor’s perspective to demonstrate three design aspects that are an important part of our exploration of the application of artificial intelligence in sports. Edge devices and raw personal data This idea about artificial intelligence actually contains two components—one related to where we place large language models and the other is related to the differences between our human language and the language that our vital signs “express” when measured in real time. Alexander Amini knows a lot about running and tennis, but he still

Jamie Engstrom On Technology, Talent And Transformation At CaterpillarJamie Engstrom On Technology, Talent And Transformation At CaterpillarApr 22, 2025 am 11:10 AM

Caterpillar's Chief Information Officer and Senior Vice President of IT, Jamie Engstrom, leads a global team of over 2,200 IT professionals across 28 countries. With 26 years at Caterpillar, including four and a half years in her current role, Engst

New Google Photos Update Makes Any Photo Pop With Ultra HDR QualityNew Google Photos Update Makes Any Photo Pop With Ultra HDR QualityApr 22, 2025 am 11:09 AM

Google Photos' New Ultra HDR Tool: A Quick Guide Enhance your photos with Google Photos' new Ultra HDR tool, transforming standard images into vibrant, high-dynamic-range masterpieces. Ideal for social media, this tool boosts the impact of any photo,

See all articles

Hot AI Tools

Undresser.AI Undress

Undresser.AI Undress

AI-powered app for creating realistic nude photos

AI Clothes Remover

AI Clothes Remover

Online AI tool for removing clothes from photos.

Undress AI Tool

Undress AI Tool

Undress images for free

Clothoff.io

Clothoff.io

AI clothes remover

Video Face Swap

Video Face Swap

Swap faces in any video effortlessly with our completely free AI face swap tool!

Hot Tools

ZendStudio 13.5.1 Mac

ZendStudio 13.5.1 Mac

Powerful PHP integrated development environment

mPDF

mPDF

mPDF is a PHP library that can generate PDF files from UTF-8 encoded HTML. The original author, Ian Back, wrote mPDF to output PDF files "on the fly" from his website and handle different languages. It is slower than original scripts like HTML2FPDF and produces larger files when using Unicode fonts, but supports CSS styles etc. and has a lot of enhancements. Supports almost all languages, including RTL (Arabic and Hebrew) and CJK (Chinese, Japanese and Korean). Supports nested block-level elements (such as P, DIV),

SublimeText3 Mac version

SublimeText3 Mac version

God-level code editing software (SublimeText3)

PhpStorm Mac version

PhpStorm Mac version

The latest (2018.2.1) professional PHP integrated development tool

Dreamweaver CS6

Dreamweaver CS6

Visual web development tools