Skip to content

Commit 5a0e7a9

Browse files
committed
sycl: simplify bin_bcast_kernel
1 parent 814f795 commit 5a0e7a9

File tree

1 file changed

+123
-231
lines changed

1 file changed

+123
-231
lines changed

ggml/src/ggml-sycl/binbcast.cpp

+123-231
Original file line numberDiff line numberDiff line change
@@ -1,93 +1,75 @@
11
#include "binbcast.hpp"
22

3+
#include <array>
34
#include <cstddef>
45
#include <cstdint>
6+
#include <cstdio>
57
#include <sycl/sycl.hpp>
68

79
#include "ggml.h"
810

9-
template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
10-
static void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,
11-
int ne0, int ne1, int ne2, int ne3,
12-
int ne10, int ne11, int ne12, int ne13,
13-
/*int s0, */ int s1, int s2, int s3,
14-
/*int s00,*/ int s01, int s02, int s03,
15-
/*int s10,*/ int s11, int s12, int s13,
16-
const sycl::nd_item<3> &item_ct1) {
17-
const int i0s = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
18-
item_ct1.get_local_id(2);
19-
const int i1 = (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
20-
item_ct1.get_local_id(1));
21-
const int i2 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) +
22-
item_ct1.get_local_id(0)) /
23-
ne3;
24-
const int i3 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) +
25-
item_ct1.get_local_id(0)) %
26-
ne3;
27-
28-
if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
29-
return;
30-
}
31-
32-
const int i11 = i1 % ne11;
33-
const int i12 = i2 % ne12;
34-
const int i13 = i3 % ne13;
35-
36-
const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
37-
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
38-
const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
39-
40-
const src0_t * src0_row = src0 + i_src0;
41-
const src1_t * src1_row = src1 + i_src1;
42-
dst_t * dst_row = dst + i_dst;
43-
44-
for (int i0 = i0s; i0 < ne0;
45-
i0 += item_ct1.get_local_range(2) * item_ct1.get_group_range(2)) {
46-
const int i10 = i0 % ne10;
47-
dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
11+
template <float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
12+
static __attribute((always_inline)) inline void k_bin_bcast_contiguous(const src0_t * __restrict__ src0,
13+
const src1_t * __restrict__ src1, dst_t * dst,
14+
std::size_t num_elements,
15+
const sycl::nd_item<1> & it) {
16+
auto element_id = it.get_global_id(0);
17+
auto global_range = it.get_global_range(0);
18+
for (; element_id < num_elements; element_id += global_range) {
19+
auto src0_float_val = sycl::vec(src0[element_id]).template convert<float, sycl::rounding_mode::rte>();
20+
auto src1_float_val = sycl::vec(src1[element_id]).template convert<float, sycl::rounding_mode::rte>();
21+
float dst_val = bin_op(src0_float_val[0], src1_float_val[0]);
22+
auto val_to_store = sycl::vec(dst_val).template convert<dst_t, sycl::rounding_mode::rte>();
23+
dst[element_id] = val_to_store;
4824
}
4925
}
5026

51-
template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
52-
static void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t * dst,
53-
int ne0, int ne1, int ne2, int ne3,
54-
int ne10, int ne11, int ne12, int ne13,
55-
/*int s0, */ int s1, int s2, int s3,
56-
/*int s00,*/ int s01, int s02, int s03,
57-
/*int s10,*/ int s11, int s12, int s13,
58-
const sycl::nd_item<3> &item_ct1) {
59-
60-
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
61-
item_ct1.get_local_id(2);
62-
63-
const int i3 = i/(ne2*ne1*ne0);
64-
const int i2 = (i/(ne1*ne0)) % ne2;
65-
const int i1 = (i/ne0) % ne1;
66-
const int i0 = i % ne0;
67-
68-
if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
69-
return;
27+
template <float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
28+
static __attribute((always_inline)) inline void k_bin_bcast(
29+
const src0_t * __restrict__ src0, const src1_t * __restrict__ src1, dst_t * dst, int ne0, int ne1, int ne2, int ne3,
30+
int ne10, int ne11, int ne12, int ne13, int s0, int s1, int s2, int s3, int s00, int s01, int s02, int s03, int s10,
31+
int s11, int s12, int s13, std::size_t num_dst_elements, const sycl::nd_item<1> & item_ct1) {
32+
auto calculate_logical_index =
33+
[](const std::array<int, 4> & dims, std::size_t element_id) __attribute__((always_inline))->std::array<int, 4> {
34+
std::array<int, 4> logical_index;
35+
#pragma unroll(4)
36+
for (int i = 3; i >= 0; i--) {
37+
logical_index[i] = element_id % dims[i];
38+
element_id /= dims[i];
39+
}
40+
return logical_index;
41+
};
42+
43+
auto calculate_index = [](const std::array<int, 4> & dims, const std::array<int, 4> & strides,
44+
const std::array<int, 4> & indices) __attribute__((always_inline))
45+
->std::size_t {
46+
std::size_t index = 0;
47+
#pragma unroll(4)
48+
for (int i = 0; i < 4; i++) {
49+
auto index_i = indices[i];
50+
if (indices[i] >= dims[i]) {
51+
index_i = indices[i] % dims[i];
52+
}
53+
index += strides[i] * index_i;
54+
}
55+
return index;
56+
};
57+
58+
auto element_id = item_ct1.get_global_id(0);
59+
for (; element_id < num_dst_elements; element_id += item_ct1.get_global_range(0)) {
60+
auto logical_index = calculate_logical_index({ ne3, ne2, ne1, ne0 }, element_id);
61+
auto src_0_index = calculate_index({ ne3, ne2, ne1, ne0 }, { s03, s02, s01, s00 }, logical_index);
62+
auto src_1_index = calculate_index({ ne13, ne12, ne11, ne10 }, { s13, s12, s11, s10 }, logical_index);
63+
auto dst_index = calculate_index({ ne3, ne2, ne1, ne0 }, { s3, s2, s1, s0 }, logical_index);
64+
auto src0_float_val = sycl::vec(src0[src_0_index]).template convert<float, sycl::rounding_mode::rte>();
65+
auto src1_float_val = sycl::vec(src1[src_1_index]).template convert<float, sycl::rounding_mode::rte>();
66+
float dst_val = bin_op(src0_float_val[0], src1_float_val[0]);
67+
auto val_to_store = sycl::vec(dst_val).template convert<dst_t, sycl::rounding_mode::rte>();
68+
dst[dst_index] = val_to_store;
7069
}
71-
72-
const int i11 = i1 % ne11;
73-
const int i12 = i2 % ne12;
74-
const int i13 = i3 % ne13;
75-
76-
const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
77-
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
78-
const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
79-
80-
const src0_t * src0_row = src0 + i_src0;
81-
const src1_t * src1_row = src1 + i_src1;
82-
dst_t * dst_row = dst + i_dst;
83-
84-
const int i10 = i0 % ne10;
85-
dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
8670
}
8771

88-
89-
template<float (*bin_op)(const float, const float)>
90-
struct bin_bcast_sycl {
72+
template <float (*bin_op)(const float, const float)> struct bin_bcast_sycl {
9173
template <typename src0_t, typename src1_t, typename dst_t>
9274
void operator()(const src0_t * src0_dd, const src1_t * src1_dd, dst_t * dst_dd, const int64_t ne00,
9375
const int64_t ne01, const int64_t ne02, const int64_t ne03, const int64_t ne10, const int64_t ne11,
@@ -96,165 +78,75 @@ struct bin_bcast_sycl {
9678
const size_t nb10, const size_t nb11, const size_t nb12, const size_t nb13, const size_t nb0,
9779
const size_t nb1, const size_t nb2, const size_t nb3, const bool src0_is_contiguous,
9880
const bool src1_is_contiguous, const bool dst_is_contiguous, queue_ptr stream) {
99-
int nr0 = ne10 / ne0;
100-
int nr1 = ne11/ne1;
101-
int nr2 = ne12/ne2;
102-
int nr3 = ne13/ne3;
103-
104-
int nr[4] = { nr0, nr1, nr2, nr3 };
105-
106-
// collapse dimensions until first broadcast dimension
107-
int64_t cne[] = {ne0, ne1, ne2, ne3};
108-
int64_t cne0[] = {ne00, ne01, ne02, ne03};
109-
int64_t cne1[] = {ne10, ne11, ne12, ne13};
110-
size_t cnb[] = {nb0, nb1, nb2, nb3};
111-
size_t cnb0[] = {nb00, nb01, nb02, nb03};
112-
size_t cnb1[] = {nb10, nb11, nb12, nb13};
113-
auto collapse = [](int64_t cne[]) {
114-
cne[0] *= cne[1];
115-
cne[1] = cne[2];
116-
cne[2] = cne[3];
117-
cne[3] = 1;
118-
};
119-
120-
auto collapse_nb = [](size_t cnb[], int64_t cne[]) {
121-
cnb[1] *= cne[1];
122-
cnb[2] *= cne[2];
123-
cnb[3] *= cne[3];
124-
};
125-
126-
if (src0_is_contiguous && src1_is_contiguous && dst_is_contiguous) {
81+
auto check_bcast_required = [](const std::array<int64_t, 4> & src_dims,
82+
const std::array<int64_t, 4> & dst_dims) -> bool {
83+
bool requires_broadcast = false;
12784
for (int i = 0; i < 4; i++) {
128-
if (nr[i] != 1) {
85+
if (dst_dims[i] > src_dims[i]) {
86+
requires_broadcast = true;
12987
break;
13088
}
131-
if (i > 0) {
132-
collapse_nb(cnb, cne);
133-
collapse_nb(cnb0, cne0);
134-
collapse_nb(cnb1, cne1);
135-
collapse(cne);
136-
collapse(cne0);
137-
collapse(cne1);
138-
}
139-
}
140-
}
141-
{
142-
int64_t ne0 = cne[0];
143-
int64_t ne1 = cne[1];
144-
int64_t ne2 = cne[2];
145-
int64_t ne3 = cne[3];
146-
147-
int64_t ne10 = cne1[0];
148-
int64_t ne11 = cne1[1];
149-
int64_t ne12 = cne1[2];
150-
int64_t ne13 = cne1[3];
151-
152-
size_t nb0 = cnb[0];
153-
size_t nb1 = cnb[1];
154-
size_t nb2 = cnb[2];
155-
size_t nb3 = cnb[3];
156-
157-
size_t nb00 = cnb0[0];
158-
size_t nb01 = cnb0[1];
159-
size_t nb02 = cnb0[2];
160-
size_t nb03 = cnb0[3];
161-
162-
size_t nb10 = cnb1[0];
163-
size_t nb11 = cnb1[1];
164-
size_t nb12 = cnb1[2];
165-
size_t nb13 = cnb1[3];
166-
167-
size_t s0 = nb0 / sizeof(dst_t);
168-
size_t s1 = nb1 / sizeof(dst_t);
169-
size_t s2 = nb2 / sizeof(dst_t);
170-
size_t s3 = nb3 / sizeof(dst_t);
171-
172-
size_t s10 = nb10 / sizeof(src1_t);
173-
size_t s11 = nb11 / sizeof(src1_t);
174-
size_t s12 = nb12 / sizeof(src1_t);
175-
size_t s13 = nb13 / sizeof(src1_t);
176-
177-
size_t s00 = nb00 / sizeof(src0_t);
178-
size_t s01 = nb01 / sizeof(src0_t);
179-
size_t s02 = nb02 / sizeof(src0_t);
180-
size_t s03 = nb03 / sizeof(src0_t);
181-
182-
GGML_UNUSED(s00);
183-
184-
GGML_ASSERT(nb0 % sizeof(dst_t) == 0);
185-
GGML_ASSERT(nb1 % sizeof(dst_t) == 0);
186-
GGML_ASSERT(nb2 % sizeof(dst_t) == 0);
187-
GGML_ASSERT(nb3 % sizeof(dst_t) == 0);
188-
189-
GGML_ASSERT(nb00 % sizeof(src0_t) == 0);
190-
GGML_ASSERT(nb01 % sizeof(src0_t) == 0);
191-
GGML_ASSERT(nb02 % sizeof(src0_t) == 0);
192-
GGML_ASSERT(nb03 % sizeof(src0_t) == 0);
193-
194-
GGML_ASSERT(nb10 % sizeof(src1_t) == 0);
195-
GGML_ASSERT(nb11 % sizeof(src1_t) == 0);
196-
GGML_ASSERT(nb12 % sizeof(src1_t) == 0);
197-
GGML_ASSERT(nb13 % sizeof(src1_t) == 0);
198-
199-
GGML_ASSERT(s0 == 1);
200-
GGML_ASSERT(s10 == 1);
201-
202-
const int block_size = 128;
203-
204-
int64_t hne0 = std::max(ne0/2LL, 1LL);
205-
206-
sycl::range<3> block_dims(1, 1, 1);
207-
block_dims[2] = std::min<unsigned int>(hne0, block_size);
208-
block_dims[1] = std::min<unsigned int>(
209-
ne1, block_size / (unsigned int)block_dims[2]);
210-
block_dims[0] = std::min(
211-
std::min<unsigned int>(
212-
ne2 * ne3, block_size / (unsigned int)block_dims[2] /
213-
(unsigned int)block_dims[1]),
214-
64U);
215-
216-
sycl::range<3> block_nums(
217-
(ne2 * ne3 + block_dims[0] - 1) / block_dims[0],
218-
(ne1 + block_dims[1] - 1) / block_dims[1],
219-
(hne0 + block_dims[2] - 1) / block_dims[2]);
220-
221-
if (block_nums[0] > 65535) {
222-
// this is the maximum number of blocks in z direction, fallback to 1D grid kernel
223-
int block_num = (ne0*ne1*ne2*ne3 + block_size - 1) / block_size;
224-
{
225-
dpct::has_capability_or_fail(stream->get_device(),
226-
{sycl::aspect::fp16});
227-
228-
stream->parallel_for(
229-
sycl::nd_range<3>(sycl::range<3>(1, 1, block_num) *
230-
sycl::range<3>(1, 1, block_size),
231-
sycl::range<3>(1, 1, block_size)),
232-
[=](sycl::nd_item<3> item_ct1) {
233-
k_bin_bcast_unravel<bin_op>(
234-
src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3,
235-
ne10, ne11, ne12, ne13, s1, s2, s3, s01, s02,
236-
s03, s11, s12, s13, item_ct1);
237-
});
238-
}
239-
} else {
240-
/*
241-
DPCT1049:16: The work-group size passed to the SYCL kernel may
242-
exceed the limit. To get the device limit, query
243-
info::device::max_work_group_size. Adjust the work-group size if
244-
needed.
245-
*/
246-
dpct::has_capability_or_fail(stream->get_device(),
247-
{sycl::aspect::fp16});
248-
249-
stream->parallel_for(
250-
sycl::nd_range<3>(block_nums * block_dims, block_dims),
251-
[=](sycl::nd_item<3> item_ct1) {
252-
k_bin_bcast<bin_op>(src0_dd, src1_dd, dst_dd, ne0, ne1,
253-
ne2, ne3, ne10, ne11, ne12, ne13,
254-
s1, s2, s3, s01, s02, s03, s11, s12, s13,
255-
item_ct1);
256-
});
25789
}
90+
return requires_broadcast;
91+
};
92+
93+
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
94+
95+
GGML_ASSERT(nb0 % sizeof(dst_t) == 0);
96+
GGML_ASSERT(nb1 % sizeof(dst_t) == 0);
97+
GGML_ASSERT(nb2 % sizeof(dst_t) == 0);
98+
GGML_ASSERT(nb3 % sizeof(dst_t) == 0);
99+
100+
GGML_ASSERT(nb00 % sizeof(src0_t) == 0);
101+
GGML_ASSERT(nb01 % sizeof(src0_t) == 0);
102+
GGML_ASSERT(nb02 % sizeof(src0_t) == 0);
103+
GGML_ASSERT(nb03 % sizeof(src0_t) == 0);
104+
105+
GGML_ASSERT(nb10 % sizeof(src1_t) == 0);
106+
GGML_ASSERT(nb11 % sizeof(src1_t) == 0);
107+
GGML_ASSERT(nb12 % sizeof(src1_t) == 0);
108+
GGML_ASSERT(nb13 % sizeof(src1_t) == 0);
109+
110+
// dst strides in number of elements
111+
size_t s0 = nb0 / sizeof(dst_t);
112+
size_t s1 = nb1 / sizeof(dst_t);
113+
size_t s2 = nb2 / sizeof(dst_t);
114+
size_t s3 = nb3 / sizeof(dst_t);
115+
116+
// src1 strides in number of elements
117+
size_t s10 = nb10 / sizeof(src0_t);
118+
size_t s11 = nb11 / sizeof(src1_t);
119+
size_t s12 = nb12 / sizeof(src1_t);
120+
size_t s13 = nb13 / sizeof(src1_t);
121+
122+
// src0 strides in number of elements
123+
size_t s00 = nb00 / sizeof(src0_t);
124+
size_t s01 = nb01 / sizeof(src0_t);
125+
size_t s02 = nb02 / sizeof(src0_t);
126+
size_t s03 = nb03 / sizeof(src0_t);
127+
128+
std::size_t num_dst_elements = static_cast<std::size_t>(ne0) * static_cast<std::size_t>(ne1) *
129+
static_cast<std::size_t>(ne2) * static_cast<std::size_t>(ne3);
130+
std::size_t local_range = 256;
131+
std::size_t global_range = ((num_dst_elements + local_range - 1) / local_range) * local_range;
132+
133+
bool needs_broadcasting = check_bcast_required({ ne00, ne01, ne02, ne03 }, { ne0, ne1, ne2, ne3 }) ||
134+
check_bcast_required({ ne10, ne11, ne12, ne13 }, { ne0, ne1, ne2, ne3 });
135+
bool all_contiguous = src0_is_contiguous && src1_is_contiguous && dst_is_contiguous;
136+
137+
if (not needs_broadcasting && all_contiguous) {
138+
stream->submit([&](sycl::handler & cgh) {
139+
cgh.parallel_for(sycl::nd_range<1>({ global_range }, { local_range }), [=](sycl::nd_item<1> it) {
140+
k_bin_bcast_contiguous<bin_op>(src0_dd, src1_dd, dst_dd, num_dst_elements, it);
141+
});
142+
});
143+
} else {
144+
stream->submit([&](sycl::handler & cgh) {
145+
cgh.parallel_for(sycl::nd_range<1>({ global_range }, { local_range }), [=](sycl::nd_item<1> it) {
146+
k_bin_bcast<bin_op>(src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3, ne10, ne11, ne12, ne13, s0, s1,
147+
s2, s3, s00, s01, s02, s03, s10, s11, s12, s13, num_dst_elements, it);
148+
});
149+
});
258150
}
259151
}
260152
};

0 commit comments

Comments
 (0)