@@ -49,6 +49,7 @@ static bool g_sycl_loaded = false;
49
49
int g_ggml_sycl_debug = 0 ;
50
50
int g_ggml_sycl_disable_optimize = 0 ;
51
51
int g_ggml_sycl_disable_graph = 0 ;
52
+ int g_ggml_sycl_prioritize_dmmv = 0 ;
52
53
53
54
static ggml_sycl_device_info ggml_sycl_init () {
54
55
ggml_sycl_device_info info = {};
@@ -195,11 +196,13 @@ static void ggml_check_sycl() try {
195
196
g_ggml_sycl_debug = get_sycl_env (" GGML_SYCL_DEBUG" , 0 );
196
197
g_ggml_sycl_disable_optimize= get_sycl_env (" GGML_SYCL_DISABLE_OPT" , 1 );
197
198
g_ggml_sycl_disable_graph = get_sycl_env (" GGML_SYCL_DISABLE_GRAPH" , 1 );
199
+ g_ggml_sycl_prioritize_dmmv = get_sycl_env (" GGML_SYCL_PRIORITIZE_DMMV" , 0 );
198
200
GGML_SYCL_DEBUG (" [SYCL] call ggml_check_sycl\n " );
199
201
GGML_LOG_INFO (" Running with Environment Variables:\n " );
200
202
GGML_LOG_INFO (" GGML_SYCL_DEBUG: %d\n " , g_ggml_sycl_debug);
201
203
GGML_LOG_INFO (" GGML_SYCL_DISABLE_OPT: %d\n " , g_ggml_sycl_disable_optimize);
202
204
GGML_LOG_INFO (" GGML_SYCL_DISABLE_GRAPH: %d\n " , g_ggml_sycl_disable_graph);
205
+ GGML_LOG_INFO (" GGML_SYCL_PRIORITIZE_DMMV: %d\n " , g_ggml_sycl_prioritize_dmmv);
203
206
GGML_LOG_INFO (" Build with Macros:\n " );
204
207
#if defined(GGML_SYCL_FORCE_MMQ)
205
208
GGML_LOG_INFO (" GGML_SYCL_FORCE_MMQ: yes\n " );
@@ -2822,12 +2825,45 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
2822
2825
std::exit (1 );
2823
2826
}
2824
2827
2828
+ enum class mul_mat_algo {
2829
+ DMMV = 0 ,
2830
+ MMVQ = 1 ,
2831
+ MUL_MAT_SYCL = 2 ,
2832
+ };
2833
+
2825
2834
inline bool ggml_sycl_supports_mmq (enum ggml_type type) {
2826
2835
// TODO: accuracy issues in MMQ
2827
2836
GGML_UNUSED (type);
2828
2837
return false ;
2829
2838
}
2830
2839
2840
+ inline bool ggml_sycl_supports_reorder_mul_mat_sycl (enum ggml_type type) {
2841
+ switch (type) {
2842
+ case GGML_TYPE_Q4_0:
2843
+ return true ;
2844
+ default :
2845
+ return false ;
2846
+ }
2847
+ }
2848
+
2849
+ inline bool ggml_sycl_supports_reorder_dmmv (enum ggml_type type) {
2850
+ switch (type) {
2851
+ case GGML_TYPE_Q4_0:
2852
+ return true ;
2853
+ default :
2854
+ return false ;
2855
+ }
2856
+ }
2857
+
2858
+ inline bool ggml_sycl_supports_reorder_mmvq (enum ggml_type type) {
2859
+ switch (type) {
2860
+ case GGML_TYPE_Q4_0:
2861
+ return true ;
2862
+ default :
2863
+ return false ;
2864
+ }
2865
+ }
2866
+
2831
2867
static bool ggml_sycl_supports_dmmv (enum ggml_type type) {
2832
2868
switch (type) {
2833
2869
case GGML_TYPE_Q4_0:
@@ -2856,7 +2892,7 @@ static void reorder_qw(char *data_device, const int ncols, const int nrows,
2856
2892
GGML_ASSERT ((size % sizeof (block_q4_0) == 0 ));
2857
2893
GGML_ASSERT ((offset % sizeof (block_q4_0) == 0 ));
2858
2894
int offset_blks = offset / sizeof (block_q4_0);
2859
- auto qs_ptr = (uint8_t *)data_device + offset_blks * QK4_0 / 2 ;;
2895
+ auto qs_ptr = (uint8_t *)data_device + offset_blks * QK4_0 / 2 ;
2860
2896
auto d_ptr = (sycl::half*)(qs_ptr + ncols * nrows / 2 ) + offset_blks;
2861
2897
2862
2898
stream->parallel_for (
@@ -2884,25 +2920,44 @@ static void reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) {
2884
2920
reorder_qw (data_device, ncols, nrows, size, 0 , stream);
2885
2921
}
2886
2922
2887
- /*
2888
- * This function could be called when the OP (mul_mat) function support reorder optimizition.
2889
- */
2890
- static void opt_for_reorder (ggml_backend_sycl_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1,
2891
- ggml_tensor * dst) {
2892
- if (!g_ggml_sycl_disable_optimize && // allow optimize, controlled by $GGML_SYCL_DISABLE_OPT
2893
- ctx->opt_feature .reorder && // allow this device due to good perf, skip the devices with bad perf.
2894
- dst->op == GGML_OP_MUL_MAT && // limit to some supported cases of Q4_0, to do for more cases.
2895
- src0->type == GGML_TYPE_Q4_0 &&
2896
- src1->ne [2 ]==1 && src1->ne [3 ]==1 ) {
2923
+ static bool should_reorder_tensor (ggml_backend_sycl_context& ctx, const ggml_tensor * dst) {
2924
+ return !g_ggml_sycl_disable_optimize && // allow optimize, controlled by $GGML_SYCL_DISABLE_OPT
2925
+ ctx.opt_feature .reorder && // allow this device due to good perf, skip the devices with bad perf.
2926
+ dst->op == GGML_OP_MUL_MAT && // limit to some supported cases of Q4_0, to do for more cases.
2927
+ dst->src [1 ]->ne [2 ]==1 && dst->src [1 ]->ne [3 ]==1 ;
2928
+ }
2897
2929
2898
- ggml_tensor_extra_gpu* extra = (ggml_tensor_extra_gpu*)src0->extra ;
2899
- if (!extra) return ; // only happen in CI/UT permute case.
2930
+ static void opt_for_reorder (ggml_backend_sycl_context * ctx, const ggml_tensor * src0, const ggml_tensor * /* src1 */ ,
2931
+ ggml_tensor * dst, mul_mat_algo mm_algorithm) {
2932
+ if (!should_reorder_tensor (*ctx, dst)) {
2933
+ return ;
2934
+ }
2900
2935
2901
- if (extra->optimized_feature .reorder ) return ; // skip the tensor which is handled for reorder.
2936
+ ggml_tensor_extra_gpu * extra = static_cast <ggml_tensor_extra_gpu *>(src0->extra );
2937
+ if (!extra || extra->optimized_feature .reorder ) {
2938
+ return ; // Skip permutations and already reordered tensors
2939
+ }
2902
2940
2903
- reorder_qw (src0, ctx->stream ());
2904
- extra->optimized_feature .reorder = true ; // used to decode/dequan in next steps.
2941
+ switch (mm_algorithm) {
2942
+ case mul_mat_algo::DMMV:
2943
+ if (!ggml_sycl_supports_reorder_dmmv (src0->type )) {
2944
+ return ;
2945
+ }
2946
+ break ;
2947
+ case mul_mat_algo::MMVQ:
2948
+ if (!ggml_sycl_supports_reorder_mmvq (src0->type )) {
2949
+ return ;
2950
+ }
2951
+ break ;
2952
+ case mul_mat_algo::MUL_MAT_SYCL:
2953
+ if (!ggml_sycl_supports_reorder_mul_mat_sycl (src0->type )) {
2954
+ return ;
2955
+ }
2956
+ break ;
2905
2957
}
2958
+
2959
+ reorder_qw (src0, ctx->stream ());
2960
+ extra->optimized_feature .reorder = true ; // Used to decode/dequan in next steps and avoid re-reordering
2906
2961
}
2907
2962
2908
2963
static void ggml_sycl_mul_mat (ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -2911,7 +2966,8 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
2911
2966
int64_t min_compute_capability = INT_MAX;
2912
2967
2913
2968
if (split) {
2914
- ggml_backend_sycl_split_buffer_type_context * buft_ctx = (ggml_backend_sycl_split_buffer_type_context *) src0->buffer ->buft ->context ;
2969
+ ggml_backend_sycl_split_buffer_type_context * buft_ctx =
2970
+ (ggml_backend_sycl_split_buffer_type_context *) src0->buffer ->buft ->context ;
2915
2971
auto & tensor_split = buft_ctx->tensor_split ;
2916
2972
for (int id = 0 ; id < ggml_sycl_info ().device_count ; ++id) {
2917
2973
// skip devices that are not going to do any work:
@@ -2924,7 +2980,7 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
2924
2980
}
2925
2981
}
2926
2982
} else {
2927
- min_compute_capability = ggml_sycl_info ().devices [ctx.device ].cc ;
2983
+ min_compute_capability = ggml_sycl_info ().devices [ctx.device ].cc ;
2928
2984
}
2929
2985
2930
2986
// check data types and tensor shapes for custom matrix multiplication kernels:
@@ -2946,9 +3002,15 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
2946
3002
use_mul_mat_q = use_mul_mat_q && (src1->ne [1 ] <= MMQ_MAX_BATCH_SIZE);
2947
3003
#endif // SYCL_USE_XMX
2948
3004
3005
+
2949
3006
// mmvq path is faster in the CUDA backend.
2950
- if (ctx.stream ()->get_backend () == sycl::backend::ext_oneapi_cuda)
3007
+ if (!g_ggml_sycl_prioritize_dmmv && (ctx.stream ()->get_backend () == sycl::backend::ext_oneapi_cuda
3008
+ // Dispatch becomes obscure with the reorder, MMVQ when the reorder optimization
3009
+ // is enabled takes precedence over DMMV, the current if-else implementation
3010
+ // requires disabling DMMV if both conditions are met
3011
+ || (should_reorder_tensor (ctx, dst) && ggml_sycl_supports_reorder_mmvq (src0->type )))) {
2951
3012
use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q;
3013
+ }
2952
3014
2953
3015
if (!split && src0->type == GGML_TYPE_F16 && ggml_is_permuted (src0) && ggml_is_permuted (src1) && src1->ne [1 ] == 1 ) {
2954
3016
// TODO: Refactor and cleanup of mul mat dispatching.
@@ -2967,17 +3029,23 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
2967
3029
// KQ + KQV multi-batch
2968
3030
ggml_sycl_mul_mat_batched_sycl (ctx, src0, src1, dst);
2969
3031
} else if (use_dequantize_mul_mat_vec) {
2970
- opt_for_reorder (&ctx, src0, src1, dst); // the OP function in this branch support reorder.
2971
- ggml_sycl_op_mul_mat ( ctx, src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, false );
2972
- // save_tensor_txt("1/dst_1.txt", (float*) dst->data, src0->ne[1], sizeof(float), ctx.stream() );
3032
+ constexpr bool convert_src1_to_q8_1 = false ;
3033
+ opt_for_reorder (& ctx, src0, src1, dst, mul_mat_algo::DMMV );
3034
+ ggml_sycl_op_mul_mat (ctx, src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, convert_src1_to_q8_1 );
2973
3035
} else if (use_mul_mat_vec_q) {
2974
- ggml_sycl_op_mul_mat (ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, true );
3036
+ constexpr bool convert_src1_to_q8_1 = true ;
3037
+ opt_for_reorder (&ctx, src0, src1, dst, mul_mat_algo::MMVQ);
3038
+ ggml_sycl_op_mul_mat (ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, convert_src1_to_q8_1);
2975
3039
} else if (use_mul_mat_q) {
2976
- ggml_sycl_op_mul_mat (ctx, src0, src1, dst, ggml_sycl_op_mul_mat_q, true );
3040
+ constexpr bool convert_src1_to_q8_1 = true ;
3041
+ ggml_sycl_op_mul_mat (ctx, src0, src1, dst, ggml_sycl_op_mul_mat_q, convert_src1_to_q8_1);
2977
3042
} else {
2978
- opt_for_reorder (&ctx, src0, src1, dst); // the OP function in this branch support reorder.
2979
- ggml_sycl_op_mul_mat (ctx, src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false );
3043
+ constexpr bool convert_src1_to_q8_1 = false ;
3044
+ // MUL_MAT_SYCL supports reorder
3045
+ opt_for_reorder (&ctx, src0, src1, dst, mul_mat_algo::MUL_MAT_SYCL);
3046
+ ggml_sycl_op_mul_mat (ctx, src0, src1, dst, ggml_sycl_op_mul_mat_sycl, convert_src1_to_q8_1);
2980
3047
}
3048
+ GGML_SYCL_DEBUG (" call %s done\n " , __func__);
2981
3049
}
2982
3050
2983
3051
0 commit comments