@@ -79,12 +79,23 @@ def sort(x, /, *, axis=-1, descending=False, stable=True, kind=None):
79
79
raise TypeError (
80
80
f"Expected type dpctl.tensor.usm_ndarray, got { type (x )} "
81
81
)
82
+ if not isinstance (kind , str ) or kind not in [
83
+ "stable" ,
84
+ "radixsort" ,
85
+ "mergesort" ,
86
+ ]:
87
+ raise ValueError (
88
+ "Unsupported kind value. Expected 'stable', 'mergesort', "
89
+ f"or 'radixsort', but got '{ kind } '"
90
+ )
82
91
nd = x .ndim
83
92
if nd == 0 :
84
93
axis = normalize_axis_index (axis , ndim = 1 , msg_prefix = "axis" )
85
94
return dpt .copy (x , order = "C" )
86
95
else :
87
96
axis = normalize_axis_index (axis , ndim = nd , msg_prefix = "axis" )
97
+ if x .size == 1 :
98
+ return dpt .copy (x , order = "C" )
88
99
a1 = axis + 1
89
100
if a1 == nd :
90
101
perm = list (range (nd ))
@@ -96,15 +107,6 @@ def sort(x, /, *, axis=-1, descending=False, stable=True, kind=None):
96
107
arr = dpt .permute_dims (x , perm )
97
108
if kind is None :
98
109
kind = "stable"
99
- if not isinstance (kind , str ) or kind not in [
100
- "stable" ,
101
- "radixsort" ,
102
- "mergesort" ,
103
- ]:
104
- raise ValueError (
105
- "Unsupported kind value. Expected 'stable', 'mergesort', "
106
- f"or 'radixsort', but got '{ kind } '"
107
- )
108
110
if kind == "mergesort" :
109
111
impl_fn = _get_mergesort_impl_fn (descending )
110
112
elif kind == "radixsort" :
0 commit comments