代码已上传:https://github.com/taishan1994/LLM-Quantization#
quarot旋转量化如果加入在线旋转,则需要修改模型的forward。这里我们使用在线旋转并且适配transformers推理。
- 使用的推理框架:transformers(当然也可以替换成vllm和sglang进行适配)
首先按照llmc给的环境安装基础环境,然后在configs/quantization下新建一个mine文件夹,里面定义好在线旋转量化的配置:quarot_w_a.ymlllmc目前只支持opt和llama模型的在线旋转,因此需要将model的type设置为Llama(qwen3和llama的结构基本差不多)
base:
seed:&seed42
model:
type
lama
path:/data/gongoubo/checkpoints/Qwen/Qwen3-8B
torch_dtype:auto
quant:
method
uarot
weight:
bit:8
symmetric:True
granularity:per_channel
group_size:-1
calib_algo:minmax
act:
bit:8
symmetric:True
granularity:per_token
special:
rotate_mode:hadamard
fp32_had:True
online_rotate:True
save:
save_trans:True
save_fake:True
save_vllm:True
save_path:/data/gongoubo/checkpoints/Qwen/llmc/Qwen3-8B-w8a8-online
我们采用w8a8量化,权重采用per channel,激活采用per-token。需要注意:
- save_trans:保存quarot旋转但不量化的权重。
- save_fake:保存quarot旋转并量化后反量化的权重。
- save_vllm:保存quarot旋转后量化后的权重。
由于transformers不支持w8a8的推理,因此我们使用save_fake保存的权重。
进行在线旋转时,我们需要修改两个地方:attention的v,o以及mlp里面的up,down。
classQwen3MLP(nn.Module):
def__init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
# self.K = 12
self.had_K, self.K = get_hadK(self.intermediate_size)
had_K_tensor, K_tensor = get_hadK(self.intermediate_size)
self.rotater = Rotater(
online_full_had=True, # for mlps, we use online full hadamard transform
online_partial_had=False,
fp32_had=True,
K=K_tensor,
had_K=had_K_tensor,
had_dim=None, # for mlps, the had_dim is not used
)
print(f'enable online rotate for Qwen2MLP')
# Explicitly move tensors to the correct device and dtype
#target_device = self.gate_proj.weight.device
#self.rotater.had_K = self.rotater.had_K.to(device=target_device, dtype=torch.float32, non_blocking=True)
defforward(self, x):
act = self.act_fn(self.gate_proj(x)) * self.up_proj(x)
# act = (act.float(), self.had_K, self.K).to(x.dtype)
act = self.rotater.rotate(act)
down_proj = self.down_proj(act)
# down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
returndown_proj
classQwen3Attention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def__init__(self, config: Qwen3Config, layer_idx: int):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.head_dim = getattr(config,"head_dim", config.hidden_size // config.num_attention_heads)
self.num_head = config.num_attention_heads
self.num_kv_head =config.num_key_value_heads
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
self.scaling = self.head_dim**-0.5
self.attention_dropout = config.attention_dropout
self.is_causal =True
self.q_proj = nn.Linear(
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
)
self.k_proj = nn.Linear(
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
)
self.v_proj = nn.Linear(
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
)
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
self.q_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim!
self.k_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape
self.sliding_window = config.sliding_windowifconfig.layer_types[layer_idx] =="sliding_attention"elseNone
had_K_tensor, K_tensor = get_hadK(
self.num_head
) # for attention, we use partial hadamard transform
print(had_K_tensor, K_tensor, self.num_head)
self.rotater = Rotater(
online_full_had=False, # for attention, we use online partial hadamard transform
online_partial_had=True,
fp32_had=True,
K=K_tensor,
had_K=had_K_tensor,
had_dim=self.head_dim,
)
print("enable Qwen3Attention")
defforward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
)-> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape,-1, self.head_dim)
query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1,2)
key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1,2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1,2)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
# 在这里进行在线旋转
init_q_shape = query_states.shape
init_k_shape = key_states.shape
# # print(query_states.shape, key_states.shape)
# query_states = (
# fast_hadamard_transform.hadamard_transform(
# query_states.to(torch.float32),
# scale=1 / math.sqrt(self.head_dim),
# )
# .reshape(init_q_shape)
# .contiguous()
# ).to(value_states.dtype)
#
# key_states = (
# fast_hadamard_transform.hadamard_transform(
# key_states.to(torch.float32),
# scale=1 / math.sqrt(self.head_dim),
# )
# .reshape(init_k_shape)
# .contiguous()
# ).to(value_states.dtype)
ifpast_key_valueisnotNone:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin,"cos": cos,"cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
attention_interface: Callable = eager_attention_forward
ifself.config._attn_implementation !="eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0ifnotself.trainingelseself.attention_dropout,
scaling=self.scaling,
sliding_window=self.sliding_window, # diff with Llama
**kwargs,
)
attn_output = attn_output.reshape(-1, self.num_head * self.head_dim)
attn_output = self.rotater.rotate(attn_output)
attn_output = attn_output.reshape(*input_shape,-1).contiguous()
attn_output = self.o_proj(attn_output)
returnattn_output, attn_weights
推理时正常推理即可:
importos
os.environ["CUDA_VISIBLE_DEVICES"] ="7"
importtorch
fromloguruimportlogger
fromtransformersimportAutoModelForCausalLM, AutoTokenizer, AutoConfig
model_name ="/data/gongoubo/checkpoints/Qwen/llmc/Qwen3-8B-w8a8-online/fake_quant_model/"
# model_name = "/data/gongoubo/Qwen-1.5-Factory/model_hub/Qwen/Qwen2___5-1___5B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
CONFIG = AutoConfig.from_pretrained(model_name)
# transformers==4.53.0
# from modeling_qwen3 import Qwen3ForCausalLM
process_word_embeddings=False
ifCONFIG.tie_word_embeddings:
CONFIG.tie_word_embeddings =False
process_word_embeddings =True
frommodeling_qwen3_online_llmcimportQwen3ForCausalLM
# from modeling_qwen3_online_r3_r4 import Qwen3ForCausalLM
# from modeling_qwen3 import Qwen3ForCausalLM
model = Qwen3ForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, config=CONFIG).to("cuda:0")
# model.lm_head.weight.data = model.model.embed_tokens.weight.data.clone()
# 确保模型在评估模式
model.eval()
# message = [{"role":"user", "content":"你是谁?"}]
# message = tokenizer.apply_chat_template(message, tokenize=False, add_special_tokens=True)
message ="<|im_start|>user\n你是谁?<|im_end|>\n<|im_start|>assistant\n"
input_ids = tokenizer.encode(message, return_tensors="pt")
input_ids = input_ids.to(model.device)
direct_output = model.generate(input_ids, max_new_tokens=256, do_sample=False, temperature=1)
direct_text = tokenizer.decode(direct_output[0])
print(direct_text)
注意加载模型修改后的qwen3的模型结构。