-
Notifications
You must be signed in to change notification settings - Fork 30
Implement tensor.isin
#2098
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Implement tensor.isin
#2098
Conversation
View rendered docs @ https://intelpython.github.io/dpctl/pulls/2098/index.html |
Array API standard conformance tests for dpctl=0.21.0dev0=py310h93fe807_8 ran successfully. |
1805102
to
5355fb8
Compare
Array API standard conformance tests for dpctl=0.21.0dev0=py310h93fe807_10 ran successfully. |
dpctl/tensor/_set_functions.py
Outdated
|
||
dep_evs = _manager.submitted_events | ||
ht_ev, s_ev = _isin( | ||
needles=x1, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems the only case when the strided implementation (which assumes slower) will be used when x1
array is not contiguous (we sort test_elements
array and no out
keyword in isin
function).
Would it make sense to flatten input array x
and to pass order
keyword there?
But, it makes sense also to keep strided implementation of _isin
in case when it might be helpful in implementation of other set functions).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can experiment and see values of flattening vs. not flattening
but in general, this implementation is going to be changed quite a bit soon, I have some local changes waiting
@antonwolfy tests still need to be added, but |
Array API standard conformance tests for dpctl=0.21.0dev0=py310h93fe807_17 ran successfully. |
@@ -112,6 +112,7 @@ set(_reduction_sources | |||
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/sum.cpp | |||
) | |||
set(_sorting_sources | |||
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/isin.cpp |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It doesn't seem relating to sorting routine
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it uses common utilities with searchsorted
(i.e., from rich_comparisons.hpp
) which is why it lives there
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if the code from rich_comparisons
gets factored out, I can go ahead and move it elsewhere, I guess to _tensor_impl
for now
fnT get() const | ||
{ | ||
using dpctl::tensor::kernels::isin_contig_impl; | ||
using Compare = typename AscendingSorter<argTy>::type; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to have Compare
templated here? Is there any use case possible when another one will be required to be used by isin
kernel?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not strictly necessary, but was done to reduce code duplication—these sorters are defined in tensor/source
.
I can look at the normal sort implementation to refresh myself on what was done there, but if that isn't sufficient, it may be preferable to template and pass the sorter here as opposed to duplicating in isin.hpp
Array API standard conformance tests for dpctl=0.21.0dev0=py310h93fe807_18 ran successfully. |
Array API standard conformance tests for dpctl=0.21.0dev0=py310h93fe807_22 ran successfully. |
Array API standard conformance tests for dpctl=0.21.0dev0=py310h93fe807_23 ran successfully. |
Array API standard conformance tests for dpctl=0.21.0dev0=py310h93fe807_24 ran successfully. |
b3822f3
to
f7e0967
Compare
Array API standard conformance tests for dpctl=0.21.0dev0=py310h93fe807_23 ran successfully. |
|
||
if not isinstance(x, dpt.usm_ndarray): | ||
x_arr = dpt.asarray( | ||
x, dtype=dt1, usm_type=res_usm_type, sycl_queue=exec_q |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to cast here to result dtype to avoid unnecessary copy below?
x, dtype=dt1, usm_type=res_usm_type, sycl_queue=exec_q | |
x, dtype=dt, usm_type=res_usm_type, sycl_queue=exec_q |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The same with test_arr
casting here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And probably then it would sense to combine the checks, like:
if not isinstance(x, dpt.usm_ndarray):
x_buf = dpt.asarray(
x, dtype=dt, usm_type=res_usm_type, sycl_queue=exec_q
)
elif x_dt != dt:
x_buf = _empty_like_orderK(x, dt, res_usm_type, sycl_dev)
ht_ev, ev = _copy_usm_ndarray_into_usm_ndarray(
src=x, dst=x_buf, sycl_queue=exec_q, depends=dep_evs
)
_manager.add_event_pair(ht_ev, ev)
else:
x_buf = x
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've left it this way because this is how element-wise functions handle it, as well: scalars are put first into the appropriate array type, then the array is cast into another type for computation
I'm not sure if it's strictly necessary, but may avoid some edge cases producing incorrect results
|
||
|
||
@pytest.mark.parametrize( | ||
"dtype", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is about case when inputs have different dtypes and casting is required?
Array API standard conformance tests for dpctl=0.21.0dev0=py310h93fe807_24 ran successfully. |
8f19cb5
to
3cf7445
Compare
Array API standard conformance tests for dpctl=0.21.0dev0=py310h93fe807_26 ran successfully. |
3cf7445
to
7c6a4be
Compare
Array API standard conformance tests for dpctl=0.21.0dev0=py310h93fe807_27 ran successfully. |
Array API standard conformance tests for dpctl=0.21.0dev0=py310h93fe807_30 ran successfully. |
isin leverages kernel very similar to searchsorted, but after the search, the position is checked, and if the position is equal to the number of elements in the searched array, existence is considered false
permit scalar input for second argument, address some review comments, add docstring
radix sort implementation asserts that array must be of size > 1
16c63e4
to
23c61a8
Compare
Array API standard conformance tests for dpctl=0.21.0dev0=py310h93fe807_35 ran successfully. |
This PR proposes an implementation for
isin
, a function likely coming to a future array API specification, which leverages a similar kernel to the implementation ofsearchsorted
This implementation uses the
searchsorted
kernel to check if the value has a position in the array. If that position is the number of elements in the array, it is not a member. Otherwise, ifarr[pos] == val
for some arrayarr
being searched for valueval
, thenval
is a member.