返回

Triton

继续啃

Triton其实是和CUDA类似的,严格意义上算是一种基于python的DSL,但是其中可以使用python语法直接实现复杂的功能,同时由于在编译层面做了优化,因此可以减少很多不必要的麻烦,对应的也会因为增加一层黑盒而提升对其中的理解难度。这里这次仅仅通过语法和实现的角度来实现几个triton算子以及对应的示例。关于未来对于原理的解析或许可以慢慢理解并应用于优化。

基础概念

关于Triton的一些细粒度说明和原理后面就想到哪里写到哪里吧。首先从实现的角度来看的话,确实Triton还是非常相似于CUDA的,也因此能使用简单的语法实现类似CUDA的性能。

核心点1:Triton的调度级别是以Block为单位的,也因此可以屏蔽很多关于Warp、Thread的逻辑。

对于传统的CUDA实现而言,尽管相对来说有着较好的性能,但是带来的重要问题就是编写难度过大,编写时需要考虑较多的细节来实现更好的性能,对于CUDA而言,主要需要考虑如下方面:

  1. 从 DRAM 的内存传输必须合并成大型事务,以利用现代内存接口的大总线宽度(内存合并访问)。
  2. 数据必须在重复使用前手动存储到 SRAM 中,并进行管理来最小化bank conflict。
  3. 计算必须仔细地进行划分和调度,不仅是在流式多处理器(SMs)之间,还包括在其内部,以促进指令/线程级并行性,并利用专用的 ALU(例如,Tensor Cores)。

因此,相较于直接编写CUDA,Triton能帮助更好的实现从而关注算法而不是语法本身。

如果使用Triton,内存事务合并、SRAM管理以及SM内的线程调度都是自动进行的,我们只需要把精力花在SM之间管理即可,这也就是说,Triton的编程粒度是Block(每个Block只会被调度到一个SM上),而不是Thread。我们只需要考虑每个Block需要做什么,至于Thread/Warp的分布和调度,Triton自动给我们处理了。那么,Block这个概念,在Triton中通过什么进行表达呢?答案是:program

img

接下来来一个最简单的例子,向量求和来说明语法的核心部分:

import triton
import triton.language as tl

@triton.jit
def add_kernel(x_ptr,  # *Pointer* to first input vector.
               y_ptr,  # *Pointer* to second input vector.
               output_ptr,  # *Pointer* to output vector.
               n_elements,  # Size of the vector.
               BLOCK_SIZE: tl.constexpr,  # Number of elements each program should process.
               # NOTE: `constexpr` so it can be used as a shape value.
               ):
    #  有多个'程序'(也就是block)处理不同的数据。我们在这里标识我们是哪个程序:
    pid = tl.program_id(axis=0)  # We use a 1D launch grid so axis is 0.
    # 该程序将处理与初始数据偏移的输入。
    # 例如,如果您有长度为 256 的向量和块大小为 64,程序
    # 将分别访问元素[0:64, 64:128, 128:192, 192:256]。
    # 请注意,偏移量是指针的列表:
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    # 创建一个mask以防止内存操作超出范围。
    mask = offsets < n_elements
    x = tl.load(x_ptr + offsets, mask=mask)
    y = tl.load(y_ptr + offsets, mask=mask)
    output = x + y
    tl.store(output_ptr + offsets, output, mask=mask)

以上就是一个精炼且美妙的算子实现,不仅简单而且非常优雅地实现了向量求和(当然对于CUDA而言向量求和可能更容易实现)

之后是使用python进行封装操作:

def add(x: torch.Tensor, y: torch.Tensor):
    # 我们需要预先分配输出。
    output = torch.empty_like(x)
    assert x.is_cuda and y.is_cuda and output.is_cuda
    n_elements = output.numel()
    # SPMD启动网格表示并行运行的内核实例数。
    # 它类似于CUDA启动网格。对于add_kernel我们使用一个1D网格,其大小是块的数量:
    grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )
    # 注意:
    #  - 每个torch.tensor对象都隐式地转换为指向其第一个元素的指针。
    #  - `triton.jit`'ed函数可以通过一个启动网格索引来获得一个可调用的GPU内核。
    #  - 不要忘记将元参数作为关键字参数传递。
    add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
    # 我们返回一个指向z的句柄,但是,由于`torch.cuda.synchronize()`尚未被调用,内核此时仍在异步运行。
    return output

上述两部分代码就揭示了Triton的编程核心,首先使用python封装Kernel并传入Grid参数,之后只需要在Kernel内部调用类似CUDA的实现即可。

对于Triton而言,其编译流程主要是Source->AST->Dialect->IR->PTX流程,当然细讲的话就过于复杂了,这里还是强调语法部分。想要学习的话可以参考这里:https://zhuanlan.zhihu.com/p/695171704

相关实现

Fused-Softmax

首先,关于softmax的实现就不再多说了,使用online来使得softmax访存次数只需要一次,但是问题在于Triton的实现好像不一般:

@triton.jit
def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr,
                   num_stages: tl.constexpr):
    # starting row of the program
    row_start = tl.program_id(0)
    row_step = tl.num_programs(0)
    for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages):
        # The stride represents how much we need to increase the pointer to advance 1 row
        row_start_ptr = input_ptr + row_idx * input_row_stride
        # The block size is the next power of two greater than n_cols, so we can fit each
        # row in a single block
        col_offsets = tl.arange(0, BLOCK_SIZE)
        input_ptrs = row_start_ptr + col_offsets
        # Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
        mask = col_offsets < n_cols
        row = tl.load(input_ptrs, mask=mask, other=-float('inf'))
        # Subtract maximum for numerical stability
        row_minus_max = row - tl.max(row, axis=0)
        # Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA)
        numerator = tl.exp(row_minus_max)
        denominator = tl.sum(numerator, axis=0)
        softmax_output = numerator / denominator
        # Write back output to DRAM
        output_row_start_ptr = output_ptr + row_idx * output_row_stride
        output_ptrs = output_row_start_ptr + col_offsets
        tl.store(output_ptrs, softmax_output, mask=mask)

还是一行一行来吧,由于softmax操作一般是二维矩阵的操作,因此实际上就是一个二维的计算过程,且考虑到基于blcok实现的优化,自然相对来说会更加复杂一点。