返回顶部
热门问答 更多热门问答
技术文章 更多技术文章

使用quarot量化qwen3并实现在线推理

[复制链接]
链载Ai 显示全部楼层 发表于 2 小时前 |阅读模式 打印 上一主题 下一主题


代码已上传:https://github.com/taishan1994/LLM-Quantization#

quarot旋转量化如果加入在线旋转,则需要修改模型的forward。这里我们使用在线旋转并且适配transformers推理。

  • 使用的量化框架:llmc
  • 使用的推理框架:transformers(当然也可以替换成vllm和sglang进行适配)

首先按照llmc给的环境安装基础环境,然后在configs/quantization下新建一个mine文件夹,里面定义好在线旋转量化的配置:quarot_w_a.ymlllmc目前只支持opt和llama模型的在线旋转,因此需要将model的type设置为Llama(qwen3和llama的结构基本差不多)

base:
seed:&seed42
model:
typelama
path:/data/gongoubo/checkpoints/Qwen/Qwen3-8B
torch_dtype:auto
quant:
methoduarot
weight:
bit:8
symmetric:True
granularity:per_channel
group_size:-1
calib_algo:minmax
act:
bit:8
symmetric:True
granularity:per_token
special:
rotate_mode:hadamard
fp32_had:True
online_rotate:True
save:
save_trans:True
save_fake:True
save_vllm:True
save_path:/data/gongoubo/checkpoints/Qwen/llmc/Qwen3-8B-w8a8-online

我们采用w8a8量化,权重采用per channel,激活采用per-token。需要注意:

  • save_trans:保存quarot旋转但不量化的权重。
  • save_fake:保存quarot旋转并量化后反量化的权重。
  • save_vllm:保存quarot旋转后量化后的权重。

由于transformers不支持w8a8的推理,因此我们使用save_fake保存的权重。

进行在线旋转时,我们需要修改两个地方:attention的v,o以及mlp里面的up,down。

classQwen3MLP(nn.Module):
def__init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]

# self.K = 12
self.had_K, self.K = get_hadK(self.intermediate_size)

had_K_tensor, K_tensor = get_hadK(self.intermediate_size)
self.rotater = Rotater(
online_full_had=True, # for mlps, we use online full hadamard transform
online_partial_had=False,
fp32_had=True,
K=K_tensor,
had_K=had_K_tensor,
had_dim=None, # for mlps, the had_dim is not used
)
print(f'enable online rotate for Qwen2MLP')
# Explicitly move tensors to the correct device and dtype
#target_device = self.gate_proj.weight.device
#self.rotater.had_K = self.rotater.had_K.to(device=target_device, dtype=torch.float32, non_blocking=True)

defforward(self, x):
act = self.act_fn(self.gate_proj(x)) * self.up_proj(x)
# act = (act.float(), self.had_K, self.K).to(x.dtype)
act = self.rotater.rotate(act)
down_proj = self.down_proj(act)
# down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
returndown_proj
classQwen3Attention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""

def__init__(self, config: Qwen3Config, layer_idx: int):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.head_dim = getattr(config,"head_dim", config.hidden_size // config.num_attention_heads)
self.num_head = config.num_attention_heads
self.num_kv_head =config.num_key_value_heads
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
self.scaling = self.head_dim**-0.5
self.attention_dropout = config.attention_dropout
self.is_causal =True

self.q_proj = nn.Linear(
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
)
self.k_proj = nn.Linear(
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
)
self.v_proj = nn.Linear(
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
)
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
self.q_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim!
self.k_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape
self.sliding_window = config.sliding_windowifconfig.layer_types[layer_idx] =="sliding_attention"elseNone

had_K_tensor, K_tensor = get_hadK(
self.num_head
) # for attention, we use partial hadamard transform
print(had_K_tensor, K_tensor, self.num_head)
self.rotater = Rotater(
online_full_had=False, # for attention, we use online partial hadamard transform
online_partial_had=True,
fp32_had=True,
K=K_tensor,
had_K=had_K_tensor,
had_dim=self.head_dim,
)

print("enable Qwen3Attention")

defforward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
)-> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape,-1, self.head_dim)


query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1,2)
key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1,2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1,2)

cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

# 在这里进行在线旋转
init_q_shape = query_states.shape
init_k_shape = key_states.shape

# # print(query_states.shape, key_states.shape)
# query_states = (
# fast_hadamard_transform.hadamard_transform(
# query_states.to(torch.float32),
# scale=1 / math.sqrt(self.head_dim),
# )
# .reshape(init_q_shape)
# .contiguous()
# ).to(value_states.dtype)
#
# key_states = (
# fast_hadamard_transform.hadamard_transform(
# key_states.to(torch.float32),
# scale=1 / math.sqrt(self.head_dim),
# )
# .reshape(init_k_shape)
# .contiguous()
# ).to(value_states.dtype)


ifpast_key_valueisnotNone:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin,"cos": cos,"cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

attention_interface: Callable = eager_attention_forward
ifself.config._attn_implementation !="eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]

attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0ifnotself.trainingelseself.attention_dropout,
scaling=self.scaling,
sliding_window=self.sliding_window, # diff with Llama
**kwargs,
)

attn_output = attn_output.reshape(-1, self.num_head * self.head_dim)
attn_output = self.rotater.rotate(attn_output)
attn_output = attn_output.reshape(*input_shape,-1).contiguous()

attn_output = self.o_proj(attn_output)
returnattn_output, attn_weights

推理时正常推理即可:

importos
os.environ["CUDA_VISIBLE_DEVICES"] ="7"
importtorch
fromloguruimportlogger
fromtransformersimportAutoModelForCausalLM, AutoTokenizer, AutoConfig

model_name ="/data/gongoubo/checkpoints/Qwen/llmc/Qwen3-8B-w8a8-online/fake_quant_model/"
# model_name = "/data/gongoubo/Qwen-1.5-Factory/model_hub/Qwen/Qwen2___5-1___5B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
CONFIG = AutoConfig.from_pretrained(model_name)
# transformers==4.53.0
# from modeling_qwen3 import Qwen3ForCausalLM

process_word_embeddings=False
ifCONFIG.tie_word_embeddings:
CONFIG.tie_word_embeddings =False
process_word_embeddings =True
frommodeling_qwen3_online_llmcimportQwen3ForCausalLM
# from modeling_qwen3_online_r3_r4 import Qwen3ForCausalLM
# from modeling_qwen3 import Qwen3ForCausalLM
model = Qwen3ForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, config=CONFIG).to("cuda:0")

# model.lm_head.weight.data = model.model.embed_tokens.weight.data.clone()
# 确保模型在评估模式
model.eval()

# message = [{"role":"user", "content":"你是谁?"}]
# message = tokenizer.apply_chat_template(message, tokenize=False, add_special_tokens=True)

message ="<|im_start|>user\n你是谁?<|im_end|>\n<|im_start|>assistant\n"
input_ids = tokenizer.encode(message, return_tensors="pt")
input_ids = input_ids.to(model.device)
direct_output = model.generate(input_ids, max_new_tokens=256, do_sample=False, temperature=1)
direct_text = tokenizer.decode(direct_output[0])

print(direct_text)

注意加载模型修改后的qwen3的模型结构。


回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

链载AI是专业的生成式人工智能教程平台。提供Stable Diffusion、Midjourney AI绘画教程,Suno AI音乐生成指南,以及Runway、Pika等AI视频制作与动画生成实战案例。从提示词编写到参数调整,手把手助您从入门到精通。
  • 官方手机版

  • 微信公众号

  • 商务合作

  • Powered by Discuz! X3.5 | Copyright © 2025-2025. | 链载Ai
  • 桂ICP备2024021734号 | 营业执照 | |广西笔趣文化传媒有限公司|| QQ