Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simple QuaRot proof of concept. #407

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 38 additions & 2 deletions exllamav2/attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
# from exllamav2.util import list_live_tensors, set_snapshot, diff_snapshot, print_vram_usage_peak
# import torch.nn.functional as F

from auto_quarot import hadamard_utils
import fast_hadamard_transform

from typing import TYPE_CHECKING
if TYPE_CHECKING:
from exllamav2.model import ExLlamaV2
Expand Down Expand Up @@ -61,6 +64,11 @@ class ExLlamaV2Attention(ExLlamaV2Module):

has_norm: bool
has_residual: bool
quarot: bool
kv_quarot: bool
had_K: torch.Tensor
K: int
had_dim: int


class Params:
Expand Down Expand Up @@ -182,7 +190,8 @@ def __init__(self,
key: str,
layer_idx: int,
has_norm: bool = True,
has_residual: bool = True):
has_residual: bool = True,
quarot: bool = False):

super().__init__(model, key)

Expand All @@ -191,6 +200,12 @@ def __init__(self,
self.layer_idx = layer_idx
self.has_norm = has_norm
self.has_residual = has_residual
self.quarot = quarot
self.kv_quarot = True # should be an option

if self.quarot:
self.had_K, self.K = hadamard_utils.get_hadK(model.config.num_attention_heads)
self.had_dim = model.config.hidden_size // model.config.num_attention_heads

self.q_handle = None
self.temp_lora_size = 0
Expand Down Expand Up @@ -315,6 +330,9 @@ def load(self):
self.model.config.arch.rope_neox_style,
q_norm,
k_norm)

if self.quarot and self.had_K is not None:
self.had_K = self.had_K.to(self.device_idx)


def unload(self):
Expand Down Expand Up @@ -447,7 +465,7 @@ def forward(self,

global has_flash_attn

if self.q_handle is None or intermediates:
if self.quarot or self.kv_quarot or self.q_handle is None or intermediates:
return self.forward_torch(hidden_states,
cache,
attn_params,
Expand Down Expand Up @@ -750,6 +768,12 @@ def forward_torch(self,
ext_c.rope_(query_states, constants.sin, constants.cos, past_len, num_attention_heads, head_dim, position_offsets, self.model.config.arch.rope_neox_style)
ext_c.rope_(key_states, constants.sin, constants.cos, past_len, num_key_value_heads, head_dim, position_offsets, self.model.config.arch.rope_neox_style)

# Add another rotation for the keys/queries

if self.kv_quarot:
query_states = fast_hadamard_transform.hadamard_transform(query_states.float(), scale=1/math.sqrt(query_states.shape[-1])).to(query_states.dtype)
key_states = fast_hadamard_transform.hadamard_transform(key_states.float(), scale=1/math.sqrt(key_states.shape[-1])).to(key_states.dtype)

# Add keys and values to cache

if cache is not None:
Expand Down Expand Up @@ -800,6 +824,18 @@ def forward_torch(self,
if cache is not None:
cache.store_kv_state(self.layer_idx, batch_size, past_len, q_len)

# QuaRot before output

if self.quarot:
init_shape = attn_output.shape
if self.K == 1:
attn_output = fast_hadamard_transform.hadamard_transform(attn_output.reshape(-1, init_shape[-1]//self.had_dim, self.had_dim).transpose(1, 2),
scale=1/math.sqrt(init_shape[-1]//self.had_dim)).transpose(1, 2)
else:
attn_output = (self.had_K.to(attn_output.dtype) @ attn_output.reshape(-1, init_shape[-1]//self.had_dim, self.had_dim)) / math.sqrt(init_shape[-1]//self.had_dim)

attn_output = attn_output.reshape(init_shape)

# Output projection

attn_proj = self.o_proj.forward(attn_output, loras = loras)
Expand Down
5 changes: 5 additions & 0 deletions exllamav2/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ class ExLlamaV2Config:

checkpoint_fused_mlp: bool

quarot: bool = False

def __init__(self,
model_dir: str | None = None):
Expand Down Expand Up @@ -200,6 +201,10 @@ def prepare(self, no_tensors: bool = False):
"model_max_length",
"max_position_embeddings",
"max_seq_len"], 2048)

quarot_config = read(read_config, dict, "quarot_config", None)
if quarot_config is not None:
self.quarot = quarot_config["rotated"]

rs = read(read_config, dict, "rope_scaling", None)
if rs and "factor" in rs:
Expand Down
20 changes: 18 additions & 2 deletions exllamav2/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
if TYPE_CHECKING:
from exllamav2.model import ExLlamaV2

from auto_quarot import hadamard_utils


class ExLlamaV2MLP(ExLlamaV2Module):

Expand All @@ -30,26 +32,34 @@ class ExLlamaV2MLP(ExLlamaV2Module):

has_norm: bool
has_residual: bool
quarot: bool
had_K: torch.Tensor
K: int

def __init__(self,
model: ExLlamaV2,
key: str,
layer_idx: int,
has_norm: bool = True,
has_residual: bool = True):
has_residual: bool = True,
quarot: bool = False):

super().__init__(model, key)

self.layer_idx = layer_idx
self.has_norm = has_norm
self.has_residual = has_residual
self.quarot = quarot

self.q_handle = None
self.temp_lora_size = 0

hidden_size = self.model.config.hidden_size
intermediate_size = self.model.config.intermediate_size

if self.quarot:
self.had_K, self.K = hadamard_utils.get_hadK(intermediate_size)

if self.has_norm:
if self.model.config.arch.norm == "layernorm":
self.post_attention_layernorm = ExLlamaV2LayerNorm(model, key + self.model.config.arch.norm_key_2)
Expand Down Expand Up @@ -136,6 +146,9 @@ def load(self):
self.model.config.max_input_len * self.model.config.max_batch_size,
self.model.config.arch.mlp_act_func == "gelu",
self.has_residual)

if self.quarot:
self.had_K = self.had_K.to(self.device_idx)


def unload(self):
Expand Down Expand Up @@ -224,7 +237,7 @@ def forward(self,
loras: list[ExLlamaV2Lora] | None = None,
**kwargs) -> torch.Tensor | dict[str: torch.Tensor]:

if self.q_handle is None or intermediates:
if self.quarot or self.q_handle is None or intermediates:
return self.forward_torch(hidden_states, cache, attn_params, past_len, intermediates, loras = loras, **kwargs)

if loras is None or self.temp_lora_size == 0:
Expand Down Expand Up @@ -271,6 +284,9 @@ def forward_torch(self,
elif self.model.config.arch.mlp_act_func == "gelu":
y = F.gelu(up)

if self.quarot:
y = hadamard_utils.matmul_hadU_cuda(y, self.had_K, self.K)

down = self.down_proj.forward(y, loras = loras)
hidden_states = down + residual if self.has_residual else down

Expand Down
4 changes: 2 additions & 2 deletions exllamav2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,9 @@ def __init__(self, config: ExLlamaV2Config, lazy_load = False):
pd = ExLlamaV2ParallelDecoder(self, layer_key, layer_idx)
self.modules += [pd]
else:
attn = ExLlamaV2Attention(self, layer_key, layer_idx)
attn = ExLlamaV2Attention(self, layer_key, layer_idx, quarot=self.config.quarot)
if self.config.arch.is_moe: mlp = ExLlamaV2MoEMLP(self, layer_key, layer_idx)
else: mlp = ExLlamaV2MLP(self, layer_key, layer_idx)
else: mlp = ExLlamaV2MLP(self, layer_key, layer_idx, quarot=self.config.quarot)
self.modules += [attn, mlp]

if self.config.arch.norm == "layernorm": norm = ExLlamaV2LayerNorm(self, "model.norm")
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ pygments
websockets
regex
numpy
tokenizers
tokenizers
git+https://github.com/sgsdxzy/AutoQuarot.git