1
1
#include " binbcast.hpp"
2
2
3
+ #include < array>
3
4
#include < cstddef>
4
5
#include < cstdint>
6
+ #include < cstdio>
5
7
#include < sycl/sycl.hpp>
6
8
7
9
#include " ggml.h"
8
10
9
- template <float (*bin_op)(const float , const float ), typename src0_t , typename src1_t , typename dst_t >
10
- static void k_bin_bcast (const src0_t * src0, const src1_t * src1, dst_t * dst,
11
- int ne0, int ne1, int ne2, int ne3,
12
- int ne10, int ne11, int ne12, int ne13,
13
- /* int s0, */ int s1, int s2, int s3,
14
- /* int s00,*/ int s01, int s02, int s03,
15
- /* int s10,*/ int s11, int s12, int s13,
16
- const sycl::nd_item<3 > &item_ct1) {
17
- const int i0s = item_ct1.get_local_range (2 ) * item_ct1.get_group (2 ) +
18
- item_ct1.get_local_id (2 );
19
- const int i1 = (item_ct1.get_local_range (1 ) * item_ct1.get_group (1 ) +
20
- item_ct1.get_local_id (1 ));
21
- const int i2 = (item_ct1.get_local_range (0 ) * item_ct1.get_group (0 ) +
22
- item_ct1.get_local_id (0 )) /
23
- ne3;
24
- const int i3 = (item_ct1.get_local_range (0 ) * item_ct1.get_group (0 ) +
25
- item_ct1.get_local_id (0 )) %
26
- ne3;
27
-
28
- if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
29
- return ;
30
- }
31
-
32
- const int i11 = i1 % ne11;
33
- const int i12 = i2 % ne12;
34
- const int i13 = i3 % ne13;
35
-
36
- const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
37
- const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
38
- const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
39
-
40
- const src0_t * src0_row = src0 + i_src0;
41
- const src1_t * src1_row = src1 + i_src1;
42
- dst_t * dst_row = dst + i_dst;
43
-
44
- for (int i0 = i0s; i0 < ne0;
45
- i0 += item_ct1.get_local_range (2 ) * item_ct1.get_group_range (2 )) {
46
- const int i10 = i0 % ne10;
47
- dst_row[i0] = (dst_t )bin_op (src0 ? (float )src0_row[i0] : 0 .0f , (float )src1_row[i10]);
11
+ template <float (*bin_op)(const float , const float ), typename src0_t , typename src1_t , typename dst_t >
12
+ static __attribute ((always_inline)) inline void k_bin_bcast_contiguous(const src0_t * __restrict__ src0,
13
+ const src1_t * __restrict__ src1, dst_t * dst,
14
+ std::size_t num_elements,
15
+ const sycl::nd_item<1 > & it) {
16
+ auto element_id = it.get_global_id (0 );
17
+ auto global_range = it.get_global_range (0 );
18
+ for (; element_id < num_elements; element_id += global_range) {
19
+ auto src0_float_val = sycl::vec (src0[element_id]).template convert <float , sycl::rounding_mode::rte>();
20
+ auto src1_float_val = sycl::vec (src1[element_id]).template convert <float , sycl::rounding_mode::rte>();
21
+ float dst_val = bin_op (src0_float_val[0 ], src1_float_val[0 ]);
22
+ auto val_to_store = sycl::vec (dst_val).template convert <dst_t , sycl::rounding_mode::rte>();
23
+ dst[element_id] = val_to_store;
48
24
}
49
25
}
50
26
51
- template <float (*bin_op)(const float , const float ), typename src0_t , typename src1_t , typename dst_t >
52
- static void k_bin_bcast_unravel (const src0_t * src0, const src1_t * src1, dst_t * dst,
53
- int ne0, int ne1, int ne2, int ne3,
54
- int ne10, int ne11, int ne12, int ne13,
55
- /* int s0, */ int s1, int s2, int s3,
56
- /* int s00,*/ int s01, int s02, int s03,
57
- /* int s10,*/ int s11, int s12, int s13,
58
- const sycl::nd_item<3 > &item_ct1) {
59
-
60
- const int i = item_ct1.get_local_range (2 ) * item_ct1.get_group (2 ) +
61
- item_ct1.get_local_id (2 );
62
-
63
- const int i3 = i/(ne2*ne1*ne0);
64
- const int i2 = (i/(ne1*ne0)) % ne2;
65
- const int i1 = (i/ne0) % ne1;
66
- const int i0 = i % ne0;
67
-
68
- if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
69
- return ;
27
+ template <float (*bin_op)(const float , const float ), typename src0_t , typename src1_t , typename dst_t >
28
+ static __attribute ((always_inline)) inline void k_bin_bcast(
29
+ const src0_t * __restrict__ src0, const src1_t * __restrict__ src1, dst_t * dst, int ne0, int ne1, int ne2, int ne3,
30
+ int ne10, int ne11, int ne12, int ne13, int s0, int s1, int s2, int s3, int s00, int s01, int s02, int s03, int s10,
31
+ int s11, int s12, int s13, std::size_t num_dst_elements, const sycl::nd_item<1 > & item_ct1) {
32
+ auto calculate_logical_index =
33
+ [](const std::array<int , 4 > & dims, std::size_t element_id) __attribute__ ((always_inline))->std ::array<int , 4 > {
34
+ std::array<int , 4 > logical_index;
35
+ #pragma unroll(4)
36
+ for (int i = 3 ; i >= 0 ; i--) {
37
+ logical_index[i] = element_id % dims[i];
38
+ element_id /= dims[i];
39
+ }
40
+ return logical_index;
41
+ };
42
+
43
+ auto calculate_index = [](const std::array<int , 4 > & dims, const std::array<int , 4 > & strides,
44
+ const std::array<int , 4 > & indices) __attribute__ ((always_inline))
45
+ ->std ::size_t {
46
+ std::size_t index = 0 ;
47
+ #pragma unroll(4)
48
+ for (int i = 0 ; i < 4 ; i++) {
49
+ auto index_i = indices[i];
50
+ if (indices[i] >= dims[i]) {
51
+ index_i = indices[i] % dims[i];
52
+ }
53
+ index += strides[i] * index_i;
54
+ }
55
+ return index ;
56
+ };
57
+
58
+ auto element_id = item_ct1.get_global_id (0 );
59
+ for (; element_id < num_dst_elements; element_id += item_ct1.get_global_range (0 )) {
60
+ auto logical_index = calculate_logical_index ({ ne3, ne2, ne1, ne0 }, element_id);
61
+ auto src_0_index = calculate_index ({ ne3, ne2, ne1, ne0 }, { s03, s02, s01, s00 }, logical_index);
62
+ auto src_1_index = calculate_index ({ ne13, ne12, ne11, ne10 }, { s13, s12, s11, s10 }, logical_index);
63
+ auto dst_index = calculate_index ({ ne3, ne2, ne1, ne0 }, { s3, s2, s1, s0 }, logical_index);
64
+ auto src0_float_val = sycl::vec (src0[src_0_index]).template convert <float , sycl::rounding_mode::rte>();
65
+ auto src1_float_val = sycl::vec (src1[src_1_index]).template convert <float , sycl::rounding_mode::rte>();
66
+ float dst_val = bin_op (src0_float_val[0 ], src1_float_val[0 ]);
67
+ auto val_to_store = sycl::vec (dst_val).template convert <dst_t , sycl::rounding_mode::rte>();
68
+ dst[dst_index] = val_to_store;
70
69
}
71
-
72
- const int i11 = i1 % ne11;
73
- const int i12 = i2 % ne12;
74
- const int i13 = i3 % ne13;
75
-
76
- const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
77
- const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
78
- const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
79
-
80
- const src0_t * src0_row = src0 + i_src0;
81
- const src1_t * src1_row = src1 + i_src1;
82
- dst_t * dst_row = dst + i_dst;
83
-
84
- const int i10 = i0 % ne10;
85
- dst_row[i0] = (dst_t )bin_op (src0 ? (float )src0_row[i0] : 0 .0f , (float )src1_row[i10]);
86
70
}
87
71
88
-
89
- template <float (*bin_op)(const float , const float )>
90
- struct bin_bcast_sycl {
72
+ template <float (*bin_op)(const float , const float )> struct bin_bcast_sycl {
91
73
template <typename src0_t , typename src1_t , typename dst_t >
92
74
void operator ()(const src0_t * src0_dd, const src1_t * src1_dd, dst_t * dst_dd, const int64_t ne00,
93
75
const int64_t ne01, const int64_t ne02, const int64_t ne03, const int64_t ne10, const int64_t ne11,
@@ -96,165 +78,75 @@ struct bin_bcast_sycl {
96
78
const size_t nb10, const size_t nb11, const size_t nb12, const size_t nb13, const size_t nb0,
97
79
const size_t nb1, const size_t nb2, const size_t nb3, const bool src0_is_contiguous,
98
80
const bool src1_is_contiguous, const bool dst_is_contiguous, queue_ptr stream) {
99
- int nr0 = ne10 / ne0;
100
- int nr1 = ne11/ne1;
101
- int nr2 = ne12/ne2;
102
- int nr3 = ne13/ne3;
103
-
104
- int nr[4 ] = { nr0, nr1, nr2, nr3 };
105
-
106
- // collapse dimensions until first broadcast dimension
107
- int64_t cne[] = {ne0, ne1, ne2, ne3};
108
- int64_t cne0[] = {ne00, ne01, ne02, ne03};
109
- int64_t cne1[] = {ne10, ne11, ne12, ne13};
110
- size_t cnb[] = {nb0, nb1, nb2, nb3};
111
- size_t cnb0[] = {nb00, nb01, nb02, nb03};
112
- size_t cnb1[] = {nb10, nb11, nb12, nb13};
113
- auto collapse = [](int64_t cne[]) {
114
- cne[0 ] *= cne[1 ];
115
- cne[1 ] = cne[2 ];
116
- cne[2 ] = cne[3 ];
117
- cne[3 ] = 1 ;
118
- };
119
-
120
- auto collapse_nb = [](size_t cnb[], int64_t cne[]) {
121
- cnb[1 ] *= cne[1 ];
122
- cnb[2 ] *= cne[2 ];
123
- cnb[3 ] *= cne[3 ];
124
- };
125
-
126
- if (src0_is_contiguous && src1_is_contiguous && dst_is_contiguous) {
81
+ auto check_bcast_required = [](const std::array<int64_t , 4 > & src_dims,
82
+ const std::array<int64_t , 4 > & dst_dims) -> bool {
83
+ bool requires_broadcast = false ;
127
84
for (int i = 0 ; i < 4 ; i++) {
128
- if (nr[i] != 1 ) {
85
+ if (dst_dims[i] > src_dims[i]) {
86
+ requires_broadcast = true ;
129
87
break ;
130
88
}
131
- if (i > 0 ) {
132
- collapse_nb (cnb, cne);
133
- collapse_nb (cnb0, cne0);
134
- collapse_nb (cnb1, cne1);
135
- collapse (cne);
136
- collapse (cne0);
137
- collapse (cne1);
138
- }
139
- }
140
- }
141
- {
142
- int64_t ne0 = cne[0 ];
143
- int64_t ne1 = cne[1 ];
144
- int64_t ne2 = cne[2 ];
145
- int64_t ne3 = cne[3 ];
146
-
147
- int64_t ne10 = cne1[0 ];
148
- int64_t ne11 = cne1[1 ];
149
- int64_t ne12 = cne1[2 ];
150
- int64_t ne13 = cne1[3 ];
151
-
152
- size_t nb0 = cnb[0 ];
153
- size_t nb1 = cnb[1 ];
154
- size_t nb2 = cnb[2 ];
155
- size_t nb3 = cnb[3 ];
156
-
157
- size_t nb00 = cnb0[0 ];
158
- size_t nb01 = cnb0[1 ];
159
- size_t nb02 = cnb0[2 ];
160
- size_t nb03 = cnb0[3 ];
161
-
162
- size_t nb10 = cnb1[0 ];
163
- size_t nb11 = cnb1[1 ];
164
- size_t nb12 = cnb1[2 ];
165
- size_t nb13 = cnb1[3 ];
166
-
167
- size_t s0 = nb0 / sizeof (dst_t );
168
- size_t s1 = nb1 / sizeof (dst_t );
169
- size_t s2 = nb2 / sizeof (dst_t );
170
- size_t s3 = nb3 / sizeof (dst_t );
171
-
172
- size_t s10 = nb10 / sizeof (src1_t );
173
- size_t s11 = nb11 / sizeof (src1_t );
174
- size_t s12 = nb12 / sizeof (src1_t );
175
- size_t s13 = nb13 / sizeof (src1_t );
176
-
177
- size_t s00 = nb00 / sizeof (src0_t );
178
- size_t s01 = nb01 / sizeof (src0_t );
179
- size_t s02 = nb02 / sizeof (src0_t );
180
- size_t s03 = nb03 / sizeof (src0_t );
181
-
182
- GGML_UNUSED (s00);
183
-
184
- GGML_ASSERT (nb0 % sizeof (dst_t ) == 0 );
185
- GGML_ASSERT (nb1 % sizeof (dst_t ) == 0 );
186
- GGML_ASSERT (nb2 % sizeof (dst_t ) == 0 );
187
- GGML_ASSERT (nb3 % sizeof (dst_t ) == 0 );
188
-
189
- GGML_ASSERT (nb00 % sizeof (src0_t ) == 0 );
190
- GGML_ASSERT (nb01 % sizeof (src0_t ) == 0 );
191
- GGML_ASSERT (nb02 % sizeof (src0_t ) == 0 );
192
- GGML_ASSERT (nb03 % sizeof (src0_t ) == 0 );
193
-
194
- GGML_ASSERT (nb10 % sizeof (src1_t ) == 0 );
195
- GGML_ASSERT (nb11 % sizeof (src1_t ) == 0 );
196
- GGML_ASSERT (nb12 % sizeof (src1_t ) == 0 );
197
- GGML_ASSERT (nb13 % sizeof (src1_t ) == 0 );
198
-
199
- GGML_ASSERT (s0 == 1 );
200
- GGML_ASSERT (s10 == 1 );
201
-
202
- const int block_size = 128 ;
203
-
204
- int64_t hne0 = std::max (ne0/2LL , 1LL );
205
-
206
- sycl::range<3 > block_dims (1 , 1 , 1 );
207
- block_dims[2 ] = std::min<unsigned int >(hne0, block_size);
208
- block_dims[1 ] = std::min<unsigned int >(
209
- ne1, block_size / (unsigned int )block_dims[2 ]);
210
- block_dims[0 ] = std::min (
211
- std::min<unsigned int >(
212
- ne2 * ne3, block_size / (unsigned int )block_dims[2 ] /
213
- (unsigned int )block_dims[1 ]),
214
- 64U );
215
-
216
- sycl::range<3 > block_nums (
217
- (ne2 * ne3 + block_dims[0 ] - 1 ) / block_dims[0 ],
218
- (ne1 + block_dims[1 ] - 1 ) / block_dims[1 ],
219
- (hne0 + block_dims[2 ] - 1 ) / block_dims[2 ]);
220
-
221
- if (block_nums[0 ] > 65535 ) {
222
- // this is the maximum number of blocks in z direction, fallback to 1D grid kernel
223
- int block_num = (ne0*ne1*ne2*ne3 + block_size - 1 ) / block_size;
224
- {
225
- dpct::has_capability_or_fail (stream->get_device (),
226
- {sycl::aspect::fp16});
227
-
228
- stream->parallel_for (
229
- sycl::nd_range<3 >(sycl::range<3 >(1 , 1 , block_num) *
230
- sycl::range<3 >(1 , 1 , block_size),
231
- sycl::range<3 >(1 , 1 , block_size)),
232
- [=](sycl::nd_item<3 > item_ct1) {
233
- k_bin_bcast_unravel<bin_op>(
234
- src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3,
235
- ne10, ne11, ne12, ne13, s1, s2, s3, s01, s02,
236
- s03, s11, s12, s13, item_ct1);
237
- });
238
- }
239
- } else {
240
- /*
241
- DPCT1049:16: The work-group size passed to the SYCL kernel may
242
- exceed the limit. To get the device limit, query
243
- info::device::max_work_group_size. Adjust the work-group size if
244
- needed.
245
- */
246
- dpct::has_capability_or_fail (stream->get_device (),
247
- {sycl::aspect::fp16});
248
-
249
- stream->parallel_for (
250
- sycl::nd_range<3 >(block_nums * block_dims, block_dims),
251
- [=](sycl::nd_item<3 > item_ct1) {
252
- k_bin_bcast<bin_op>(src0_dd, src1_dd, dst_dd, ne0, ne1,
253
- ne2, ne3, ne10, ne11, ne12, ne13,
254
- s1, s2, s3, s01, s02, s03, s11, s12, s13,
255
- item_ct1);
256
- });
257
89
}
90
+ return requires_broadcast;
91
+ };
92
+
93
+ dpct::has_capability_or_fail (stream->get_device (), { sycl::aspect::fp16 });
94
+
95
+ GGML_ASSERT (nb0 % sizeof (dst_t ) == 0 );
96
+ GGML_ASSERT (nb1 % sizeof (dst_t ) == 0 );
97
+ GGML_ASSERT (nb2 % sizeof (dst_t ) == 0 );
98
+ GGML_ASSERT (nb3 % sizeof (dst_t ) == 0 );
99
+
100
+ GGML_ASSERT (nb00 % sizeof (src0_t ) == 0 );
101
+ GGML_ASSERT (nb01 % sizeof (src0_t ) == 0 );
102
+ GGML_ASSERT (nb02 % sizeof (src0_t ) == 0 );
103
+ GGML_ASSERT (nb03 % sizeof (src0_t ) == 0 );
104
+
105
+ GGML_ASSERT (nb10 % sizeof (src1_t ) == 0 );
106
+ GGML_ASSERT (nb11 % sizeof (src1_t ) == 0 );
107
+ GGML_ASSERT (nb12 % sizeof (src1_t ) == 0 );
108
+ GGML_ASSERT (nb13 % sizeof (src1_t ) == 0 );
109
+
110
+ // dst strides in number of elements
111
+ size_t s0 = nb0 / sizeof (dst_t );
112
+ size_t s1 = nb1 / sizeof (dst_t );
113
+ size_t s2 = nb2 / sizeof (dst_t );
114
+ size_t s3 = nb3 / sizeof (dst_t );
115
+
116
+ // src1 strides in number of elements
117
+ size_t s10 = nb10 / sizeof (src0_t );
118
+ size_t s11 = nb11 / sizeof (src1_t );
119
+ size_t s12 = nb12 / sizeof (src1_t );
120
+ size_t s13 = nb13 / sizeof (src1_t );
121
+
122
+ // src0 strides in number of elements
123
+ size_t s00 = nb00 / sizeof (src0_t );
124
+ size_t s01 = nb01 / sizeof (src0_t );
125
+ size_t s02 = nb02 / sizeof (src0_t );
126
+ size_t s03 = nb03 / sizeof (src0_t );
127
+
128
+ std::size_t num_dst_elements = static_cast <std::size_t >(ne0) * static_cast <std::size_t >(ne1) *
129
+ static_cast <std::size_t >(ne2) * static_cast <std::size_t >(ne3);
130
+ std::size_t local_range = 256 ;
131
+ std::size_t global_range = ((num_dst_elements + local_range - 1 ) / local_range) * local_range;
132
+
133
+ bool needs_broadcasting = check_bcast_required ({ ne00, ne01, ne02, ne03 }, { ne0, ne1, ne2, ne3 }) ||
134
+ check_bcast_required ({ ne10, ne11, ne12, ne13 }, { ne0, ne1, ne2, ne3 });
135
+ bool all_contiguous = src0_is_contiguous && src1_is_contiguous && dst_is_contiguous;
136
+
137
+ if (not needs_broadcasting && all_contiguous) {
138
+ stream->submit ([&](sycl::handler & cgh) {
139
+ cgh.parallel_for (sycl::nd_range<1 >({ global_range }, { local_range }), [=](sycl::nd_item<1 > it) {
140
+ k_bin_bcast_contiguous<bin_op>(src0_dd, src1_dd, dst_dd, num_dst_elements, it);
141
+ });
142
+ });
143
+ } else {
144
+ stream->submit ([&](sycl::handler & cgh) {
145
+ cgh.parallel_for (sycl::nd_range<1 >({ global_range }, { local_range }), [=](sycl::nd_item<1 > it) {
146
+ k_bin_bcast<bin_op>(src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3, ne10, ne11, ne12, ne13, s0, s1,
147
+ s2, s3, s00, s01, s02, s03, s10, s11, s12, s13, num_dst_elements, it);
148
+ });
149
+ });
258
150
}
259
151
}
260
152
};
0 commit comments