Skip to content

Commit

Permalink
Introduce bf16 cuda support
Browse files Browse the repository at this point in the history
  • Loading branch information
jart committed May 23, 2024
1 parent 0b5997d commit c0aa43e
Showing 1 changed file with 27 additions and 0 deletions.
27 changes: 27 additions & 0 deletions llama.cpp/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,11 @@
#if defined(GGML_USE_HIPBLAS)
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <hip/hip_bfloat16.h>

#ifdef GGML_USE_TINYBLAS
#include "tinyblas.cu"
#define __nv_bfloat16 hip_bfloat16
#define CUBLAS_COMPUTE_16F TINYBLAS_COMPUTE_16F
#define CUBLAS_COMPUTE_32F TINYBLAS_COMPUTE_32F
#define CUBLAS_COMPUTE_32F_FAST_16F TINYBLAS_COMPUTE_32F
Expand Down Expand Up @@ -95,6 +97,7 @@
// for rocblas_initialize()
#include "rocblas/rocblas.h"
#endif // __HIP_PLATFORM_AMD__
#define __nv_bfloat16 hip_bfloat16
#define CUBLAS_COMPUTE_16F HIPBLAS_R_16F
#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
Expand Down Expand Up @@ -191,6 +194,7 @@
#include <cuda_runtime.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include "tinyblas.cu"

#define CUBLAS_COMPUTE_16F TINYBLAS_COMPUTE_16F
Expand Down Expand Up @@ -223,6 +227,7 @@
#include <cuda.h>
#include <cublas_v2.h>
#include <cuda_fp16.h>
#include <cuda_bf16.h>

#if CUDART_VERSION < 11020
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
Expand Down Expand Up @@ -2324,6 +2329,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
return dequantize_row_iq3_s_cuda;
case GGML_TYPE_F16:
return convert_unary_cuda<half>;
case GGML_TYPE_BF16:
return convert_unary_cuda<__nv_bfloat16>;
default:
return nullptr;
}
Expand Down Expand Up @@ -3348,6 +3355,14 @@ static __device__ void convert_f16(const void * vx, const int64_t ib, const int
v.y = x[ib + iqs + 1];
}

static __device__ void convert_bf16(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
const __nv_bfloat16 * x = (const __nv_bfloat16 *) vx;

// automatic __nv_bfloat16 -> float type cast if dfloat == float
v.x = x[ib + iqs + 0];
v.y = x[ib + iqs + 1];
}

template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) {
// qk = quantized weights per x block
Expand Down Expand Up @@ -3510,6 +3525,15 @@ static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, floa
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
}

static void convert_mul_mat_vec_bf16_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
dequantize_mul_mat_vec<1, 1, convert_bf16>
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
}

void ggml_cuda_op_dequantize_mul_mat_vec(
ggml_backend_cuda_context & ctx,
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
Expand Down Expand Up @@ -3575,6 +3599,9 @@ void ggml_cuda_op_dequantize_mul_mat_vec(
case GGML_TYPE_F16:
convert_mul_mat_vec_f16_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
break;
case GGML_TYPE_BF16:
convert_mul_mat_vec_bf16_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
break;
default:
GGML_ASSERT(false);
break;
Expand Down

0 comments on commit c0aa43e

Please sign in to comment.