Skip to content

Commit 8f19cb5

Browse files
committed
Move sort kind validation and add a fast-path for size == 1 arrays
radix sort implementation asserts that array must be of size > 1
1 parent f7e0967 commit 8f19cb5

File tree

1 file changed

+11
-9
lines changed

1 file changed

+11
-9
lines changed

dpctl/tensor/_sorting.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -79,12 +79,23 @@ def sort(x, /, *, axis=-1, descending=False, stable=True, kind=None):
7979
raise TypeError(
8080
f"Expected type dpctl.tensor.usm_ndarray, got {type(x)}"
8181
)
82+
if not isinstance(kind, str) or kind not in [
83+
"stable",
84+
"radixsort",
85+
"mergesort",
86+
]:
87+
raise ValueError(
88+
"Unsupported kind value. Expected 'stable', 'mergesort', "
89+
f"or 'radixsort', but got '{kind}'"
90+
)
8291
nd = x.ndim
8392
if nd == 0:
8493
axis = normalize_axis_index(axis, ndim=1, msg_prefix="axis")
8594
return dpt.copy(x, order="C")
8695
else:
8796
axis = normalize_axis_index(axis, ndim=nd, msg_prefix="axis")
97+
if x.size == 1:
98+
return dpt.copy(x, order="C")
8899
a1 = axis + 1
89100
if a1 == nd:
90101
perm = list(range(nd))
@@ -96,15 +107,6 @@ def sort(x, /, *, axis=-1, descending=False, stable=True, kind=None):
96107
arr = dpt.permute_dims(x, perm)
97108
if kind is None:
98109
kind = "stable"
99-
if not isinstance(kind, str) or kind not in [
100-
"stable",
101-
"radixsort",
102-
"mergesort",
103-
]:
104-
raise ValueError(
105-
"Unsupported kind value. Expected 'stable', 'mergesort', "
106-
f"or 'radixsort', but got '{kind}'"
107-
)
108110
if kind == "mergesort":
109111
impl_fn = _get_mergesort_impl_fn(descending)
110112
elif kind == "radixsort":

0 commit comments

Comments
 (0)