|
OpenAI gpt-oss 模型 on Azure
本月初,我们发布了 gpt-oss 两款模型正式登陆 Azure AI Foundry (国际版),详细发布细节请参考 👉OpenAI 开源模型 gpt-oss 重磅登陆 Foundry 平台
gpt-oss-120b在核心推理基准测试中表现与OpenAI o4-mini 几乎持平,并且可在单个 80 GB GPU(NC_H100单卡机型)上高效运行。
gpt-oss-20b在常见基准测试的结果与 OpenAI o3-mini 相近,并且可以在仅有 16GB 内存的边缘设备上运行,非常适合本地推理、端侧部署或低成本的快速迭代场景。
这两款模型在工具调用、少样本函数调用(few-shot function calling)、CoT 推理(如 Tau-Bench 智能体评估套件上的测试)和 HealthBench (甚至优于 OpenAI o1 和 GPT 4o 等专有模型)方面也表现出色。
在这篇文章中,我将展示两款模型在 A10 和 H100 两款 GPU 上的性能,包括TTFT,每秒 tokens 数等。
假设我们有 8 个数(实际 MX 是 32,为了演示用小样本): [1.0,0.9,1.1,0.95,1.05,1.0,0.92,100.0]
P_i≈[0.0625,...,6.0(saturated)] [ INT4 (无符号示例) ] [ E2M1 (FP4) ]
┌───┬───┬───┬───┐ ┌───┬───┬───┬───┐│b3 │b2 │b1 │b0 │ ←4bits │ S │ E1│ E0│ M0│ ←4bits└───┴───┴───┴───┘ └───┴───┴───┴───┘b3..b0→ 整数值 (0~15) S: 符号位 (0=正,1=负) E1E0:2-bit 指数(Exponent) M0:1-bit 尾数(Mantissa)
在 gpt-oss 中,Sink Token 机制主要用于提升长上下文场景下的推理吞吐量与延迟表现,同时通过防止上下文丢失来保持准确性——需要注意的是,它并不会直接提升模型的推理能力。
在 gpt-oss 的推理流程中(尤其是基于 vLLM 推理时),Sink Token 是插入在输入序列最开头的一个特殊 Token。
全局上下文锚点(Global Context Anchor) Sink Token 会被序列中的所有 Token 关注,它就像是对整个 Prompt / 上下文的高维压缩摘要,为后续推理提供一个统一的参考点。 长上下文效率优化(Long Context Efficiency) 在长上下文推理中,大多数 Token 只会关注最近的 N 个 Token(滑动窗口机制)。但借助 Sink Token,它们仍能快速获取全局信息,而无需重新处理整个历史上下文
您可以把 Sink Token 视为 KV Cache 中的“固定全局记忆单元”——即使使用长上下文关注机制,它也不会被滑动窗口“挤掉”,始终为全局信息留有一个入口。
要让 Sink Token 真正发挥作用,需要一种混合注意力(Hybrid Attention Mask)设计:
实现上的三个关键点:
灵活的注意力 Mask 控制:需要能够为不同 Token 分配不同的注意力范围。 固定 KV Cache 条目:Sink Token 的 KV 缓存位置要被“钉住”,防止在长序列推理中被滑动窗口覆盖。 单次执行融合全局 + 局部计算:避免多次计算带来的性能损失。
这种非对称注意力布局是 Sink Token 在长上下文场景中既能保留推理速度,又能维持上下文连贯性的核心原因。
FlashAttention-3(FA3)的优势 原生支持前缀 + 局部混合注意力掩码(local hybrid attention masks) 可以在同一 kernel 调用中,同时处理 Sink Token 的全局访问和其他 Token 的局部注意力计算。 Hopper 架构优化(H100、L40S),带来更高吞吐量、更低延迟。
FlashAttention-2(FA2)的局限
结论: 在Hopper GPU(H100、L40S)上:vLLM + FA3 是Sink Token 性能的最优组合。 在Ampere GPU(A10、A100)上: FA3 Kernels 可能不完整或不支持,导致性能下降甚至运行失败。 Ollama 通过不运行真正的 Sink Token 逻辑(固定注意模式 + MXFP4 量化),可规避这个问题。
通过将昂贵的上下文重扫描(re-scan)转化为更具成本效益的“全局内存查找”。Sink Token 在超长上下文(≥ 32k tokens)中,显著降低 TTFT(Time To First Token) 和整体吞吐量压力。
FA3 不只是优化方案,在 vLLM 中,若想要高效运行 Sink Token,FA3 是必备条件。
在 Ampere GPU(A10 / A100)上运行 在Hopper GPU(H100 / L40S)上运行,必选 vLLM + FA3 + Sink Token,从而确保性能最大化。
gpt-oss-20b on Azure NV A10 GPU VM
在 gpt-oss 的推理逻辑中引入了 Sink Token。想要高效运行 Sink Token 需要 FA3(FA2 没有相应的内核),Hopper 架构对 FA3 的支持更好,而在 Ampere 上则存在兼容性问题。
在使用 A10时,优选 Ollama,Ollama 版本模型默认使用 MXFP4 量化,简单且节省内存。
如果不进行量化,直接使用具备 BF16 推理能力的 HF transformers,A10 的内存就不够用了。
关于这部分测试的说明:只在 Ollama 上使用了单个 A10 GPU。加载模型前如下图所示:
TTFT<1sThroughput:45~55tokens/s
importrequests, time, json
MODEL ="gpt-oss:20b"PROMPT ="Give me a 2000-word introduction to Ollama."
url ="http://localhost:11434/api/generate"payload = {"model": MODEL,"prompt": PROMPT,"stream":True}
t0 = time.time()first_token_time =Nonetoken_count =0
withrequests.post(url, json=payload, stream=True, timeout=600)asresp: forlineinresp.iter_lines(): ifnot line: continue data = json.loads(line) ifdata.get("done"): break chunk = data.get("response","") ifnot chunk: continue iffirst_token_timeisNone: first_token_time = time.time() token_count +=len(chunk.split()) # 简化统计token,可用tiktoken更精确
t1 = time.time()ttft = first_token_time - t0throughput = token_count/(t1-first_token_time)iffirst_token_time else0print(f"TTFT:{ttft:.3f}s, Tokens:{token_count}, Throughput:{throughput:.2f}tokens/s")
gpt-oss-20b on Azure H100 GPU VM
在 NC H100 上,借助 vLLM,我们能够利用 gpt-oss-20b 模型实现出色的推理性能。
vllmserveopenai/gpt-oss-20b
(oss-20b-tgi) root@h100vm:~# cat stress_test.py #!/usr/bin/envpython3#stress_test.py"""Asynchronouslystress-testalocalvLLMOpenAI-compatibleendpoint.Prerequisites:pipinstall"httpx[http2]"tqdmorjsonAuthor:2025-08-06"""importargparse,asyncio,time,statistics,osimportorjson,httpxfromtqdm.asyncioimporttqdm_asyncioastqdm#tqdm≥4.66ENDPOINT="http://127.0.0.1:8000/v1/chat/completions"HEADERS={"Content-Type":"application/json"}SYSTEM="Youareahelpfulassistant."defbuild_payload(prompt:str,max_tokens:int=128,temp:float=0.0):return{"model":"openai/gpt-oss-20b",#任意字符串也行,只要跟serve时一致"messages":[{"role":"system","content":SYSTEM},{"role":"user","content":prompt},],"temperature":temp,"max_tokens":max_tokens,"stream":False#如需压TTFT可设True,但统计更复杂}asyncdefworker(client:httpx.AsyncClient,payload:dict,latencies:list,ttfts:list,tokens:list,):"""Sendasinglerequestandrecordmetrics."""t0=time.perf_counter()resp=awaitclient.post(ENDPOINT,headers=HEADERS,content=orjson.dumps(payload))t1=time.perf_counter()ifresp.status_code!=200:raiseRuntimeError(f"HTTP{resp.status_code}:{resp.text[:200]}")out=resp.json()#usage字段遵循OpenAI规范usage=out.get("usage",{})c_tok=usage.get("completion_tokens",0)ttft=out.get("ttft",0)#vLLM0.10起返回,若无自行估算ifnotttft:#估算:总时长*(prompt_tokens/全部tokens)粗略近似p_tok=usage.get("prompt_tokens",1)ttft=(p_tok/(p_tok+c_tok+1e-6))*(t1-t0)latencies.append(t1-t0)ttfts.append(ttft)tokens.append(c_tok)asyncdefrun(concurrency:int,total_requests:int,payload:dict):latencies,ttfts,tokens=[],[],[]limits=httpx.Limits(max_connections=concurrency)timeout=httpx.Timeout(60.0)#适当加大asyncwithhttpx.AsyncClient(limits=limits,timeout=timeout,http2=True)asclient:sem=asyncio.Semaphore(concurrency)asyncdef_task(_):asyncwithsem:awaitworker(client,payload,latencies,ttfts,tokens)awaittqdm.gather(*[_task(i)foriinrange(total_requests)])returnlatencies,ttfts,tokensdefmain():ap=argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)ap.add_argument("--concurrency","-c",type=int,default=64,help="numberofconcurrentrequests")ap.add_argument("--requests","-n",type=int,default=1024,help="totalrequeststosend")ap.add_argument("--prompt",type=str,default="Explainquantummechanicsinonesentence.",help="userprompt")ap.add_argument("--max-tokens",type=int,default=128)args=ap.parse_args()payload=build_payload(args.prompt,max_tokens=args.max_tokens)print(f"Startstresstest:{args.requests}requests|"f"concurrency={args.concurrency}|max_tokens={args.max_tokens}")st=time.perf_counter()lat,ttft,toks=asyncio.run(run(args.concurrency,args.requests,payload))et=time.perf_counter()total_time=et-st#──Stats────────────────────────────────────────────────────────────────defpct(lst,p):returnstatistics.quantiles(lst,n=100)[p-1]print("\nRESULTS")print(f"Totalwall-clocktime:{total_time:8.2f}s")print(f"Requests/second:{args.requests/total_time:8.1f}req/s")print(f"Tokens/second:{sum(toks)/total_time:8.1f}tok/s")forname,arrin[("Latency(s)",lat),("TTFT(s)",ttft)]:print(f"{name:<15}p50={statistics.median(arr):.3f}"f"p90={pct(arr,90):.3f}p99={pct(arr,99):.3f}")print("\nDone.")if__name__=="__main__":main()(oss-20b-tgi)root@h100vm:~#(oss-20b-tgi)root@h100vm:~#(oss-20b-tgi)root@h100vm:~#(oss-20b-tgi)root@h100vm:~#(oss-20b-tgi)root@h100vm:~#catstress_test.py#!/usr/bin/envpython3#stress_test.py"""Asynchronouslystress-testalocalvLLMOpenAI-compatibleendpoint.Prerequisites:pipinstall"httpx[http2]"tqdmorjsonAuthor:2025-08-06"""importargparse,asyncio,time,statistics,osimportorjson,httpxfromtqdm.asyncioimporttqdm_asyncioastqdm#tqdm≥4.66ENDPOINT="http://127.0.0.1:8000/v1/chat/completions"HEADERS={"Content-Type":"application/json"}SYSTEM="Youareahelpfulassistant."defbuild_payload(prompt:str,max_tokens:int=128,temp:float=0.0):return{"model":"openai/gpt-oss-20b",#任意字符串也行,只要跟serve时一致"messages":[{"role":"system","content":SYSTEM},{"role":"user","content":prompt},],"temperature":temp,"max_tokens":max_tokens,"stream":False#如需压TTFT可设True,但统计更复杂}asyncdefworker(client:httpx.AsyncClient,payload:dict,latencies:list,ttfts:list,tokens:list,):"""Sendasinglerequestandrecordmetrics."""t0=time.perf_counter()resp=awaitclient.post(ENDPOINT,headers=HEADERS,content=orjson.dumps(payload))t1=time.perf_counter()ifresp.status_code!=200:raiseRuntimeError(f"HTTP{resp.status_code}:{resp.text[:200]}")out=resp.json()#usage字段遵循OpenAI规范usage=out.get("usage",{})c_tok=usage.get("completion_tokens",0)ttft=out.get("ttft",0)#vLLM0.10起返回,若无自行估算ifnotttft:#估算:总时长*(prompt_tokens/全部tokens)粗略近似p_tok=usage.get("prompt_tokens",1)ttft=(p_tok/(p_tok+c_tok+1e-6))*(t1-t0)latencies.append(t1-t0)ttfts.append(ttft)tokens.append(c_tok)asyncdefrun(concurrency:int,total_requests:int,payload:dict):latencies,ttfts,tokens=[],[],[]limits=httpx.Limits(max_connections=concurrency)timeout=httpx.Timeout(60.0)#适当加大asyncwithhttpx.AsyncClient(limits=limits,timeout=timeout,http2=True)asclient:sem=asyncio.Semaphore(concurrency)asyncdef_task(_):asyncwithsem:awaitworker(client,payload,latencies,ttfts,tokens)awaittqdm.gather(*[_task(i)foriinrange(total_requests)])returnlatencies,ttfts,tokensdefmain():ap=argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)ap.add_argument("--concurrency","-c",type=int,default=64,help="numberofconcurrentrequests")ap.add_argument("--requests","-n",type=int,default=1024,help="totalrequeststosend")ap.add_argument("--prompt",type=str,default="Explainquantummechanicsinonesentence.",help="userprompt")ap.add_argument("--max-tokens",type=int,default=128)args=ap.parse_args()payload=build_payload(args.prompt,max_tokens=args.max_tokens)print(f"Startstresstest:{args.requests}requests|"f"concurrency={args.concurrency}|max_tokens={args.max_tokens}")st=time.perf_counter()lat,ttft,toks=asyncio.run(run(args.concurrency,args.requests,payload))et=time.perf_counter()total_time=et-st#──Stats────────────────────────────────────────────────────────────────defpct(lst,p):returnstatistics.quantiles(lst,n=100)[p-1]print("\nRESULTS")print(f"Totalwall-clocktime:{total_time:8.2f}s")print(f"Requests/second:{args.requests/total_time:8.1f}req/s")print(f"Tokens/second:{sum(toks)/total_time:8.1f}tok/s")forname,arrin[("Latency(s)",lat),("TTFT(s)",ttft)]:print(f"{name:<15}p50={statistics.median(arr):.3f}"f"p90={pct(arr,90):.3f}p99={pct(arr,99):.3f}")print("\nDone.")if__name__=="__main__":main()
x(oss-20b-tgi)root@h100vm:~#pythonstress_test.py--concurrency256--requests2000--prompt"Explainquantummechanicsinoneparagraph."--max-tokens256
(oss-20b-tgi)root@h100vm:~#pythonstress_test.py--concurrency256--requests2000--prompt"Explainquantummechanicsinoneparagraph."--max-tokens256 Startstress test:2000requests | concurrency=256| max_tokens=256100%|████████████████████████████████████████████████████████████████████████████████████████████████████████|2000/2000[01:06<00:00, 29.92it/s]
RESULTSTotalwall-clock time : 66.89 sRequests/ second : 29.9 req/sTokens / second : 7645.2 tok/sLatency(s) p50=8.835p90=11.235 p99=14.755TTFT(s) p50=2.271p90=2.874 p99=3.775
Done.
gpt-oss-120b on Azure H100 GPU VM
(gpt-oss)root@h100vm:~#vllmserveopenai/gpt-oss-120b
使用stress_test.py,只将模型从openai/gpt-oss-20b 改为openai/gpt-oss-120b。
(gpt-oss)root@h100vm:~#pythonstress_test-120b.py--concurrency256--requests2000--prompt"Explainquantummechanicsinoneparagraph."--max-tokens128
RESULTSTotalwall-clock time : 60.73 sRequests/ second : 32.9 req/sTokens / second : 4215.6 tok/sLatency(s) p50=8.254p90=10.479 p99=11.782TTFT(s) p50=3.363p90=4.269 p99=4.800
Done.
(gpt-oss)root@h100vm:~#pythonrun_local_llm.py" leasewritemeaPythonprogramthatcanrundirectlyintheterminal.ThisprogramshouldbeaTetrisgamewithacolorfulinterface,andallowtheplayertocontrolthedirectionoftheblocks,gamescreenshouldhasaclearborder,runwithoutanyerror."
#!/usr/bin/env python3"""简易命令行调用本地 vLLM (OpenAI 兼容) 的脚本用法: python run_llm.py "你的 prompt ..."可选: -m/--model 指定模型名称 (默认: 自动探测) -u/--url 指定服务地址 (默认: http://127.0.0.1:8000) -v/--verbose 显示完整 JSON 响应"""
importargparseimportsysimportrequestsfromopenaiimportOpenAIfromopenai.types.chatimportChatCompletion
deflist_models(base_url:str): """调用 /v1/models 获取当前加载的模型列表""" try: resp = requests.get(f"{base_url.rstrip('/')}/models", timeout=3) resp.raise_for_status() data = resp.json() return[m["id"]formindata.get("data", [])] exceptException: return[]
defmain() ->None: parser = argparse.ArgumentParser(description="Run prompt on local vLLM server") parser.add_argument("prompt", nargs="+",help="提示词") parser.add_argument("-m","--model",help="模型名称 (默认: 自动探测)") parser.add_argument("-u","--url", default="http://127.0.0.1:8000", help="服务器地址(不带 /v1),默认 http://127.0.0.1:8000") parser.add_argument("-v","--verbose", action="store_true", help="打印完整 JSON")
args = parser.parse_args() base_url = args.url.rstrip("/") +"/v1"
# 如果用户没指定模型,就到 /v1/models 去探测 model_name = args.model ifmodel_nameisNone: models = list_models(base_url) ifnot models: print("❌ 无法从 /v1/models 获取模型列表,请检查 vLLM 是否在运行。", file=sys.stderr) sys.exit(1) iflen(models) >1: print("⚠️ 服务器上有多个模型,请用 -m 指定;当前可用:",", ".join(models), file=sys.stderr) sys.exit(1) model_name = models[0]
prompt_text =" ".join(args.prompt)
client = OpenAI(base_url=base_url, api_key="EMPTY")
try: resp: ChatCompletion = client.chat.completions.create( model=model_name, messages=[{"role":"user","content": prompt_text}], temperature=0.7 ) exceptExceptionase: print("❌ 调用失败:", e, file=sys.stderr) sys.exit(1)
# 输出 ifargs.verbose: print("=== 完整 JSON ===") print(resp.model_dump_json(indent=2, ensure_ascii=False)) print("\n=== 模型回答 ===")
print(resp.choices[0].message.content.strip())
if__name__ =="__main__": main()
gpt-oss-120b on Azure AI foundry
gpt-oss-120b 现已在 Azure AI Foundry(国际版)上线,并且可以非常方便的实现一键部署。
|