过去我用 Agent 复刻了很多在我技术栈能够 cover 的工作和想法,可以发现 Agent 的效率很高。但是保持自身竞争力的方式,不是 Fork,而是创新工作。
这里的 Fork 是指:用 Agent 更快地完成你本来就会做的事——写一个你熟悉的 CRUD 服务、移植一个你做过的驱动、重构一段你理解的代码。
Agent 在这类任务上的效率提升是显著的,但它本质上是提速,不是拓展。当所有人都能用 Agent 以十倍速完成同类工作时,这个效率优势就不再是壁垒。
创新工作是指:借助 Agent 进入你原本不具备完整能力的领域,产出原本不存在的东西。
笔者对 QEMU 和 RISC-V 比较熟悉,但对 Triton 编译器内部的 MLIR Pass 体系、CPU 后端的 IR 降级路径并不了解。
单靠自己,从零实现一个 Triton 的 RISC-V CPU 后端需要数周的学习和试错。而 Agent 可以在这个过程中充当领域桥梁——将笔者的 RISC-V 架构知识与 Triton 的编译器框架对接,产出一个之前不存在的后端实现。
所以我尝试利用 Agent 协助我给 Triton 支持一个新的后端 qemu-riscv64。这篇文章会分享:
通过启发式提问,让 Agent 协助人快速学习新领域;
利用 Git,让 Agent 的每个 feature 可编译、可测试、可二分查找;
探讨如何进行工程化落地和交付。
01
什么是 Triton
Triton 是由 OpenAI 开发的一种用于编写高性能并行计算 kernel 的 DSL,基于 Python 语法,它与 CUDA 的关系类似于“更高层次的 DSL 前端”。
相比 CUDA,Triton 的核心思想是关注“怎么对一个 block 的数据做运算”,这个是程序员使用 Triton 编写 kernel 的第一要务。至于具体如何调度到硬件线程、如何利用向量指令、如何管理缓存,则全部由编译器决定。
这么做是符合直觉的,可以大幅度降低算子开发的心智负担。
如果你的工作流切换到 Agent 上面了,可以按照下面的 prompt,让 Agent 辅助你理解什么是 Triton:
“请帮我调研 Triton 的官方手册,总结一下 Triton 的核心机制、编程示例、与其他 DSL 的显著区别,形成文档,并在文末给出相关的参考链接,允许你在这个过程中向我提问,一次一个问题。”
这段提示词有两个关键点:一手信息来源要忠实于官方资料;通过 Agent 向人不断地提问,来生成符合个体技术背景差异的回答。
GPU 后端是 Triton 的原生目标,但 RISC-V CPU 后端的存在有以下价值:
开发调试:在没有 GPU 的机器上开发和调试 Triton kernel;
通用部署:在边缘设备、嵌入式系统、RISC-V 服务器上运行相同的 kernel 代码;
性能基线:为 GPU kernel 提供 CPU 参考实现,用于正确性验证;
新架构验证:RISC-V 生态正在快速发展,CPU 后端使 Triton 成为 RISC-V 向量计算的编译前端。
目前官方已经给出了 x86 和 aarch64 参考实现,叫做
triton-cpu
,
这是一个 Triton 的下游仓库,在支持了 CPU 后端的同时,定期 rebase 上游基线。
如果想要快速了解 triton-cpu 这个项目,可以这样问你的 Agent:
“调研 triton-cpu 的 GitHub 仓库,总结 CPU 后端的当前进展,和 triton 上游的 commit 差异。”
现在我们来看一个简单的 vector add kernel,来理解 Triton 的编程模型。
import triton
import triton.language as tl
@triton.jit # Kernel 装饰器
def vector_add_kernel(x_ptr, y_ptr, out_ptr, n_elements,
BLOCK_SIZE: tl.constexpr): # program_id(0) 返回当前 program 实例在第 0 维的索引
pid = tl.program_id(0) # 计算本 block 处理的元素偏移
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) # 越界保护 mask
mask = offsets < n_elements # 带 mask 的向量加载
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
# 向量存储
tl.store(out_ptr + offsets, x + y, mask=mask)
关键特征在最后几行,tl.load/store 描述了数据块的加载和存储方式,从这里可以看出来,程序员关注点主要在数据块的切分和使用上。
再进一步,我们需要了解 Triton DSL 是怎么一步步编译到 RISC-V 指令的,这有助于帮助我们理解 Triton 的定位,如果你了解 LLVM,那么它会是一个很好的类比对象。
到这一步就比较清晰了,Triton 作为一个领域特定的编译前端,利用一系列特化 Pass 进行多级 IR 降级,通过 MLIR-to-LLVM 的标准路径落地到 LLVM IR,最终由 LLVM 完成指令选择、寄存器分配、机器代码生成。
区别在于,Triton IR 原生就是“向量化、块级并行”的,而 LLVM IR 通常从标量语义开始。
你可以这样提问 Agent:
“Triton CPU 后端和 GPU 后端在 IR 降级路径上有哪些差异?他们的最终产物在调用和执行上有没有区别?”
02
Triton RISC-V CPU 后端实现
支持 RISC-V 后端可以使 triton-cpu 能够将 Triton kernel 编译为 RISC-V 机器码。选择 qemu-riscv64 作为后端执行对象,可以提供通用的 RISC-V 参考实现,能够在 PC 机器上通过二进制翻译技术运行 RISC-V 机器码。
下面给出架构图:
被
triton.jit
装饰的部分是具体的 kernel 函数,它会被转化成 TTIR(Triton IR),这里是硬件无关的高层 IR;
接下来会被转化为 TTCIR(Triton CPU IR),在这里引入 CPU 概念;
更进一步地,转化为 TTTCIR(Target Triton CPU IR),进行目标架构特化,图中展示的 ConvertDotToRVV 是由 Agent 编写的用于将标量操作进行向量化的 Pass。
从 TTCIR 往后开始进入 LLVM 的世界,首先是转化成 LLVM MLIR(LLVM Dialect in MLIR);
然后是转化为 LLVM IR,接着是输出为 ASM 汇编文本,这里对应的是 RISC-V 汇编;
最后编译为 .so 共享库二进制,如果需要运行在 qemu-riscv64 上面,这里可以利用 gcc 或者 clang 进行交叉编译;接下来就是运行时加载了。
我们再画一张图,梳理一下:
以上是通过和 Agent 讨论制定的 Spec,这里推荐使用 Claude + humanize 帮你生成更高质量的 Plan。你可以按照下面的 prompt 进行操作:
“ /humanize:gen-idea 帮我给 triton-cpu 增加一个新的后端 qemu-riscv64,可以在 x86 的机器上运行 RISC-V 的机器码。在探索的过程,可以向我大量提问。”
之后 Claude 会按照提示词,调用多个 subagent 展开并行的 explore,并返回一个 draft。
我们在详细阅读这个 draft 以后,需要和 Agent 进行多轮探讨,明确方向上没有问题。接下来,我们继续生成 Plan:
“ /humanize:gen-plan 根据 idea 进一步生成 plan。”
这一步会将 idea 的 draft 拆解为具体的 AC 和 Task,并确定里程碑和验收目标。
我们在审查 Plan 没有问题以后,可以开启 rlcr-loop,进入实施阶段:
“ /humanize:start-rlcr-loop 开始执行这个 plan。”
为了确保我们充分理解 Plan 的内容,humanize 会在启动 rlcr-loop 之前,根据 Plan 的内容向我们随机提问,如果回答错误,会建议我仔细阅读 Plan 以后再启动。
如果没有问题,Agent 将努力完成 Plan 中的每一个目标,并且在每轮 loop 结束前,调用 Codex(GPT-5.5)进行严格的审查。
在启动 rlcr-loop 之前,还需要和 Agent 对齐 Git 工作流。
Agent 生成的代码如果不落到可追溯的 commit 上,事后就无法 review、无法 bisect、无法复现问题。笔者在 Plan 和对话中使用了以下 prompt 来约束 Agent 的 Git 行为:
“在开发分支上,每轮 loop 结束时生成一个 commit,commit message 以 Round N 开头,描述本轮完成的 AC 和关键修改。所有 loop 结束后,按 feature 维度将多轮 commit squash 为独立的 feature commit 合入 main,确保 main 上的每个 commit 都可以独立编译和测试。”
实际产出的 Git 历史如下:
# 开发分支 triton-cpu-riscv-qemu(保留完整迭代过程)
dc2586a9a [CPU] 递增式模型片段测试: softmax + layernorm + MLP 端到端 qemu 验证
8048af6b9 [CPU] Round 1: 完善模型片段测试 - 三 .so MLP 组合 + OMP 精确一致性
4f4a3debe [CPU] Round 2: softmax -inf 掩码修复 + 负向测试 + C 驱动加固
a1b41196c [CPU] Round 3: MLP 维度不匹配检测 + qemu 直接测试 + softmax 参考修复...
ae838915d [CPU] Round 19: 修复 test_math.py 向量库期望以匹配后端行为
cb38cd3d4 [CPU] Finalize: 提取重复逻辑为辅助方法 + 移动延迟导入# main 分支(squash 后的 feature commit,每个可独立编译测试)
700f8cd8a [CPU] Add RISC-V 64-bit cross-compilation infrastructure
f522fa953 [CPU] Add ConvertDotToRVV pass and SLEEF RVVM1 math dispatch
cbb9a02ee [CPU] Add RISC-V build targets and SLEEF cross-compilation
75920596b [CPU] Add comprehensive RISC-V 64-bit test suite
3bd5f5e10 [CPU] Add RISC-V 64-bit backend documentation
这个两层结构的好处是:开发分支保留了 Agent 每轮 loop 的完整修改记录,可以回溯任何一次 Codex 审查后的修正细节;
main 分支上每个 commit 对应一个完整的 feature 交付,可以用
git bisect
定位回归,也可以 cherry-pick 单个 feature 到其他分支。
为了完成对 RISC-V 后端的适配,笔者一共拟定了四份 Plan,分别是功能实现、RVV/BF16/FP16 测试、模型片段测试、文档编写。
四个 Plan 平均耗时在 3 个小时左右,loop 数平均在 9 轮左右,最终交付的代码质量相当高。
03
关键 Pass 实现拆解
我们从每个阶段的 Pass 入手,对于 triton-cpu 已有的部分,我们一笔带过,主要讲解 Agent 在哪些环节做了关键修改。
阶段 1:
make_ttir()
— 硬件无关优化
passes.common.add_inliner(pm) # 函数内联
passes.ttir.add_combine(pm) # Triton IR 组合优化
passes.common.add_canonicalizer(pm) # MLIR 标准规范化
passes.ttir.add_reorder_broadcast(pm) # 广播操作重排
passes.common.add_cse(pm) # 公共子表达式消除
passes.common.add_licm(pm) # 循环不变代码外提
passes.common.add_symbol_dce(pm) # 死符号消除
这个阶段与 GPU 后端共享,类似于 LLVM 中
-O2
通用的中端优化。
阶段 2:
make_ttcir()
— TTIR 到 CPU IR 转换
cpu.passes.ttcpuir.add_scalarize(pm, True) # 标量化部分操作
cpu.passes.ttcpuir.add_convert_memory_ops(pm, True) # tl.load/tl.store → memref 操作
cpu.passes.ttcpuir.add_convert_ptr_ops(pm) # 指针运算转换
cpu.passes.ttcpuir.add_convert_elementwise_ops(pm) # 逐元素操作(加减乘除)
cpu.passes.ttcpuir.add_convert_elem_manip_ops(pm) # 元素操纵(reshape, broadcast)
cpu.passes.ttcpuir.add_convert_dot_op(pm) # tl.dot → cpu::DotOp
cpu.passes.ttcpuir.add_convert_histogram_op(pm) # 直方图操作
cpu.passes.ttcpuir.add_convert_reduction_op(pm, True, False) # tl.sum/tl.max → 规约
cpu.passes.ttcpuir.add_convert_scan_op(pm) # 前缀扫描操作
cpu.passes.ttcpuir.add_convert_cf_ops(pm) # 控制流转换
cpu.passes.ttcpuir.add_convert_atomic_ops(pm) # 原子操作
cpu.passes.ttcpuir.add_convert_debug_ops(pm) # 调试操作
passes.common.add_cse(pm) # CSE 清理
passes.common.add_symbol_dce(pm)
passes.common.add_canonicalizer(pm)
这是 Triton 特有操作到 CPU 语义的映射。例如
tl.load
的 mask 语义在这里被展开为条件向量加载。
Triton 的内存操作天然支持 mask(越界保护):
x = tl.load(x_ptr + offsets, mask=mask, other=0.0)
在 CPU 后端的
makettcir()阶段,addconvertmemoryops
pass 将其转换为:
// 带 mask 时 → LLVM masked load intrinsic
%result = call <N x float> @llvm.masked.load.vNf32.p0(
ptr %addr,
i32 alignment,
<N x i1> %mask,
<N x float> %passthru // other=0.0 对应 zeroinitializer
)
不带 mask 时直接变成普通的 LLVM
load
指令。
在 RISC-V 后端,
llvm.masked.load
最终由 LLVM 的 RVV 后端降级为
vle32.v
+
vmerge
或基于 mask 的条件加载指令。
阶段 3:
make_tttcir()
— 目标架构特化
这是 RISC-V 特化的核心阶段,依据架构和扩展特征选择不同的 pass:
cpu.passes.ttcpuir.add_triton_cpu_canonicalizer(pm) # CPU IR 规范化
cpu.passes.ttcpuir.add_optimize_masks(pm) # mask 优化
passes.common.add_canonicalizer(pm)
# --- 架构特化的 DotOp 转换(按优先级排列)---
# Intel AMX(x86, 有 amx-tile 特征时)
# cpu.passes.ttcpuir.add_convert_dot_to_amx(pm, ...)
# x86 AVX512 FMA
# cpu.passes.ttcpuir.add_convert_dot_to_fma(pm)
# *** RISC-V RVV(riscv64, 有 +v 特征时)***
if arch == "riscv64" and "v" in features:
cpu.passes.ttcpuir.add_convert_dot_to_rvv(pm)
# 通用回退(所有架构都会注册)
cpu.passes.ttcpuir.add_convert_dot_generic(pm)
# --- 数据类型转换策略 ---
# BF16/FP16 条件编译(见 4.3 节)
cpu.passes.ttcpuir.add_convert_unsupported_ops(pm, ...)
cpu.passes.ttcpuir.add_decompose_fp_conversions(pm, ...)
在
make_tttcir()
阶段,还会处理 Triton 的
tl.dot
语义,这个是矩阵乘法的核心算子。
主要应用 ConvertDotToRVV 这个 Pass,将一个
cpu::DotOp
(形状是
[M, K] x [K, N] -> [M, N]
)展开为行级向量 FMA 操作,LLVM 后端将其映射为 RVV 的
vfmacc.vf
指令。
该 Pass 位于:
third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotToRVV.cpp
。
Agent 在实现这个 Pass 的时候,遵循笔者的指令,尽可能复用已有实现,
因此 ConvertDotToRVV 与 ConvertDotToFMA(x86)共享相同的高层策略:
先
遍历
DotOp
操作,通过候选检查函数找到候选,然后分析内存缓冲区:
findInputBuffer()
处理 LHS/RHS;之后检查累加器循环携带:
isLoopCarriedAcc()
最后
逐行降级为 FMA:提取 LHS 标量,广播为 N 宽向量,与 RHS 行进行乘加累加。
关键差异在于:
FMA 针对定宽 x86 SIMD(SSE/AVX),LLVM 对 x86(直接指令选择)和 RISC-V(vsetvli + RVV 指令选择)使用不同的后端策略。
ConvertDotToAMX(x86)则复杂得多,使用硬件 tile 寄存器、显式 tile 大小约束和 VNNI 编码。RVV 没有 tile 寄存器,向量 FMA 方法更简单。
具体思路如下:
匹配这个 Pass 的条件(isRvvCandidate)要求,输入必须是 2D 向量,类型暂时只支持了 f32 和 f64,不要求特定的 shape 大小。
在分块形状上,任何 rank-2 的 2D DotOp 都可以,该 Pass 不施加分块大小约束,LLVM 后端通过
vsetvli
进行任意向量宽度的处理,充分利用 RVV 的特点。
我们来看一下转换策略。
给定
C[M,N] += A[M,K] * B[K,N]
,生成如下 MLIR 操作序列:
// 1. 提取累加器的每一行
for m in 0..M:
accRow[m] = vector.extract C, [m] // <N x f32>
// 2. 外积累加
for m in 0..M:
for k in 0..K:
scalar = vector.extract A, [m, k] // f32
splat = vector.broadcast scalar → <N x f32>
row_b = vector.extract B, [k] // <N x f32> accRow[m] = vector.fma(splat, row_b, accRow[m])// 3. 写回结果for m in 0..M: result = vector.insert accRow[m], result, [m]
这个方案的好处是,
每次 FMA 操作都是行级向量操作
,LLVM 后端可以直接映射到 RVV 的
vfmacc.vf
指令(标量-向量 FMA),无需额外的向量 shuffle。
利用
环回累加器优化
:
isLoopCarriedAcc()
检测累加器是否在循环中被反复更新(如 matmul 的 K 维循环),如果是,累加器值可以保持在寄存器中,避免重复 load/store。
整体来看,这个 Pass 并没有让人很惊艳的地方,属于比较常规的处理策略。
阶段 4:
make_llir()
— 到 LLVM IR 的降级
# TritonCPU IR → LLVM Dialect (MLIR)
cpu.passes.ttcpuir.add_lower_vector_multi_dim(pm) # 多维向量展平
cpu.passes.ttcpuir.add_expand_strided_metadata(pm) # stride 元数据展开
cpu.passes.ttcpuir.add_vector_to_scf(pm, True, 1, False) # vector → scf 循环
cpu.passes.ttcpuir.add_lower_affine(pm) # affine → 标准
passes.convert.add_scf_to_cf(pm) # scf → cf(控制流)
passes.convert.add_index_to_llvmir(pm) # index → LLVM
# *** Triton 特有操作的 LLVM 转换 ***
cpu.passes.ttcpuir.add_func_op_to_llvmir(pm) # tt::FuncOp → LLVM::FuncOp #(这里追加 6 个 program id 参数)cpu.passes.ttcpuir.add_program_id_to_llvmir(pm) # GetProgramIdOp → 函数参数读取cpu.passes.ttcpuir.add_memory_op_to_llvmir(pm) # 内存操作到 LLVMcpu.passes.ttcpuir.add_atomic_ops_to_llvmir(pm) # 原子操作到 LLVMcpu.passes.ttcpuir.add_debug_ops_to_llvmir(pm) # 调试操作到 LLVM# *** 数学库派发 ***# 对于 riscv64+v:SLEEF RVVM1 向量化数学函数# 对于 x86 AVX512:libmvec 或 SLEEFcpu.passes.ttcpuir.add_math_to_vec_lib(pm, vec_lib, features)passes.convert.add_math_to_llvmir(pm) # math dialect → LLVM intrinsicscpu.passes.ttcpuir.add_math_to_libm(pm) # 剩余 math → libm 函数调用cpu.passes.ttcpuir.add_vector_to_llvmir(pm, fast_math, arch) # vector → LLVMcpu.passes.ttcpuir.add_memref_to_llvmir(pm) # memref → LLVMpasses.convert.add_reconcile_unrealized(pm) # 消除类型转换占位passes.convert.add_arith_to_llvmir(pm) # arith → LLVMcpu.passes.ttcpuir.add_func_to_llvmir(pm) # func → LLVMcpu.passes.ttcpuir.add_ub_to_llvmir(pm) # UB → LLVM# LLVM MLIR → LLVM IR(纯 LLVM 世界)llvm.to_module(mod, context)# 对于交叉编译:llvm.set_target(llvm_mod, target_triple, target_cpu, target_features)llvm.optimize_module(llvm_mod, OPTIMIZE_O3, target_cpu, target_features, ...)
与 GPU 后端的一个显著差异,是对
tl.programid语义的映射。
**
在 GPU 上,programid
对应硬件的 block index 寄存器;在 CPU 上,它变成了
函数参数
。
在
makellir()的 addfuncopto_llvmir
pass 中,
FuncOpConversion::amendProgramIdArgs()
向每个 kernel 函数追加 6 个参数:
// 原始函数签名:void kernel(float* x, float* y, float* out, int n)
// 追加后:void kernel(float* x, float* y, float* out, int n,
// int32_t pid0, int32_t pid1, int32_t pid2,
// uint32_t gridX, uint32_t gridY, uint32_t gridZ)
amendedInputTy.push_back(i32_ty); // pid0
amendedInputTy.push_back(i32_ty); // pid1
amendedInputTy.push_back(i32_ty); // pid2
amendedInputTy.push_back(ui32_ty); // gridX
amendedInputTy.push_back(ui32_ty); // gridY
amendedInputTy.push_back(ui32_ty); // gridZ
GetProgramIdOpConversion
将
tl.program_id(axis)
替换为对函数参数的直接读取:
// Utility.cpp
Value getProgramId(FunctionOpInterface funcOp, int axis) {
auto args = funcOp.getArguments();
auto argIdx = args.size() - 6 + axis;
return args[argIdx];
}
Value getNumPrograms(FunctionOpInterface funcOp, int axis) {
auto args = funcOp.getArguments();
auto argIdx = args.size() - 3 + axis;
return args[argIdx];
}
我们再看对数学函数的处理,这里关注
tl.exp
/
tl.sin
等语义。
数学函数的编译路径,会经过三层降级:
tl.exp(x)
→ math::ExpOp (MLIR math dialect)
→ [MathToVecLib pass]
→ 如果是 riscv64+v 且 SLEEF 可用:
func.call @Sleef_expfx_u10rvvm1(%vec) # SLEEF RVVM1 向量化版本
→ 如果是 riscv64+v 但 SLEEF 不可用:
[math_to_llvmir pass]
→ llvm.exp intrinsic(由 LLVM 后端展开为标量循环 + libm 调用)
→ 如果是 riscv64 无 +v:
[math_to_libm pass]
→ func.call @expf(%scalar) # 标量 libm
这里 Agent 选择了 SLEEF 库,为 Triton 提供向量化的数学函数实现。
主要支持了 SLEEF RVVM1 架构。RVVM1 表示 SLEEF 使用 LMUL=1 的 RVV 向量寄存器组,这是一种与向量长度无关(VLA)的 ABI。LLVM 在代码生成时根据目标硬件的 VLEN 决定实际向量长度。
具体操作是在编译时 Pass 层(MathToVecLib.cpp)中进行判断,当
isRvv == true
时,
populateSleefRvvPatterns()
注册了 27 个数学函数的 RVVM1 变体:
// 每个函数都使用 SleefNameGenerator("name", ulp, /*rvvm1=*/true)
populatePatternsForOp<math::SinOp>(patterns, gen("sin"), ...);
populatePatternsForOp<math::CosOp>(patterns, gen("cos"), ...);
populatePatternsForOp<math::ExpOp>(patterns, gen("exp"), ...);
// ... 共 27 个数学操作
SleefNameGenerator
的 RVVM1 模式:
if (useRvvm1) {
return "Sleef_" + baseName + (bitwidth == 32 ? "f" : "d") +
"x" + ulpSuffix + "rvvm1";
}
RVVM1 模式通过
updatevecsize()
中的 RVV 特征检测激活:
if (feature == "v") {
isRvv = true;
vec_size_in_bits = std::max<size_t>(vec_size_in_bits, 128);
}
阶段 5:
make_asm()
— LLVM IR 到汇编
# 交叉编译场景
llvm.translate_to_asm(src, target_triple, target_cpu, target_features, ...)
# 本地编译场景
llvm.translate_to_host_asm(src, ...)
此阶段调用 LLVM 的
TargetMachine::addPassesToEmitFile()
,走完整的 LLVM 后端流水线,比如 SelectionDAG / GlobalISel、指令选择、寄存器分配、指令调度、MC 层代码发射等。
阶段 6:
make_so()
— 汇编到共享库
# 交叉编译路径:调用 riscv64-linux-gnu-gcc
_cross_build("kernel", asm_path, tmpdir, options)
# 本地编译路径:调用宿主 gcc/clang
_build("kernel", asm_path, tmpdir, lib_dirs, include_dirs, libs, ccflags)
04
Kernel 的运行时执行模型
Triton kernel 的并行模型是每个 program 实例独立处理一个 tile,所有实例构成一个 3D grid。在 CPU 后端,这个 grid 通过 OpenMP
parallel for
映射到线程。
核心调度代码在
thirdparty/cpu/backend/driver.py的 runomp_kernels
模板中:
// 1. 展平 3D grid 为 1D
size_t N = gridX * gridY * gridZ;
auto all_grids = get_all_grids(gridX, gridY, gridZ);
// 2. 特殊情况:单实例直接调用
if (N == 1) {
(*kernel_ptr)(args..., 0, 0, 0, 1, 1, 1);
return;
}
// 3. 特殊情况:单线程顺序执行(避免 OMP 开销)
if (max_threads == 1) {
for (size_t i = 0; i < N; ++i) {
const auto [x, y, z] = all_grids[i];
(*kernel_ptr)(args..., x, y, z, gridX, gridY, gridZ);
}
return;
}
// 4. 多线程并行执行
#pragma omp parallel for schedule(static) num_threads(max_threads)
for (size_t i = 0; i < N; ++i) {
const auto [x, y, z] = all_grids[i];
(*kernel_ptr)(args..., x, y, z, gridX, gridY, gridZ);
}
Grid 展开顺序是
(z, y, x)
(z 最外层),这与 CUDA 的 grid 遍历顺序一致。每个 OpenMP 线程调用 kernel 函数时传入不同的
(x, y, z)
值,这就是
tl.program_id()
的来源。
16 核示例:
vector_add(n=16384, BLOCK_SIZE=1024)
grid = (16,) → 16 grid points
OMP_NUM_THREADS=16:
core 0: kernel(pid=0) → arr[0:1024] 每核处理 1 个 block
core 1: kernel(pid=1) → arr[1024:2048]
...
core 15: kernel(pid=15) → arr[15360:16384]
OMP_NUM_THREADS=4:
core 0: pid=0,1,2,3 (schedule=static) 每核处理 4 个 block
core 1: pid=4,5,6,7
core 2: pid=8,9,10,11
core 3: pid=12,13,14,15
交叉编译时默认加
-fopenmp
(
compiler.py:crossbuild()
),生成的
.so
已包含 OpenMP 并行代码。可通过
TRITONDISABLEOPENMP=1
禁用。
PS:如果你对并行编程感兴趣,强烈推荐周洲仪博士参与翻译的《深入理解并行编程》第 2 版。在 Agent 时代,可以静下心来把一本书读扎实,是一种难能可贵的品质。
编译后的 kernel 函数签名遵循以下 ABI 约定:
typedef void (*kernel_ptr_t)(
// 用户定义的参数(非 constexpr)
float* x_ptr, // tl.pointer
float* y_ptr,
float* out_ptr,
int32_t n_elements, // tl.int32
// 固定的 6 个调度参数(由编译器自动追加)
uint32_t pid0, // tl.program_id(0)
uint32_t pid1, // tl.program_id(1)
uint32_t pid2, // tl.program_id(2)
uint32_t gridX, // tl.num_programs(0)
uint32_t gridY, // tl.num_programs(1)
uint32_t gridZ // tl.num_programs(2)
);
注意
pid0/pid1/pid2
使用
int32t(有符号),而 gridX/gridY/gridZ 使用 uint32t
(无符号),这与
FuncOpToLLVM.cpp
中的类型定义一致。
我们再补充一下 dlopen/dlsym 动态加载的机制。
CPUUtils.load_binary()
使用 Python 的
ctypes.cdll.LoadLibrary()
动态加载编译好的 .so:
def load_binary(self, name, kernel, shared_mem, device):
# 1. 检查 ELF 架构匹配
machine = _elf_machine(kernel)
if machine != host_machine:
raise RuntimeError("Cannot load ... ELF binary on ... host")
# 2. 写入临时文件并加载
with tempfile.NamedTemporaryFile(suffix=".so") as f:
f.write(kernel)
f.flush()
lib = ctypes.cdll.LoadLibrary(f.name)
fn_ptr = getattr(lib, name)
return (lib, fn_ptr, 0, 0, 0)
交叉编译的产物在宿主机上无法直接加载(ELF machine 不匹配),必须通过 QEMU 或真实目标硬件执行。
05
QEMU 验证体系
Triton CPU 使用的是 qemu-user 用户态仿真,而非 qemu-system 全系统仿真,主要是出于性能的考虑。经过与 Agent 的协同验证,我们确定了以下事实:
qemu-user 可以验证的功能有:多线程功能正确、多线程调度与同步。
qemu-user 不可以验证的功能:RVWMO 内存序,由于集成了主机内存模型,relaxed ordering 的 bug 可能被掩盖;多核模拟通过 pthread,不代表真实 RISC-V。
由于交叉编译的 Triton kernel 是一个 .so 共享库(没有
main
函数),Agent 根据 Plan 设计了 QemuRunner 为每种测试场景生成一个独立的 C 驱动程序,负责 kernel.so 的加载,数据的输入和按照 grid 遍历方式调用 kernel,以及结果的输出。
驱动程序的编译和执行流程:
def _compile_driver(self, driver_src, tmpdir, extra_flags=None):
# 1. 写入 C 源码
Path(src_path).write_text(driver_src)
# 2. 交叉编译
cmd = [self.cc, src_path, "-o", bin_path,
f"-march={march}", f"-mabi={self.target_abi}",
"-ldl", "-lm", "-O2"]
# 3. 可选 OpenMP 支持
if extra_flags:
cmd.extend(extra_flags) # e.g., ["-fopenmp"]
def _run_qemu(self, binary, args, tmpdir):
# 1. 构建 qemu 命令
qemu_cmd = [self.qemu]
if self._needs_cpu_max():
qemu_cmd += ["-cpu", "max"]
# 2. 设置库搜索路径
qemu_cmd += ["-L", self.sysroot]
qemu_cmd += ["-E", f"LD_LIBRARY_PATH={':'.join(lib_paths)}"]
# 3. 执行(60 秒超时)
ret = subprocess.run(qemu_cmd, capture_output=True, text=True, timeout=60)
为了确保 Agent 生成代码的正确性和稳定性,笔者制定了面向模型片段测试的计划,按照 Triton 特性的复杂度递增排列测试,每个阶段引入一个新的计算范式。主要覆盖:
单算子正确性,比如基本向量操作、RVV 向量操作、BF16/FP16 类型验证;
单遍规约 Softmax,测试覆盖:baseline/RVV/OMP/all-inf/non-pow2 等,精度范围在 atol=1e-4,rtol=1e-4;
多遍规约 LayerNorm,测试覆盖
:
baseline/multipass/RVV/OMP + baseline-vs-RVV 三输出对比,精度范围在 atol=1e-3, rtol=1e-3;
多 Kernel 组合 MLP,baseline/RVV + invalid .so + wrong symbol + dimension mismatch,精度范围在 atol=0.1, rtol=1e-2。
为了提高交付质量,我让 Agent 根据 feature 尽可能添加更多的正向反向测试,力求覆盖所有功能点。
06
humanize 协作复盘
最后复盘一下整个协作流程。这次工作的核心工具链是 Claude Code + humanize skill,后者提供了 gen-idea → gen-plan → rlcr-loop 的结构化工作流,并在每轮 loop 结束时调用 Codex(GPT-5.5, xhigh effort)进行独立审查。
整个项目拆成了四份 Plan:
Plan
任务数
AC 数
里程碑
RISC-V 后端功能实现
23
9
5(从 native smoke test 到 RVV 优化)
RVV/BF16/FP16 测试
22
9
5(从回归基线到 SLEEF 集成)
模型片段测试
12
6
4(Softmax → LayerNorm → MLP)
文档重构
9
5
4(8 文件合并为 3 文件,去重 45%)
四份 Plan 合计 66 个 Task、29 个验收标准、18 个里程碑,规划阶段用了 3 天(4 月 27-29 日),执行阶段每份 Plan 平均 3 小时、9 轮 loop。
humanize 在这个过程中解决的核心问题是
估算校正
。
初始 draft 里,笔者对交叉编译的工作量估计是"300-500 行代码",
Codex 在第二轮审查中拿着代码指出
tritoncpu.cc硬编码了 opts.x86 = true、veclibrequirements里完全没有 RISC-V 的分支,且translatetohostasm
无法处理交叉编译场景。
另一个价值是
盲区发现
。16 轮 Codex 审查(其中 13 轮使用 GPT-5.5)累计识别出 7 个主要技术盲区:
x86 硬编码、交叉编译复杂度、RVV Pass 的 C++/CMake/pybind 工作量、cache key 设计、16 核验证标准、OpenMP 验证方法、多 kernel 组合架构。这些问题如果在执行阶段才暴露,返工成本会很高。
回到开头的问题:fork 还是创新?
说实在的,这次工作处在两者的交界处。
ConvertDotToRVV 是 ConvertDotToFMA 的 RISC-V 变体,SLEEF 集成是在 x86 路径上做架构扩展。从实现手法上看,Agent 做的仍然是"有参考实现的迁移",属于高质量的 fork。
但结果是新的。Triton 之前没有 RISC-V CPU 后端,现在有了。笔者之前不了解 MLIR Pass 体系,现在能拆解六个编译阶段的 Pass 流水线并写成文章。
这大概是 Agent 时代"创新"的一个特征:不需要从零开始。
你带着一个领域的深度进来,Agent 帮你桥接另一个领域的框架。笔者的 RISC-V 知识决定了方向和判断标准,Agent 填补了 Triton 编译器框架的知识盲区。
16 轮 Codex 审查帮笔者发现了 7 个自己找不到的技术问题,gen-idea 的并行 explore 帮笔者在几分钟内建立了对陌生代码库的结构认知。
fork 的上限是原作。创新的上限取决于你带进来的领域知识。
Agent 把"跨领域组合"的门槛从"两个领域都精通"降到了"一个领域精通 + 另一个领域能提出正确的问题"。
但这个工作流也有明确的边界。如果任务是为 RISC-V 矩阵扩展从零设计一套 tile 调度算法,gen-plan 生成的 Plan 质量会大打折扣,因为 Agent 缺乏参考锚点。Agent 擅长在已知模式上做高质量变体,不擅长在未知空间做原创探索。
认清这个边界,才能把 Agent 用在正确的地方。
07
交付与信任
Agent 写完代码,只是开始。真正的问题是:怎么把这份产出交付出去,让别人信任它?
第一反应可能是"Agent 实现得多好、多完美"。但这不是第一要义。第一要义是:你作为开发者,有没有完全理解 Agent 的实现。
代码交付以后,面对的是人。
有人会问你为什么选 RVVM1 而不是 RVVM2,有人会在 code review 里质疑 ConvertDotToRVV 的分块策略,有人会在生产环境遇到一个 SLEEF 符号找不到的链接错误。
这些场景需要的不是再跑一次 Agent,而是你能及时响应、定位问题、给出解释。
这也是前面花大量篇幅拆解六个编译阶段的原因。
不是为了展示 Agent 的产出,而是为了确认笔者自己理解了每一层 Pass 在做什么。如果只是把 Agent 的输出贴上去,第一个 bug report 就会暴露你不懂自己交付的东西。
更现实的问题是:Agent 不会一直在。
会话断了、context 超了、工具链升级了、Agent 的输出质量回退了——这些都会发生。到那个时候,你能不能在 Agent 的代码上继续迭代?能不能手动改一个 Pass、加一个测试、修一个交叉编译的链接参数?
这就是 Git 工作流的价值所在。开发分支保留了 19 轮 loop 的完整记录,每一轮做了什么、为什么改、Codex 审查发现了什么问题,全部可追溯。
main 分支上的 5 个 feature commit 各自独立可编译。任何一个人拿到这个仓库,都可以从任意一个 commit 开始,理解上下文,继续工作。
所以交付的本质不是 Agent 的代码质量,而是你对这份代码的掌控力。别人信任的不是 Agent,是开发者本人。
啰嗦结束,祝大家五一假期快乐~
撰文:泽文(Zevorn)
首图:主核(Kernyr)
审校:泽文(Zevorn)
图片、资料来源:
[1] OpenAI Triton 官方仓库:https://github.com/triton-lang/triton
[2] triton-cpu(CPU 后端)仓库:https://github.com/triton-lang/triton-cpu
[3] Triton 官方文档与教程:https://triton-lang.org
[4] SLEEF 向量化数学库:https://github.com/shibatch/sleef
[5] QEMU 官方文档:https://www.qemu.org/docs/master/
[6] RISC-V 规范(向量扩展 V 1.0):https://github.com/riscv/riscv-v-spec
[7] MLIR 官方文档:https://mlir.llvm.org
[8] LLVM 官方文档:https://llvm.org/docs/
[9] 文中示意图,由 GPT image-2 生成。
推荐阅读
- 关注公众号,
免费使用
社区提供的 ima 知识库
现已推出:
AI Infra/QEMU/Compiler/Linux