Skip to content

Commit 5c86c9e

Browse files
CUDA: fix crash on large batch size for MoE models (#13384)
1 parent efb8b47 commit 5c86c9e

File tree

1 file changed

+14
-12
lines changed

1 file changed

+14
-12
lines changed

ggml/src/ggml-cuda/getrows.cu

+14-12
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,11 @@ static __global__ void k_get_rows(
1010
/*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
1111
const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) {
1212

13-
const int i00 = (blockIdx.x*blockDim.x + threadIdx.x)*2;
14-
const int i10 = blockDim.y*blockIdx.y + threadIdx.y;
15-
const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12;
16-
const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12;
13+
// The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
14+
const int i00 = (blockIdx.y * blockDim.x + threadIdx.x)*2;
15+
const int i10 = blockIdx.x;
16+
const int i11 = blockIdx.z / ne12;
17+
const int i12 = blockIdx.z % ne12;
1718

1819
if (i00 >= ne00) {
1920
return;
@@ -46,10 +47,11 @@ static __global__ void k_get_rows_float(
4647
/*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
4748
const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) {
4849

49-
const int i00 = blockIdx.x*blockDim.x + threadIdx.x;
50-
const int i10 = blockDim.y*blockIdx.y + threadIdx.y;
51-
const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12;
52-
const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12;
50+
// The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
51+
const int i00 = blockIdx.y * blockDim.x + threadIdx.x;
52+
const int i10 = blockIdx.x;
53+
const int i11 = blockIdx.z / ne12;
54+
const int i12 = blockIdx.z % ne12;
5355

5456
if (i00 >= ne00) {
5557
return;
@@ -94,8 +96,8 @@ static void get_rows_cuda_q(
9496
const size_t nb1, const size_t nb2, const size_t nb3,
9597
cudaStream_t stream) {
9698
const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
97-
const int block_num_x = (ne00 + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE);
98-
const dim3 block_nums(block_num_x, ne10, ne11*ne12);
99+
const int block_num_y = (ne00 + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE);
100+
const dim3 block_nums(ne10, block_num_y, ne11*ne12);
99101

100102
// strides in elements
101103
// const size_t s0 = nb0 / sizeof(dst_t);
@@ -127,8 +129,8 @@ static void get_rows_cuda_float(
127129
const size_t nb1, const size_t nb2, const size_t nb3,
128130
cudaStream_t stream) {
129131
const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
130-
const int block_num_x = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE;
131-
const dim3 block_nums(block_num_x, ne10, ne11*ne12);
132+
const int block_num_y = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE;
133+
const dim3 block_nums(ne10, block_num_y, ne11*ne12);
132134

133135
// strides in elements
134136
// const size_t s0 = nb0 / sizeof(dst_t);

0 commit comments

Comments
 (0)