Skip to content

Commit f7e0967

Browse files
committed
add basic tests for isin functionality
1 parent ef9d7e3 commit f7e0967

File tree

1 file changed

+168
-0
lines changed

1 file changed

+168
-0
lines changed

dpctl/tests/test_tensor_isin.py

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
# Data Parallel Control (dpctl)
2+
#
3+
# Copyright 2020-2025 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
import pytest
18+
19+
import dpctl.tensor as dpt
20+
from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported
21+
from dpctl.utils import ExecutionPlacementError
22+
23+
24+
@pytest.mark.parametrize(
25+
"dtype",
26+
[
27+
"i1",
28+
"u1",
29+
"i2",
30+
"u2",
31+
"i4",
32+
"u4",
33+
"i8",
34+
"u8",
35+
"f2",
36+
"f4",
37+
"f8",
38+
"c8",
39+
"c16",
40+
],
41+
)
42+
def test_isin_basic(dtype):
43+
q = get_queue_or_skip()
44+
skip_if_dtype_not_supported(dtype, q)
45+
46+
n = 100
47+
x = dpt.arange(n, dtype=dtype)
48+
test = dpt.arange(n - 1, dtype=dtype)
49+
r1 = dpt.isin(x, test)
50+
assert dpt.all(r1[:-1])
51+
assert not r1[-1]
52+
assert r1.shape == x.shape
53+
54+
# test with invert keyword
55+
r2 = dpt.isin(x, test, invert=True)
56+
assert not dpt.all(r2[:-1])
57+
assert r2[-1]
58+
assert r2.shape == x.shape
59+
60+
61+
def test_isin_basic_bool():
62+
dt = dpt.bool
63+
n = 100
64+
x = dpt.zeros(n, dtype=dt)
65+
x[-1] = True
66+
test = dpt.zeros((), dtype=dt)
67+
r1 = dpt.isin(x, test)
68+
assert dpt.all(r1[:-1])
69+
assert not r1[-1]
70+
assert r1.shape == x.shape
71+
72+
r2 = dpt.isin(x, test, invert=True)
73+
assert not dpt.all(r2[:-1])
74+
assert r2[-1]
75+
assert r2.shape == x.shape
76+
77+
78+
@pytest.mark.parametrize(
79+
"dtype",
80+
[
81+
"i1",
82+
"u1",
83+
"i2",
84+
"u2",
85+
"i4",
86+
"u4",
87+
"i8",
88+
"u8",
89+
"f2",
90+
"f4",
91+
"f8",
92+
"c8",
93+
"c16",
94+
],
95+
)
96+
def test_isin_strided(dtype):
97+
q = get_queue_or_skip()
98+
skip_if_dtype_not_supported(dtype, q)
99+
100+
n, m = 100, 20
101+
x = dpt.zeros((n, m), dtype=dtype, order="F")
102+
x[:, ::2] = dpt.arange(1, (m / 2) + 1, dtype=dtype)
103+
test = dpt.arange(1, (m / 2) + 1, dtype=dtype)
104+
r1 = dpt.isin(x, test)
105+
assert dpt.all(r1[:, ::2])
106+
assert not dpt.all(r1[:, 1::2])
107+
assert r1.shape == x.shape
108+
109+
# test with invert keyword
110+
r2 = dpt.isin(x, test, invert=True)
111+
assert not dpt.all(r2[:, ::2])
112+
assert dpt.all(r2[:, 1::2])
113+
assert r2.shape == x.shape
114+
115+
116+
def test_isin_strided_bool():
117+
dt = dpt.bool
118+
n, m = 100, 20
119+
x = dpt.ones((n, m), dtype=dt, order="F")
120+
x[:, ::2] = False
121+
test = dpt.zeros((), dtype=dt)
122+
r1 = dpt.isin(x, test)
123+
assert dpt.all(r1[:, ::2])
124+
assert not dpt.all(r1[:, 1::2])
125+
assert r1.shape == x.shape
126+
127+
# test with invert keyword
128+
r2 = dpt.isin(x, test, invert=True)
129+
assert not dpt.all(r2[:, ::2])
130+
assert dpt.all(r2[:, 1::2])
131+
assert r2.shape == x.shape
132+
133+
134+
def test_isin_empty_inputs():
135+
get_queue_or_skip()
136+
137+
x = dpt.ones((10, 0, 1), dtype="i4")
138+
test = dpt.ones((), dtype="i4")
139+
res1 = dpt.isin(x, test)
140+
assert isinstance(res1, dpt.usm_ndarray)
141+
assert res1.size == 0
142+
assert res1.shape == x.shape
143+
assert res1.dtype == dpt.bool
144+
145+
res2 = dpt.isin(x, test, invert=True)
146+
assert isinstance(res2, dpt.usm_ndarray)
147+
assert res2.size == 0
148+
assert res2.shape == x.shape
149+
assert res2.dtype == dpt.bool
150+
151+
x = dpt.ones((3, 3), dtype="i4")
152+
test = dpt.ones(0, dtype="i4")
153+
res3 = dpt.isin(x, test)
154+
assert isinstance(res3, dpt.usm_ndarray)
155+
assert res3.shape == x.shape
156+
assert res3.dtype == dpt.bool
157+
assert not dpt.all(res3)
158+
159+
res4 = dpt.isin(x, test, invert=True)
160+
assert isinstance(res4, dpt.usm_ndarray)
161+
assert res4.shape == x.shape
162+
assert res4.dtype == dpt.bool
163+
assert dpt.all(res4)
164+
165+
166+
def test_isin_validation():
167+
with pytest.raises(ExecutionPlacementError):
168+
dpt.isin(1, 1)

0 commit comments

Comments
 (0)