Skip to content

Commit

Permalink
CUDA: deduplicate FlashAttention code (#7352)
Browse files Browse the repository at this point in the history
  • Loading branch information
JohannesGaessler committed May 18, 2024
1 parent cb42c29 commit 133d99c
Show file tree
Hide file tree
Showing 8 changed files with 315 additions and 653 deletions.
11 changes: 11 additions & 0 deletions ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,17 @@ static const __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -

typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v);

static __device__ __forceinline__ float get_alibi_slope(
const float max_bias, const uint32_t h, const uint32_t n_head_log2, const float m0, const float m1
) {
if (max_bias <= 0.0f) {
return 1.0f;
}
const float base = h < n_head_log2 ? m0 : m1;
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;

return powf(base, exph);
}

//////////////////////

Expand Down
115 changes: 115 additions & 0 deletions ggml-cuda/fattn-common.cuh
Original file line number Diff line number Diff line change
@@ -1,7 +1,44 @@
#include "common.cuh"

#include <cstdint>

#define FATTN_KQ_STRIDE 256
#define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
#define SOFTMAX_FTZ_THRESHOLD -20.0f // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs.

typedef void (* fattn_kernel_t)(
const char * __restrict__ Q,
const char * __restrict__ K,
const char * __restrict__ V,
const char * __restrict__ mask,
float * __restrict__ dst,
float2 * __restrict__ dst_meta,
const float scale,
const float max_bias,
const float m0,
const float m1,
const uint32_t n_head_log2,
const int ne00,
const int ne01,
const int ne02,
const int ne03,
const int ne10,
const int ne11,
const int ne12,
const int ne13,
const int ne31,
const int nb31,
const int nb01,
const int nb02,
const int nb03,
const int nb11,
const int nb12,
const int nb13,
const int ne0,
const int ne1,
const int ne2,
const int ne3);

template<int D, int parallel_blocks> // D == head size
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
__launch_bounds__(D, 1)
Expand Down Expand Up @@ -45,3 +82,81 @@ static __global__ void flash_attn_combine_results(

dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator;
}

template <int D, int parallel_blocks>
void launch_fattn(ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, int nwarps, int cols_per_block) {
const ggml_tensor * Q = dst->src[0];
const ggml_tensor * K = dst->src[1];
const ggml_tensor * V = dst->src[2];

const ggml_tensor * mask = dst->src[3];

ggml_tensor * KQV = dst;

GGML_ASSERT(Q->type == GGML_TYPE_F32);
GGML_ASSERT(K->type == GGML_TYPE_F16);
GGML_ASSERT(V->type == GGML_TYPE_F16);
GGML_ASSERT(KQV->type == GGML_TYPE_F32);

GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
"the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");

GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");

ggml_cuda_pool & pool = ctx.pool();
cudaStream_t main_stream = ctx.stream();

ggml_cuda_pool_alloc<float> dst_tmp(pool);
ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);

if (parallel_blocks > 1) {
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
}

const dim3 block_dim(WARP_SIZE, nwarps, 1);
const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]);
const int shmem = 0;

float scale = 1.0f;
float max_bias = 0.0f;

memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float));
memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));

const uint32_t n_head = Q->ne[2];
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));

const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);

fattn_kernel<<<blocks_num, block_dim, shmem, main_stream>>>(
(const char *) Q->data,
(const char *) K->data,
(const char *) V->data,
mask ? ((const char *) mask->data) : nullptr,
(parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
scale, max_bias, m0, m1, n_head_log2,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
Q->nb[1], Q->nb[2], Q->nb[3],
K->nb[1], K->nb[2], K->nb[3],
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
);
CUDA_CHECK(cudaGetLastError());

if ((parallel_blocks) == 1) {
return;
}

const dim3 block_dim_combine(D, 1, 1);
const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
const int shmem_combine = 0;

flash_attn_combine_results<D, parallel_blocks>
<<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
CUDA_CHECK(cudaGetLastError());
}
135 changes: 26 additions & 109 deletions ggml-cuda/fattn-tile-f16.cu
Original file line number Diff line number Diff line change
Expand Up @@ -54,17 +54,8 @@ static __global__ void flash_attn_tile_ext_f16(

const int stride_KV2 = nb11 / sizeof(half2);

half slopeh = __float2half(1.0f);

// ALiBi
if (max_bias > 0.0f) {
const uint32_t h = blockIdx.y;

const float base = h < n_head_log2 ? m0 : m1;
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;

slopeh = __float2half(powf(base, exph));
}
const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
const half slopeh = __float2half(slopef);

static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");

Expand Down Expand Up @@ -272,124 +263,50 @@ static __global__ void flash_attn_tile_ext_f16(
#endif // FP16_AVAILABLE
}

template <int D, int cols_per_block, int parallel_blocks> void launch_fattn_tile_f16(
const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask,
ggml_cuda_pool & pool, cudaStream_t main_stream
) {
ggml_cuda_pool_alloc<float> dst_tmp(pool);
ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);

if (parallel_blocks > 1) {
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
}

constexpr int nwarps = 8;
const dim3 block_dim(WARP_SIZE, nwarps, 1);
const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]);
const int shmem = 0;

float scale = 1.0f;
float max_bias = 0.0f;

memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float));
memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));

const uint32_t n_head = Q->ne[2];
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));

const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);

flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>
<<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) Q->data,
(const char *) K->data,
(const char *) V->data,
mask ? ((const char *) mask->data) : nullptr,
parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
scale, max_bias, m0, m1, n_head_log2,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
Q->nb[1], Q->nb[2], Q->nb[3],
K->nb[1], K->nb[2], K->nb[3],
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
);
CUDA_CHECK(cudaGetLastError());

if (parallel_blocks == 1) {
return;
template <int cols_per_block, int parallel_blocks>
void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * Q = dst->src[0];
switch (Q->ne[0]) {
case 64: {
constexpr int D = 64;
constexpr int nwarps = 8;
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>;
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
} break;
case 128: {
constexpr int D = 128;
constexpr int nwarps = 8;
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>;
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
} break;
default: {
GGML_ASSERT(false && "FlashAttention without tensor cores only supports head sizes 64 and 128.");
} break;
}

const dim3 block_dim_combine(D, 1, 1);
const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
const int shmem_combine = 0;

flash_attn_combine_results<D, parallel_blocks>
<<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
CUDA_CHECK(cudaGetLastError());
}

void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * Q = dst->src[0];
const ggml_tensor * K = dst->src[1];
const ggml_tensor * V = dst->src[2];

const ggml_tensor * mask = dst->src[3];

ggml_tensor * KQV = dst;
const ggml_tensor * KQV = dst;
const ggml_tensor * Q = dst->src[0];

const int32_t precision = KQV->op_params[2];
GGML_ASSERT(precision == GGML_PREC_DEFAULT);
GGML_ASSERT(Q->ne[0] == 64 || Q->ne[0] == 128 && "FlashAttention without tensor cores only supports head sizes 64 and 128.");

if (Q->ne[1] <= 16) {
constexpr int cols_per_block = 16;
constexpr int parallel_blocks = 4;
switch (Q->ne[0]) {
case 64:
launch_fattn_tile_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
break;
case 128:
launch_fattn_tile_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
break;
default:
GGML_ASSERT(false);
break;
}
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
return;
}

if (Q->ne[1] <= 32) {
constexpr int cols_per_block = 32;
constexpr int parallel_blocks = 4;
switch (Q->ne[0]) {
case 64:
launch_fattn_tile_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
break;
case 128:
launch_fattn_tile_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
break;
default:
GGML_ASSERT(false);
break;
}
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
return;
}

constexpr int cols_per_block = 32;
constexpr int parallel_blocks = 1;
switch (Q->ne[0]) {
case 64:
launch_fattn_tile_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
break;
case 128:
launch_fattn_tile_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
break;
default:
GGML_ASSERT(false);
break;
}
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
}

0 comments on commit 133d99c

Please sign in to comment.