@@ -10,10 +10,11 @@ static __global__ void k_get_rows(
10
10
/* const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
11
11
const size_t s10, const size_t s11, const size_t s12/* , const size_t s13*/ ) {
12
12
13
- const int i00 = (blockIdx .x *blockDim .x + threadIdx .x )*2 ;
14
- const int i10 = blockDim .y *blockIdx .y + threadIdx .y ;
15
- const int i11 = (blockIdx .z *blockDim .z + threadIdx .z )/ne12;
16
- const int i12 = (blockIdx .z *blockDim .z + threadIdx .z )%ne12;
13
+ // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
14
+ const int i00 = (blockIdx .y * blockDim .x + threadIdx .x )*2 ;
15
+ const int i10 = blockIdx .x ;
16
+ const int i11 = blockIdx .z / ne12;
17
+ const int i12 = blockIdx .z % ne12;
17
18
18
19
if (i00 >= ne00) {
19
20
return ;
@@ -46,10 +47,11 @@ static __global__ void k_get_rows_float(
46
47
/* const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
47
48
const size_t s10, const size_t s11, const size_t s12/* , const size_t s13*/ ) {
48
49
49
- const int i00 = blockIdx .x *blockDim .x + threadIdx .x ;
50
- const int i10 = blockDim .y *blockIdx .y + threadIdx .y ;
51
- const int i11 = (blockIdx .z *blockDim .z + threadIdx .z )/ne12;
52
- const int i12 = (blockIdx .z *blockDim .z + threadIdx .z )%ne12;
50
+ // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
51
+ const int i00 = blockIdx .y * blockDim .x + threadIdx .x ;
52
+ const int i10 = blockIdx .x ;
53
+ const int i11 = blockIdx .z / ne12;
54
+ const int i12 = blockIdx .z % ne12;
53
55
54
56
if (i00 >= ne00) {
55
57
return ;
@@ -94,8 +96,8 @@ static void get_rows_cuda_q(
94
96
const size_t nb1, const size_t nb2, const size_t nb3,
95
97
cudaStream_t stream) {
96
98
const dim3 block_dims (CUDA_GET_ROWS_BLOCK_SIZE, 1 , 1 );
97
- const int block_num_x = (ne00 + 2 *CUDA_GET_ROWS_BLOCK_SIZE - 1 ) / (2 *CUDA_GET_ROWS_BLOCK_SIZE);
98
- const dim3 block_nums (block_num_x, ne10 , ne11*ne12);
99
+ const int block_num_y = (ne00 + 2 *CUDA_GET_ROWS_BLOCK_SIZE - 1 ) / (2 *CUDA_GET_ROWS_BLOCK_SIZE);
100
+ const dim3 block_nums (ne10, block_num_y , ne11*ne12);
99
101
100
102
// strides in elements
101
103
// const size_t s0 = nb0 / sizeof(dst_t);
@@ -127,8 +129,8 @@ static void get_rows_cuda_float(
127
129
const size_t nb1, const size_t nb2, const size_t nb3,
128
130
cudaStream_t stream) {
129
131
const dim3 block_dims (CUDA_GET_ROWS_BLOCK_SIZE, 1 , 1 );
130
- const int block_num_x = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1 ) / CUDA_GET_ROWS_BLOCK_SIZE;
131
- const dim3 block_nums (block_num_x, ne10 , ne11*ne12);
132
+ const int block_num_y = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1 ) / CUDA_GET_ROWS_BLOCK_SIZE;
133
+ const dim3 block_nums (ne10, block_num_y , ne11*ne12);
132
134
133
135
// strides in elements
134
136
// const size_t s0 = nb0 / sizeof(dst_t);
0 commit comments