Triton其实是和CUDA类似的,严格意义上算是一种基于python的DSL,但是其中可以使用python语法直接实现复杂的功能,同时由于在编译层面做了优化,因此可以减少很多不必要的麻烦,对应的也会因为增加一层黑盒而提升对其中的理解难度。这里这次仅仅通过语法和实现的角度来实现几个triton算子以及对应的示例。关于未来对于原理的解析或许可以慢慢理解并应用于优化。
基础概念
关于Triton的一些细粒度说明和原理后面就想到哪里写到哪里吧。首先从实现的角度来看的话,确实Triton还是非常相似于CUDA的,也因此能使用简单的语法实现类似CUDA的性能。
核心点1:Triton的调度级别是以Block为单位的,也因此可以屏蔽很多关于Warp、Thread的逻辑。
对于传统的CUDA实现而言,尽管相对来说有着较好的性能,但是带来的重要问题就是编写难度过大,编写时需要考虑较多的细节来实现更好的性能,对于CUDA而言,主要需要考虑如下方面:
- 从 DRAM 的内存传输必须合并成大型事务,以利用现代内存接口的大总线宽度(内存合并访问)。
- 数据必须在重复使用前手动存储到 SRAM 中,并进行管理来最小化bank conflict。
- 计算必须仔细地进行划分和调度,不仅是在流式多处理器(SMs)之间,还包括在其内部,以促进指令/线程级并行性,并利用专用的 ALU(例如,Tensor Cores)。
因此,相较于直接编写CUDA,Triton能帮助更好的实现从而关注算法而不是语法本身。
如果使用Triton,内存事务合并、SRAM管理以及SM内的线程调度都是自动进行的,我们只需要把精力花在SM之间管理即可,这也就是说,Triton的编程粒度是Block(每个Block只会被调度到一个SM上),而不是Thread。我们只需要考虑每个Block需要做什么,至于Thread/Warp的分布和调度,Triton自动给我们处理了。那么,Block这个概念,在Triton中通过什么进行表达呢?答案是:program。

接下来来一个最简单的例子,向量求和来说明语法的核心部分:
import triton
import triton.language as tl
@triton.jit
def add_kernel(x_ptr, # *Pointer* to first input vector.
y_ptr, # *Pointer* to second input vector.
output_ptr, # *Pointer* to output vector.
n_elements, # Size of the vector.
BLOCK_SIZE: tl.constexpr, # Number of elements each program should process.
# NOTE: `constexpr` so it can be used as a shape value.
):
# 有多个'程序'(也就是block)处理不同的数据。我们在这里标识我们是哪个程序:
pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0.
# 该程序将处理与初始数据偏移的输入。
# 例如,如果您有长度为 256 的向量和块大小为 64,程序
# 将分别访问元素[0:64, 64:128, 128:192, 192:256]。
# 请注意,偏移量是指针的列表:
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
# 创建一个mask以防止内存操作超出范围。
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
output = x + y
tl.store(output_ptr + offsets, output, mask=mask)
以上就是一个精炼且美妙的算子实现,不仅简单而且非常优雅地实现了向量求和(当然对于CUDA而言向量求和可能更容易实现)
之后是使用python进行封装操作:
def add(x: torch.Tensor, y: torch.Tensor):
# 我们需要预先分配输出。
output = torch.empty_like(x)
assert x.is_cuda and y.is_cuda and output.is_cuda
n_elements = output.numel()
# SPMD启动网格表示并行运行的内核实例数。
# 它类似于CUDA启动网格。对于add_kernel我们使用一个1D网格,其大小是块的数量:
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )
# 注意:
# - 每个torch.tensor对象都隐式地转换为指向其第一个元素的指针。
# - `triton.jit`'ed函数可以通过一个启动网格索引来获得一个可调用的GPU内核。
# - 不要忘记将元参数作为关键字参数传递。
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
# 我们返回一个指向z的句柄,但是,由于`torch.cuda.synchronize()`尚未被调用,内核此时仍在异步运行。
return output
上述两部分代码就揭示了Triton的编程核心,首先使用python封装Kernel并传入Grid参数,之后只需要在Kernel内部调用类似CUDA的实现即可。
对于Triton而言,其编译流程主要是Source->AST->Dialect->IR->PTX流程,当然细讲的话就过于复杂了,这里还是强调语法部分。想要学习的话可以参考这里:https://zhuanlan.zhihu.com/p/695171704
相关实现
Fused-Softmax
首先,关于softmax的实现就不再多说了,使用online来使得softmax访存次数只需要一次,但是问题在于Triton的实现好像不一般:
@triton.jit
def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr,
num_stages: tl.constexpr):
# starting row of the program
row_start = tl.program_id(0)
row_step = tl.num_programs(0)
for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages):
# The stride represents how much we need to increase the pointer to advance 1 row
row_start_ptr = input_ptr + row_idx * input_row_stride
# The block size is the next power of two greater than n_cols, so we can fit each
# row in a single block
col_offsets = tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
# Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
mask = col_offsets < n_cols
row = tl.load(input_ptrs, mask=mask, other=-float('inf'))
# Subtract maximum for numerical stability
row_minus_max = row - tl.max(row, axis=0)
# Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA)
numerator = tl.exp(row_minus_max)
denominator = tl.sum(numerator, axis=0)
softmax_output = numerator / denominator
# Write back output to DRAM
output_row_start_ptr = output_ptr + row_idx * output_row_stride
output_ptrs = output_row_start_ptr + col_offsets
tl.store(output_ptrs, softmax_output, mask=mask)
还是一行一行来吧,由于softmax操作一般是二维矩阵的操作,因此实际上就是一个二维的计算过程,且考虑到基于blcok实现的优化,自然相对来说会更加复杂一点。