返回

CUDA编程-通用矩阵乘法(GeMM)和CUDA优化

中途碰到了python实现git的教程,实在有趣所以学了一下。

一个计算的工具

https://xmartlabs.github.io/cuda-calculator/

什么是GeMM

说实话我一开始以为GeMM和MM是一个东西,仔细看完之后才发现其实有点区别,前者实际上就是对后者在计算机系统领域的一个抽象。

从数学的角度来谈,矩阵乘法有着相当复杂的优化空间而且看不懂,但是从计算机的角度是可以实现更加高效的计算和处理的,接下来会从CPU硬件架构慢慢到GPU架构下进行。

CPU下的优化

不多废话了。

image-20240807204907321
image-20240807204907321

最简单的肯定如上,但是时间是$O(n^3)$

对应的伪代码很简单:

for (int m = 0; m < M; m++) {
  for (int n = 0; n < N; n++) {
    C[m][n] = 0;
    for (int k = 0; k < K; k++) {
      C[m][n] += A[m][k] * B[k][n];
    }
  }
}

对于这样的算法优化一般可以分为两类:

前者这里就不多说明,毕竟这部分确实是数学家该做的事情,我们要做的是如何利用架构尽可能优化后者的情况。

根据空间连续性和时间连续性

说实话这个图拉出来就知道发生什么了:

image-20240807205344740
image-20240807205344740

因为线性存储且缓存,所以每次四块数据会保存在缓存中,因此可以展开一下循环来加快速度(实际上这种展开是看存储器结构的):

for (int m = 0; m < M; m++) {
  for (int n = 0; n < N; n += 4) {
    C[m][n + 0] = 0;
    C[m][n + 1] = 0;
    C[m][n + 2] = 0;
    C[m][n + 3] = 0;
    for (int k = 0; k < K; k++) {
      C[m][n + 0] += A[m][k] * B[k][n + 0];
      C[m][n + 1] += A[m][k] * B[k][n + 1];
      C[m][n + 2] += A[m][k] * B[k][n + 2];
      C[m][n + 3] += A[m][k] * B[k][n + 3];
    }
  }
}

直到这一步还是可以看出来的,接下来操作实际上是对输出进行的,因为输出同样满足上面的缓存性质:

image-20240807210020859
image-20240807210020859

所以就变成了:

for (int m = 0; m < M; m += 4) {
  for (int n = 0; n < N; n += 4) {
    C[m + 0][n + 0..3] = 0;
    C[m + 1][n + 0..3] = 0;
    C[m + 2][n + 0..3] = 0;
    C[m + 3][n + 0..3] = 0;
    for (int k = 0; k < K; k++) {
      C[m + 0][n + 0..3] += A[m + 0][k] * B[k][n + 0..3];
      C[m + 1][n + 0..3] += A[m + 1][k] * B[k][n + 0..3];
      C[m + 2][n + 0..3] += A[m + 2][k] * B[k][n + 0..3];
      C[m + 3][n + 0..3] += A[m + 3][k] * B[k][n + 0..3];
    }
  }
}

由于中间对于C矩阵需要进行归约操作,因此理论上存储到寄存器的速度最快,所以可以分解成$4 \times 4$的小块进行

image-20240807210511354
image-20240807210511354

因为伪代码无法提供寄存器优化部分所以看看整体的一个图:

image-20240807210404402
image-20240807210404402

此外,根据矩阵变量的精度不同也可以根据内存块继续进行优化(这时候就要底层到数据结构在内存中的排布了)

image-20240807211314668
image-20240807211314668

此时源代码就变成了:

for (int mo = 0; mo < M; mo += 8) {
  for (int no = 0; no < N; no += 8) {
    for (int mi = 0; mi < 2;mi ++) {
      for (int ni = 0; ni < 2; ni++) {
        int m = mo + mi * 4;
        int n = no + ni * 4;
        C[m + 0..3][n + 0..3] = 0;
        C[m + 0..3][n + 0..3] = 0;
        C[m + 0..3][n + 0..3] = 0;
        C[m + 0..3][n + 0..3] = 0;
        for (int k = 0; k < K; k += 4) {
          C[m + 0..3][n + 0..3] += A[m + 0..3][k + 0] * B[k + 0][n + 0..3];
          C[m + 0..3][n + 0..3] += A[m + 0..3][k + 1] * B[k + 1][n + 0..3];
          C[m + 0..3][n + 0..3] += A[m + 0..3][k + 2] * B[k + 2][n + 0..3];
          C[m + 0..3][n + 0..3] += A[m + 0..3][k + 3] * B[k + 3][n + 0..3];
        }
      }
    }
  }
}

说实话,到这里,CPU上的优化就差不多了,但是上述的优化在GPU上还能继续大放光彩。不管怎么说先写一版CPU计算的C代码来看看:

#define OFFSET(row, col, ld) ((row) * (ld) + (col))

void cpuSgemm(
    float *a, float *b, float *c, const int M, const int N, const int K) {

    for (int m = 0; m < M; m++) {
        for (int n = 0; n < N; n++) {
            float psum = 0.0;
            for (int k = 0; k < K; k++) {
                psum += a[OFFSET(m, k, K)] * b[OFFSET(k, n, N)];
            }
            c[OFFSET(m, n, N)] = psum;
        }
    }
}

CUDA来咯

CPU还是有一定的扩展性的,但是在现在的时代还是用CUDA做并行加速更强大一点。

首先写一个最简单的GPU版的GeMM:

__global__ void naiveSgemm(
    float * __restrict__ a, float * __restrict__ b, float * __restrict__ c,
    const int M, const int N, const int K) {

    int n = blockIdx.x * blockDim.x + threadIdx.x;
    int m = blockIdx.y * blockDim.y + threadIdx.y;
    if (m < M && n < N) {
        float psum = 0.0;
        #pragma unroll
        for (int k = 0; k < K; k++) {
            psum += a[OFFSET(m, k, K)] * b[OFFSET(k, n, N)];
        }
        c[OFFSET(m, n, N)] = psum;
    }
}

const int BM = 32, BN = 32;
const int M = 512, N = 512, K = 512;
dim3 blockDim(BN, BM);
dim3 gridDim((N + BN - 1) / BN, (M + BM - 1) / BM);

这个很多简单倒是,每个线程处理一个矩阵上的数据,但是显然上述代码没有实现任何存储结构,计算结构上的优化,因此需要大量优化。

首先说一下上面的代码的从内存、计算上的流程:

  • 在Globel Memory中为三个矩阵分配存储空间
  • 每个矩阵C计算独立,所以每个thread对应一个值的计算
  • 执行线程的配置(说实话,这种配置相较于之前学习时使用的配置手段好理解的多,但是从直觉上性能应该不高(因为线程束不相邻))

$$girdDim.x \times blockDim.x = N$$

$$gridDim.y \times blockDim.y = M$$

每个thread的workflow如下:从矩阵A中读取长度为k的向量,从矩阵B中读取长度为k的列向量,做循环点积计算,最后写回C矩阵,整体的读写相当花费带宽:

$$K\times M\times N\times 4 Bytes +M\times N \times 4 Bytes$$

由于32个线程属于一个线程束(看架构),所以读取矩阵B的时候可以一次读取32列数据。尽管如此还是差太多了。

优化第一步: 共享内存

一次累加运算需要两次global memory的load才能实现,这种访存导致性能相当低。所以可以把一些数据放到shared memory。

首先把矩阵C分成$$BM\times BN$$大小的分块,每个分块由一个block计算,其中每个Thread计算矩阵中的$TM\times TN$个元素,之后计算的数据就可以从一个smem中读取了(一个线程束放在一个block中,而这部分shared memory是共享的)

接下来是一个复杂但是很重要的分析:

首先分块之后,对于每个分块有

计算量:$BM\times BN\times K \times 2$

访存量:$(BM + BN)\times K \times 4 Bytes$

计算访存量两者比一下就好了,结果为$\frac{BM\cdot BN}{2(BM+BN)}=\frac{1}{2(\frac{1}{BN}+\frac{1}{BM})}$

显然,BM和BN越大,计算访存比越高,性能就会越好。但是在基础部分实际上学过了,受到各个因素的限制,这些个数字是不能无限制增大的。

首先是shared memory的大小,对于V100,1个SM仅仅的shared memory只有96KB,但是一个Block的数据要占用:$BK * (BM+BN)*4 Bytes$。

再者,TM和TN也是受限的,首先,对于不同架构,有着对Block中线程数量总数的限制,在V100中一个block的线程数量不能超过1024,且如果太小的话会影响SM中Block间的并行。此外,寄存器数量有限,一个线程就需要$TM\times TN$个寄存器来存放结果,由于总数不能超过256,否则也会影响并并行效果。

上面的分析虽然复杂但是相当有效且有用。

基于上述考量,选择$BM=BN=128, BK=8, TM=TN=8$,此时代码长这个样子(代码建议配合图食用):

image-20240809201303889
image-20240809201303889

#define FLOAT4(pointer) (reinterpret_cast<float4*>(&(pointer))[0])
\*
 	FLOAT4(pointer) 用于将指针 pointer 转换为 float4 类型的指针,并访问其第一个元素。具体来说,它的作用是将 pointer 强制转换为 float4* 类型,并返回指向该 float4 类型的第一个元素的引用。
*/
__global__ void sgemm_V1(
    float * __restrict__ a, float * __restrict__ b, float * __restrict__ c,
    const int M, const int N, const int K) {

    const int BM = 128;
    const int BN = 128;
    const int BK = 8;
    const int TM = 8;
    const int TN = 8;

    const int bx = blockIdx.x;
    const int by = blockIdx.y;
    const int tx = threadIdx.x;
    const int ty = threadIdx.y;
    const int tid = ty * blockDim.x + tx;

    __shared__ float s_a[BM][BK];
    __shared__ float s_b[BK][BN];

    float r_c[TM][TN] = {0.0};

    int load_a_smem_m = tid >> 1;  // tid/2, row of s_a
    int load_a_smem_k = (tid & 1) << 2;  // (tid % 2 == 0) ? 0 : 4, col of s_a
    int load_b_smem_k = tid >> 5;   // tid/32, row of s_b
    int load_b_smem_n = (tid & 31) << 2;  // (tid % 32) * 4, col of s_b

    int load_a_gmem_m = by * BM + load_a_smem_m;  // global row of a
    int load_b_gmem_n = bx * BN + load_b_smem_n;  // global col of b

    for (int bk = 0; bk < (K + BK - 1) / BK; bk++) {
        int load_a_gmem_k = bk * BK + load_a_smem_k;   // global col of a
        int load_a_gmem_addr = OFFSET(load_a_gmem_m, load_a_gmem_k, K);
        FLOAT4(s_a[load_a_smem_m][load_a_smem_k]) = FLOAT4(a[load_a_gmem_addr]);
        int load_b_gmem_k = bk * BK + load_b_smem_k;   // global row of b
        int load_b_gmem_addr = OFFSET(load_b_gmem_k, load_b_gmem_n, N);
        FLOAT4(s_b[load_b_smem_k][load_b_smem_n]) = FLOAT4(b[load_b_gmem_addr]);

        __syncthreads();

        #pragma unroll
        for (int k = 0; k < BK; k++) {
            #pragma unroll
            for (int m = 0; m < TM; m++) {
                #pragma unroll
                for (int n = 0; n < TN; n++) {
                    int comp_a_smem_m = ty * TM + m;
                    int comp_b_smem_n = tx * TN + n;
                    r_c[m][n] += s_a[comp_a_smem_m][k] * s_b[k][comp_b_smem_n];
                }
            }
        }

        __syncthreads();
    }

    #pragma unroll
    for (int i = 0; i < TM; i++) {
        int store_c_gmem_m = by * BM + ty * TM + i;
        #pragma unroll
        for (int j = 0; j < TN; j += 4) {
            int store_c_gmem_n = bx * BN + tx * TN + j;
            int store_c_gmem_addr = OFFSET(store_c_gmem_m, store_c_gmem_n, N);
            FLOAT4(c[store_c_gmem_addr]) = FLOAT4(r_c[i][j]);
        }
    }
}

说实话相较于之前教程的代码看起来有点乱,但是不影响我们分析。其pipeline如下:

  • 将矩阵分块$A_{[BM, BK]},B_{[BK, BN]}$放到shared memory中

这里说实话有点相当复杂。首先是对于每块,$A_{[BM, BK]}$每个tbread需要搬运$\frac{BKTMTN}{BN}$在这里是4个浮点数,这正好可以用CUDA的一个float4数据结构进行存储(很显然,一个8B*4正好能对齐),这时候,对于上述配置的分块,其索引关系如下左

image-20240809193924810
image-20240809193924810

这时候就可以考虑把数据放到共享存储里了,但是对应的是需要一个索引和转储的过程,也就是后面要操作的s_a,s_b对象,不难理解,我们要做的就是把一部分存储放在shared memory用来减少访存的次数,在这种情况下,load_a_smem_m=tid/2=tid >> 2就是s_a的行号。对应的列号load_a_smem_k = (tid % 2 == 0) ? 0 : 4 = (tid & 1) << 2实际上是线程在shared memory的索引,同理可以得到矩阵B的分布 int load_b_smem_k = tid >> 5, int load_b_smem_n = (tid & 31) << 2

上面只不过是单个block的执行过程,在多个block索引分块的时候Global Memory的对应关系还是有变化的,还是以矩阵A为例子,分块$A_{[BM,BK]}$按着行进行,所以首先确定行号,根据Grid的二维全局线性索引关系,则分块的起始行号应该是by*BM全局的行号就应该是load_a_gmem_m = by * BM + load_a_smem_m。对于列号有所不同,分块沿着行方向进行,所以列是变化的,需要在循环内部进行计算,先计算起始列号bk*BK加上分块内部的列号load_a_smem_k可以得到load_a_gmem_k = bk*BK+load_a_smem_k从而确定分块在原始数据中的位置OFFSET(load_a_gmem_m, load_a_gmem_k, K)

  • 计算分块矩阵$C_{[TM, TN]}$知道s_a, s_b之后计算得到对应的r_c即可。然后存入global memory。当然这个过程也是复杂的索引变换过程

优化第二步: 解决Bank Conflict问题

上面大大提高了访存效率从而提高性能,下一步是继续优化共享内存的使用。

这一步优化其实还是挺印象深刻的,因为共享内存分为32个bank,每个bank宽度为4B,如果多次访问同一个Bank的数据,就会导致Bank Conflict问题,这个解决方案之前就是错位。

先看看前面矩阵乘法导致的Bank Conflict问题

  • 去矩阵A需要取一个列向量,而A在shared memory中是按行存储的,从而conflict了
  • 此外,当TM=TN=8时,需要从shared memory中取连续8个地址,一条指令取四个数就需要两个指令,由于一个线程的两个load指令地址是连续的,此外,由于同一个warp不同线程同一条load指令的访存地址是隔开,所以一次是同时对一个bank进行tid/2数量的访存,同样会导致bank conflict

所以需要进行两点优化

  • 为A分配的时候转置一下,按列进行存储。
  • 将每个线程负责计算的TM*TN划分成两个,由于一条指令实现A的一块load操作,所以两个load可以同时执行

(虽然这里没有提到,但是在之前的编程中,我们知道实际上也是可以通过错位手段来解决bank conflict问题的)

image-20240809203004897
image-20240809203004897

__global__ void sgemm_V2(
    float * __restrict__ a, float * __restrict__ b, float * __restrict__ c,
    const int M, const int N, const int K) {

    const int BM = 128;
    const int BN = 128;
    const int BK = 8;
    const int TM = 8;
    const int TN = 8;

    const int bx = blockIdx.x;
    const int by = blockIdx.y;
    const int tx = threadIdx.x;
    const int ty = threadIdx.y;
    const int tid = ty * blockDim.x + tx;

    __shared__ float s_a[BK][BM];
    __shared__ float s_b[BK][BN];

    float r_load_a[4];
    float r_load_b[4];
    float r_comp_a[TM];
    float r_comp_b[TN];
    float r_c[TM][TN] = {0.0};

    int load_a_smem_m = tid >> 1;
    int load_a_smem_k = (tid & 1) << 2;
    int load_b_smem_k = tid >> 5;
    int load_b_smem_n = (tid & 31) << 2;

    int load_a_gmem_m = by * BM + load_a_smem_m;
    int load_b_gmem_n = bx * BN + load_b_smem_n;

    for (int bk = 0; bk < (K + BK - 1) / BK; bk++) {

        int load_a_gmem_k = bk * BK + load_a_smem_k;
        int load_a_gmem_addr = OFFSET(load_a_gmem_m, load_a_gmem_k, K);
        int load_b_gmem_k = bk * BK + load_b_smem_k;
        int load_b_gmem_addr = OFFSET(load_b_gmem_k, load_b_gmem_n, N);
        FLOAT4(r_load_a[0]) = FLOAT4(a[load_a_gmem_addr]);
        FLOAT4(r_load_b[0]) = FLOAT4(b[load_b_gmem_addr]);

        s_a[load_a_smem_k    ][load_a_smem_m] = r_load_a[0];
        s_a[load_a_smem_k + 1][load_a_smem_m] = r_load_a[1];
        s_a[load_a_smem_k + 2][load_a_smem_m] = r_load_a[2];
        s_a[load_a_smem_k + 3][load_a_smem_m] = r_load_a[3];
        FLOAT4(s_b[load_b_smem_k][load_b_smem_n]) = FLOAT4(r_load_b[0]);

        __syncthreads();

        #pragma unroll
        for (int tk = 0; tk < BK; tk++) {
            FLOAT4(r_comp_a[0]) = FLOAT4(s_a[tk][ty * TM / 2         ]);
            FLOAT4(r_comp_a[4]) = FLOAT4(s_a[tk][ty * TM / 2 + BM / 2]);
            FLOAT4(r_comp_b[0]) = FLOAT4(s_b[tk][tx * TN / 2         ]);
            FLOAT4(r_comp_b[4]) = FLOAT4(s_b[tk][tx * TN / 2 + BN / 2]);

            #pragma unroll
            for (int tm = 0; tm < TM; tm++) {
                #pragma unroll
                for (int tn = 0; tn < TN; tn++) {
                    r_c[tm][tn] += r_comp_a[tm] * r_comp_b[tn];
                }
            }
        }

        __syncthreads();
    }

    #pragma unroll
    for (int i = 0; i < TM / 2; i++) {
        int store_c_gmem_m = by * BM + ty * TM / 2 + i;
        int store_c_gmem_n = bx * BN + tx * TN / 2;
        int store_c_gmem_addr = OFFSET(store_c_gmem_m, store_c_gmem_n, N);
        FLOAT4(c[store_c_gmem_addr]) = FLOAT4(r_c[i][0]);
        FLOAT4(c[store_c_gmem_addr + BN / 2]) = FLOAT4(r_c[i][4]);
    }
    #pragma unroll
    for (int i = 0; i < TM / 2; i++) {
        int store_c_gmem_m = by * BM + BM / 2 + ty * TM / 2 + i;
        int store_c_gmem_n = bx * BN + tx * TN / 2;
        int store_c_gmem_addr = OFFSET(store_c_gmem_m, store_c_gmem_n, N);
        FLOAT4(c[store_c_gmem_addr]) = FLOAT4(r_c[i + TM / 2][0]);
        FLOAT4(c[store_c_gmem_addr + BN / 2]) = FLOAT4(r_c[i + TM / 2][4]);
    }
}

优化第三步:流水并行化 Double Buffering

通过增加缓存使得整个过程称为流水线以减少等待时间,提高效率:

image-20240809203614053
image-20240809203614053

从代码的话其实看不出来哪里改了:

__global__ void sgemm_V3(
    float * __restrict__ a, float * __restrict__ b, float * __restrict__ c,
    const int M, const int N, const int K) {

    const int BM = 128;
    const int BN = 128;
    const int BK = 8;
    const int TM = 8;
    const int TN = 8;

    const int bx = blockIdx.x;
    const int by = blockIdx.y;
    const int tx = threadIdx.x;
    const int ty = threadIdx.y;
    const int tid = ty * blockDim.x + tx;

    __shared__ float s_a[2][BK][BM];
    __shared__ float s_b[2][BK][BN];

    float r_load_a[4];
    float r_load_b[4];
    float r_comp_a[TM];
    float r_comp_b[TN];
    float r_c[TM][TN] = {0.0};

    int load_a_smem_m = tid >> 1;
    int load_a_smem_k = (tid & 1) << 2;
    int load_b_smem_k = tid >> 5;
    int load_b_smem_n = (tid & 31) << 2;

    int load_a_gmem_m = by * BM + load_a_smem_m;
    int load_b_gmem_n = bx * BN + load_b_smem_n;

    {
        int load_a_gmem_k = load_a_smem_k;
        int load_a_gmem_addr = OFFSET(load_a_gmem_m, load_a_gmem_k, K);
        int load_b_gmem_k = load_b_smem_k;
        int load_b_gmem_addr = OFFSET(load_b_gmem_k, load_b_gmem_n, N);
        FLOAT4(r_load_a[0]) = FLOAT4(a[load_a_gmem_addr]);
        FLOAT4(r_load_b[0]) = FLOAT4(b[load_b_gmem_addr]);

        s_a[0][load_a_smem_k    ][load_a_smem_m] = r_load_a[0];
        s_a[0][load_a_smem_k + 1][load_a_smem_m] = r_load_a[1];
        s_a[0][load_a_smem_k + 2][load_a_smem_m] = r_load_a[2];
        s_a[0][load_a_smem_k + 3][load_a_smem_m] = r_load_a[3];
        FLOAT4(s_b[0][load_b_smem_k][load_b_smem_n]) = FLOAT4(r_load_b[0]);
    }

    for (int bk = 1; bk < (K + BK - 1) / BK; bk++) {

        int smem_sel = (bk - 1) & 1;
        int smem_sel_next = bk & 1;

        int load_a_gmem_k = bk * BK + load_a_smem_k;
        int load_a_gmem_addr = OFFSET(load_a_gmem_m, load_a_gmem_k, K);
        int load_b_gmem_k = bk * BK + load_b_smem_k;
        int load_b_gmem_addr = OFFSET(load_b_gmem_k, load_b_gmem_n, N);
        FLOAT4(r_load_a[0]) = FLOAT4(a[load_a_gmem_addr]);
        FLOAT4(r_load_b[0]) = FLOAT4(b[load_b_gmem_addr]);

        #pragma unroll
        for (int tk = 0; tk < BK; tk++) {
            FLOAT4(r_comp_a[0]) = FLOAT4(s_a[smem_sel][tk][ty * TM / 2         ]);
            FLOAT4(r_comp_a[4]) = FLOAT4(s_a[smem_sel][tk][ty * TM / 2 + BM / 2]);
            FLOAT4(r_comp_b[0]) = FLOAT4(s_b[smem_sel][tk][tx * TN / 2         ]);
            FLOAT4(r_comp_b[4]) = FLOAT4(s_b[smem_sel][tk][tx * TN / 2 + BN / 2]);

            #pragma unroll
            for (int tm = 0; tm < TM; tm++) {
                #pragma unroll
                for (int tn = 0; tn < TN; tn++) {
                    r_c[tm][tn] += r_comp_a[tm] * r_comp_b[tn];
                }
            }
        }

        s_a[smem_sel_next][load_a_smem_k    ][load_a_smem_m] = r_load_a[0];
        s_a[smem_sel_next][load_a_smem_k + 1][load_a_smem_m] = r_load_a[1];
        s_a[smem_sel_next][load_a_smem_k + 2][load_a_smem_m] = r_load_a[2];
        s_a[smem_sel_next][load_a_smem_k + 3][load_a_smem_m] = r_load_a[3];
        FLOAT4(s_b[smem_sel_next][load_b_smem_k][load_b_smem_n]) = FLOAT4(r_load_b[0]);

        __syncthreads();
    }

    #pragma unroll
    for (int tk = 0; tk < BK; tk++) {
        FLOAT4(r_comp_a[0]) = FLOAT4(s_a[1][tk][ty * TM / 2         ]);
        FLOAT4(r_comp_a[4]) = FLOAT4(s_a[1][tk][ty * TM / 2 + BM / 2]);
        FLOAT4(r_comp_b[0]) = FLOAT4(s_b[1][tk][tx * TN / 2         ]);
        FLOAT4(r_comp_b[4]) = FLOAT4(s_b[1][tk][tx * TN / 2 + BN / 2]);

        #pragma unroll
        for (int tm = 0; tm < TM; tm++) {
            #pragma unroll
            for (int tn = 0; tn < TN; tn++) {
                r_c[tm][tn] += r_comp_a[tm] * r_comp_b[tn];
            }
        }
    }

    #pragma unroll
    for (int i = 0; i < TM / 2; i++) {
        int store_c_gmem_m = by * BM + ty * TM / 2 + i;
        int store_c_gmem_n = bx * BN + tx * TN / 2;
        int store_c_gmem_addr = OFFSET(store_c_gmem_m, store_c_gmem_n, N);
        FLOAT4(c[store_c_gmem_addr]) = FLOAT4(r_c[i][0]);
        FLOAT4(c[store_c_gmem_addr + BN / 2]) = FLOAT4(r_c[i][4]);
    }
    #pragma unroll
    for (int i = 0; i < TM / 2; i++) {
        int store_c_gmem_m = by * BM + BM / 2 + ty * TM / 2 + i;
        int store_c_gmem_n = bx * BN + tx * TN / 2;
        int store_c_gmem_addr = OFFSET(store_c_gmem_m, store_c_gmem_n, N);
        FLOAT4(c[store_c_gmem_addr]) = FLOAT4(r_c[i + TM / 2][0]);
        FLOAT4(c[store_c_gmem_addr + BN / 2]) = FLOAT4(r_c[i + TM / 2][4]);
    }
}

实际上核心在于这一段:

 __shared__ float s_a[2][BK][BM];
    __shared__ float s_b[2][BK][BN];

    float r_load_a[4];
    float r_load_b[4];
    float r_comp_a[TM];
    float r_comp_b[TN];
    float r_c[TM][TN] = {0.0};

    int load_a_smem_m = tid >> 1;
    int load_a_smem_k = (tid & 1) << 2;
    int load_b_smem_k = tid >> 5;
    int load_b_smem_n = (tid & 31) << 2;

    int load_a_gmem_m = by * BM + load_a_smem_m;
    int load_b_gmem_n = bx * BN + load_b_smem_n;

    {
        int load_a_gmem_k = load_a_smem_k;
        int load_a_gmem_addr = OFFSET(load_a_gmem_m, load_a_gmem_k, K);
        int load_b_gmem_k = load_b_smem_k;
        int load_b_gmem_addr = OFFSET(load_b_gmem_k, load_b_gmem_n, N);
        FLOAT4(r_load_a[0]) = FLOAT4(a[load_a_gmem_addr]);
        FLOAT4(r_load_b[0]) = FLOAT4(b[load_b_gmem_addr]);

        s_a[0][load_a_smem_k    ][load_a_smem_m] = r_load_a[0];
        s_a[0][load_a_smem_k + 1][load_a_smem_m] = r_load_a[1];
        s_a[0][load_a_smem_k + 2][load_a_smem_m] = r_load_a[2];
        s_a[0][load_a_smem_k + 3][load_a_smem_m] = r_load_a[3];
        FLOAT4(s_b[0][load_b_smem_k][load_b_smem_n]) = FLOAT4(r_load_b[0]);
    }

    for (int bk = 1; bk < (K + BK - 1) / BK; bk++) {

        int smem_sel = (bk - 1) & 1;
        int smem_sel_next = bk & 1;

        int load_a_gmem_k = bk * BK + load_a_smem_k;
        int load_a_gmem_addr = OFFSET(load_a_gmem_m, load_a_gmem_k, K);
        int load_b_gmem_k = bk * BK + load_b_smem_k;
        int load_b_gmem_addr = OFFSET(load_b_gmem_k, load_b_gmem_n, N);
        FLOAT4(r_load_a[0]) = FLOAT4(a[load_a_gmem_addr]);
        FLOAT4(r_load_b[0]) = FLOAT4(b[load_b_gmem_addr]);

        #pragma unroll
        for (int tk = 0; tk < BK; tk++) {
            FLOAT4(r_comp_a[0]) = FLOAT4(s_a[smem_sel][tk][ty * TM / 2         ]);
            FLOAT4(r_comp_a[4]) = FLOAT4(s_a[smem_sel][tk][ty * TM / 2 + BM / 2]);
            FLOAT4(r_comp_b[0]) = FLOAT4(s_b[smem_sel][tk][tx * TN / 2         ]);
            FLOAT4(r_comp_b[4]) = FLOAT4(s_b[smem_sel][tk][tx * TN / 2 + BN / 2]);

            #pragma unroll
            for (int tm = 0; tm < TM; tm++) {
                #pragma unroll
                for (int tn = 0; tn < TN; tn++) {
                    r_c[tm][tn] += r_comp_a[tm] * r_comp_b[tn];
                }
            }
        }

        s_a[smem_sel_next][load_a_smem_k    ][load_a_smem_m] = r_load_a[0];
        s_a[smem_sel_next][load_a_smem_k + 1][load_a_smem_m] = r_load_a[1];
        s_a[smem_sel_next][load_a_smem_k + 2][load_a_smem_m] = r_load_a[2];
        s_a[smem_sel_next][load_a_smem_k + 3][load_a_smem_m] = r_load_a[3];
        FLOAT4(s_b[smem_sel_next][load_b_smem_k][load_b_smem_n]) = FLOAT4(r_load_b[0]);

        __syncthreads();
    }

这里专门用大括号划出来了一段作用域,看起来只是拆了一个循环出来,实际上还是有相当大的变化的。如果是在一个作用域内则之间指令是无法判断其相互独立的,从而导致等待,实际上共享内存取出之后就可以暂时放弃掉,通过划分这个作用域使得编译器知道这部分和下面的内容是独立的,总而可以直接进行写回操作,从而提高了性能。

通过上述操作实际上都能达到Cublas的水准了。(建议常看常新)

Licensed under CC BY-NC-SA 4.0