返回顶部
热门问答 更多热门问答
技术文章 更多技术文章

百灵大模型 Ling 和 Ring 系列首发支持 SGLang-JAX 推理引擎

[复制链接]
链载Ai 显示全部楼层 发表于 昨天 22:33 |阅读模式 打印 上一主题 下一主题


10 月 19 日,SGLang 团队发布文章,宣布推出全新的开源推理引擎 SGLang-JAX —— 一个完全基于 JAX 和 XLA 构建的最先进的开源推理引擎,实现了快速、原生的 TPU 推理。

SGLang-JAX 项目首发支持了百灵非思考模型 Ling 和思考模型 Ring,感谢社区伙伴对百灵大模型的认可与支持,期待后续更多更深入的社区合作。

【百灵模型系列官方链接🤗

  • https://huggingface.co/collections/inclusionAI/ling-v2

  • https://huggingface.co/collections/inclusionAI/ring-v2


SGLang-JAX 技术博客原文(https://lmsys.org/blog/2025-10-29-sglang-jax/)翻译如下:


SGLang-JAX:一种用于原生 TPU 推理的开源解决方案

作者:SGLang-JAX 团队,2025 年 10 月 29 日

我们非常高兴地介绍 SGLang-JAX —— 一个完全基于 JAX 和 XLA 构建的最先进开源推理引擎。
它利用了 SGLang 的高性能服务端架构,并使用 JAX 来编译模型的前向计算过程。
通过结合 SGLang 与 JAX,本项目在保持 连续批处理(continuous batching)、前缀缓存(prefix caching)、张量与专家并行(tensor & expert parallelism)、推测解码(speculative decoding)、内核融合(kernel fusion)等高级特性支持的同时,实现了快速、原生的 TPU 推理。

基准测试显示,SGLang-JAX 的性能与其他 TPU 推理方案相当,甚至在某些情况下更胜一筹。
源码可在此获取:
👉 https://github.com/sgl-project/sglang-jax


为何选择 JAX 后端?

虽然 SGLang 最初基于 PyTorch 构建,但社区一直强烈希望支持 JAX。
我们选择构建 JAX 后端的主要原因如下:

  • JAX 为 TPU 而生
    为了在性能上做到极致无妥协,Jax 是最合适的选择。随着 Google 扩大 TPU 的公共访问,我们预计 JAX + TPU 将获得显著增长,带来高性价比的推理能力。

  • 顶级 AI 实验室已广泛采用 JAX
    Google DeepMind、xAI、Anthropic 和 Apple 等实验室使用同一框架进行训练与推理,可减少维护成本并避免两阶段间的偏移。

  • JAX + XLA 是经过验证的编译驱动栈
    它在 TPU 上性能卓越,同时在多种自研 AI 芯片(TPU 类架构)上表现优异。


架构

下图展示了 SGLang-JAX 的体系结构。整个堆栈完全基于 JAX,代码简洁且依赖最小化。

在输入端,它通过 OpenAI 兼容 API 接收请求,利用 SGLang 的高效 RadixCache 进行前缀缓存,并采用 重叠调度器(overlap scheduler) 实现低开销批处理。
调度器会为不同批次大小预编译 JAX 计算图。

在模型端,我们使用 Flax 实现模型,并通过 shard_map 支持多种并行策略。
两个核心算子 —— 注意力(attention)和 MoE —— 使用自定义 Pallas 内核 实现。

SGLang-JAX 架构示意图


关键优化

1. 集成 Ragged Paged Attention v3

我们将 RPA v3 集成至系统,并扩展以支持 SGLang 的特性:

  • 根据不同场景优化 kernel grid block 配置,以获得更佳性能。

  • 实现与 RadixCache 的兼容。

  • 为支持 EAGLE 推测解码(speculative decoding),我们为 RPA v3 增加了用于验证阶段的自定义 mask。

2. 减少调度开销

在前向计算过程中,CPU 与 TPU 的串行操作会影响性能。

然而,不同设备上的操作可以解耦,例如:在 TPU 启动计算的同时,CPU 可准备下一个批次。

为此,我们的调度器通过事件循环重叠 CPU 与 TPU 的任务执行:

  • 调度器使用结果队列与线程事件,实现 CPU 与 TPU 的流水线化。

  • 当 TPU 处理批次 N 时,CPU 准备批次 N+1。

  • 根据性能分析结果优化操作顺序,实现最大化重叠。

例如在 Qwen/Qwen3-32B 上,我们将 预填充到解码的时间间隙 从约 12ms → 38μs,以及 7ms → 24μs。

更多细节可参考我们此前的博客(https://lmsys.org/blog/2024-12-04-sglang-v0-4/)。

使用重叠调度器:批次间间隙极小

不使用调度器:批次间存在显著 CPU 开销

3. MoE 内核优化

MoE 层目前支持两种实现策略:EPMoE 与 FusedMoE。

  • 在 EPMoE 中,我们集成了 Megablox GMM 运算符,替代了原先基于 jax.ragged_dot 的实现。该算子专为 MoE 设计,能高效处理不同大小的专家组(group_sizes),避免冗余计算与非连续内存访问。在典型配置下,性能提升可达 3–4× e2e ITL 加速。结合高效的 token 排列(permute/unpermute)、跨设备专家并行通信(ragged_all_to_all) 与 自适应分块策略,EPMoE 能显著提升吞吐量。

  • FusedMoE 将所有专家计算融合为稠密的 einsum 操作,无需跨设备通信,适合专家数量少但单体较大的模型(如 <64 个专家)。它同时也是调试与正确性验证的轻量后备方案。

4. 推测解码(Speculative Decoding)

SGLang-JAX 实现了基于 EAGLE 的推测解码,也称为 多 Token 预测(MTP)。

该技术通过轻量级 draft head 一次预测多个 Token,并在验证阶段通过完整模型并行确认,从而加速生成。

为支持树状的 MTP-Verify,SGLang-Jax 在 RPA v3 之上增加了 非因果 mask 支持,使其可在验证阶段并行解码树状、非因果 Token。

目前支持 Eagle2 和 Eagle3,未来将继续优化内核并扩展不同注意力后端的支持。


TPU 性能

经过上述优化后,SGLang-JAX 的性能与其他 TPU 推理方案相当甚至更优。
在 TPU 与 GPU 的对比中,SGLang-JAX 也展现出强劲的竞争力。

完整基准结果与使用说明见:
👉 https://github.com/sgl-project/sglang-jax/issues/297


使用方法


安装 SGLang-JAX 与启动服务

安装:

    #withuvuvvenv--python3.12&&source.venv/bin/activateuvpipinstallsglang-jax#fromsourcegitclonehttps://github.com/sgl-project/sglang-jaxcdsglang-jaxuvvenv--python3.12&&source.venv/bin/activateuvpipinstall-epython/

    启动服务:

      MODEL_NAME="Qwen/Qwen3-8B" # or"Qwen/Qwen3-32B"
      jax_COMPILATION_CACHE_DIR=/tmp/jit_cache\uv run python -u -m sgl_jax.launch_server\--model-path${MODEL_NAME}\--trust-remote-code\--tp-size=4\--device=tpu\--mem-fraction-static=0.8\--chunked-prefill-size=2048\--download-dir=/tmp\--dtype=bfloat16\--max-running-requests256\--page-size=128


      通过 GCP 控制台使用 TPU

      可在菜单Compute Engine → Create TPU中创建 TPU。

      注意不同区域支持的 TPU 版本不同,请设置 TPU 软件版本为v2-alpha-tpuv6e。

      前往Settings → Metadata → SSH Keys,添加公钥。创建完成后,可通过External IP + 用户名登录 TPU。

      更多说明见:👉Google Cloud TPU 设置指南https://docs.cloud.google.com/tpu/docs/setup-gcp-account


      通过 Skypilot 使用 TPU

      建议在日常开发中使用Skypilot。

      可通过以下步骤快速搭建环境并运行测试:

      安装 Skypilot(GCP 环境)https://docs.skypilot.co/en/latest/getting-started/installation.html#gcp:

      👉安装指南

      然后运行sgl-jax.sky.yaml:

        skylaunchsgl-jax.sky.yaml--cluster=sgl-jax-skypilot-v6e-4--infra=gcp-i30--down-y--use-spot

        该命令会在各地区自动选择最低价的 TPU Spot 实例,闲置 30 分钟后自动关闭,并自动安装 sglang-jax 环境。设置完成后,可直接通过ssh cluster_name登录,无需记录外部 IP。


        路线图

        我们与 Google Cloud 团队及多方合作伙伴正在推进以下计划:


        模型支持与优化

        • 优化 Grok2、Ling/Ring、DeepSeek V3、GPT-OSS

        • 支持 MiMo-Audio、Wan 2.1、Qwen3 VL

        TPU 优化内核

        • 量化内核

        • 通信与计算重叠内核

        • MLA 内核

        强化 RL 训练与权重同步

        • 与 tunix 集成

        • Pathways 与多主机支持

        高级服务能力

        • 预填充-解码分离(Prefill-decode disaggregation)

        • 分层 KV 缓存(Hierarchical KV cache)

        • 多 LoRA 批处理(Multi-LoRA batching)

        回复

        使用道具 举报

        您需要登录后才可以回帖 登录 | 立即注册

        本版积分规则

        链载AI是专业的生成式人工智能教程平台。提供Stable Diffusion、Midjourney AI绘画教程,Suno AI音乐生成指南,以及Runway、Pika等AI视频制作与动画生成实战案例。从提示词编写到参数调整,手把手助您从入门到精通。
        • 官方手机版

        • 微信公众号

        • 商务合作

        • Powered by Discuz! X3.5 | Copyright © 2025-2025. | 链载Ai
        • 桂ICP备2024021734号 | 营业执照 | |广西笔趣文化传媒有限公司|| QQ