Skip to content

GH-46572: [Python] expose filter option to python for join #46566

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jun 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions python/pyarrow/_acero.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -273,14 +273,15 @@ cdef class _HashJoinNodeOptions(ExecNodeOptions):

def _set_options(
self, join_type, left_keys, right_keys, left_output=None, right_output=None,
output_suffix_for_left="", output_suffix_for_right="",
output_suffix_for_left="", output_suffix_for_right="", Expression filter_expression=None,
):
cdef:
CJoinType c_join_type
vector[CFieldRef] c_left_keys
vector[CFieldRef] c_right_keys
vector[CFieldRef] c_left_output
vector[CFieldRef] c_right_output
CExpression c_filter_expression

# join type
if join_type == "left semi":
Expand Down Expand Up @@ -312,6 +313,11 @@ cdef class _HashJoinNodeOptions(ExecNodeOptions):
for key in right_keys:
c_right_keys.push_back(_ensure_field_ref(key))

if filter_expression is None:
c_filter_expression = _true
else:
c_filter_expression = filter_expression.unwrap()

# left/right output fields
if left_output is not None and right_output is not None:
for colname in left_output:
Expand All @@ -323,7 +329,7 @@ cdef class _HashJoinNodeOptions(ExecNodeOptions):
new CHashJoinNodeOptions(
c_join_type, c_left_keys, c_right_keys,
c_left_output, c_right_output,
_true,
c_filter_expression,
<c_string>tobytes(output_suffix_for_left),
<c_string>tobytes(output_suffix_for_right)
)
Expand All @@ -332,7 +338,7 @@ cdef class _HashJoinNodeOptions(ExecNodeOptions):
self.wrapped.reset(
new CHashJoinNodeOptions(
c_join_type, c_left_keys, c_right_keys,
_true,
c_filter_expression,
<c_string>tobytes(output_suffix_for_left),
<c_string>tobytes(output_suffix_for_right)
)
Expand Down Expand Up @@ -373,15 +379,17 @@ class HashJoinNodeOptions(_HashJoinNodeOptions):
output_suffix_for_right : str
Suffix added to names of output fields coming from right input,
see `output_suffix_for_left` for details.
filter_expression : pyarrow.compute.Expression
Residual filter which is applied to matching row.
"""

def __init__(
self, join_type, left_keys, right_keys, left_output=None, right_output=None,
output_suffix_for_left="", output_suffix_for_right=""
output_suffix_for_left="", output_suffix_for_right="", filter_expression=None,
):
self._set_options(
join_type, left_keys, right_keys, left_output, right_output,
output_suffix_for_left, output_suffix_for_right
output_suffix_for_left, output_suffix_for_right, filter_expression
)


Expand Down
6 changes: 5 additions & 1 deletion python/pyarrow/acero.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def _perform_join(join_type, left_operand, left_keys,
right_operand, right_keys,
left_suffix=None, right_suffix=None,
use_threads=True, coalesce_keys=False,
output_type=Table):
output_type=Table, filter_expression=None):
"""
Perform join of two tables or datasets.
Expand Down Expand Up @@ -114,6 +114,8 @@ def _perform_join(join_type, left_operand, left_keys,
in the join result.
output_type: Table or InMemoryDataset
The output type for the exec plan result.
filter_expression : pyarrow.compute.Expression
Residual filter which is applied to matching row.
Returns
-------
Expand Down Expand Up @@ -183,12 +185,14 @@ def _perform_join(join_type, left_operand, left_keys,
join_type, left_keys, right_keys, left_columns, right_columns,
output_suffix_for_left=left_suffix or "",
output_suffix_for_right=right_suffix or "",
filter_expression=filter_expression,
)
else:
join_opts = HashJoinNodeOptions(
join_type, left_keys, right_keys,
output_suffix_for_left=left_suffix or "",
output_suffix_for_right=right_suffix or "",
filter_expression=filter_expression,
)
decl = Declaration(
"hashjoin", options=join_opts, inputs=[left_source, right_source]
Expand Down
24 changes: 21 additions & 3 deletions python/pyarrow/table.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -5631,7 +5631,7 @@ cdef class Table(_Tabular):

def join(self, right_table, keys, right_keys=None, join_type="left outer",
left_suffix=None, right_suffix=None, coalesce_keys=True,
use_threads=True):
use_threads=True, filter_expression=None):
"""
Perform a join between this table and another one.

Expand Down Expand Up @@ -5665,6 +5665,8 @@ cdef class Table(_Tabular):
in the join result.
use_threads : bool, default True
Whether to use multithreading or not.
filter_expression : pyarrow.compute.Expression
Residual filter which is applied to matching row.

Returns
-------
Expand All @@ -5674,6 +5676,7 @@ cdef class Table(_Tabular):
--------
>>> import pandas as pd
>>> import pyarrow as pa
>>> import pyarrow.compute as pc
>>> df1 = pd.DataFrame({'id': [1, 2, 3],
... 'year': [2020, 2022, 2019]})
>>> df2 = pd.DataFrame({'id': [3, 4],
Expand Down Expand Up @@ -5724,7 +5727,7 @@ cdef class Table(_Tabular):
n_legs: [[5,100]]
animal: [["Brittle stars","Centipede"]]

Right anti join
Right anti join:

>>> t1.join(t2, 'id', join_type="right anti")
pyarrow.Table
Expand All @@ -5735,6 +5738,20 @@ cdef class Table(_Tabular):
id: [[4]]
n_legs: [[100]]
animal: [["Centipede"]]

Inner join with intended mismatch filter expression:

>>> t1.join(t2, 'id', join_type="inner", filter_expression=pc.equal(pc.field("n_legs"), 100))
pyarrow.Table
id: int64
year: int64
n_legs: int64
animal: string
----
id: []
year: []
n_legs: []
animal: []
"""
self._assert_cpu()
if right_keys is None:
Expand All @@ -5743,7 +5760,8 @@ cdef class Table(_Tabular):
join_type, self, keys, right_table, right_keys,
left_suffix=left_suffix, right_suffix=right_suffix,
use_threads=use_threads, coalesce_keys=coalesce_keys,
output_type=Table
output_type=Table,
filter_expression=filter_expression,
)

def join_asof(self, right_table, on, by, tolerance, right_on=None, right_by=None):
Expand Down
64 changes: 64 additions & 0 deletions python/pyarrow/tests/test_acero.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,70 @@ def test_hash_join():
assert result.sort_by("a").equals(expected)


def test_hash_join_with_residual_filter():
left = pa.table({'key': [1, 2, 3], 'a': [4, 5, 6]})
left_source = Declaration("table_source", options=TableSourceNodeOptions(left))
right = pa.table({'key': [2, 3, 4], 'b': [4, 5, 6]})
right_source = Declaration("table_source", options=TableSourceNodeOptions(right))

join_opts = HashJoinNodeOptions(
"inner", left_keys="key", right_keys="key",
filter_expression=pc.equal(pc.field('a'), 5))
joined = Declaration(
"hashjoin", options=join_opts, inputs=[left_source, right_source])
result = joined.to_table()
expected = pa.table(
[[2], [5], [2], [4]],
names=["key", "a", "key", "b"])
assert result.equals(expected)

# test filter expression referencing columns from both side
join_opts = HashJoinNodeOptions(
"left outer", left_keys="key", right_keys="key",
filter_expression=pc.equal(pc.field("a"), 5) | pc.equal(pc.field("b"), 10)
)
joined = Declaration(
"hashjoin", options=join_opts, inputs=[left_source, right_source])
result = joined.to_table()
expected = pa.table(
[[2, 1, 3], [5, 4, 6], [2, None, None], [4, None, None]],
names=["key", "a", "key", "b"])
assert result.equals(expected)

# test with always true
always_true = pc.scalar(True)
join_opts = HashJoinNodeOptions(
"inner", left_keys="key", right_keys="key",
filter_expression=always_true)
joined = Declaration(
"hashjoin", options=join_opts, inputs=[left_source, right_source])
result = joined.to_table()
expected = pa.table(
[[2, 3], [5, 6], [2, 3], [4, 5]],
names=["key", "a", "key", "b"]
)
assert result.equals(expected)

# test with always false
always_false = pc.scalar(False)
join_opts = HashJoinNodeOptions(
"inner", left_keys="key", right_keys="key",
filter_expression=always_false)
joined = Declaration(
"hashjoin", options=join_opts, inputs=[left_source, right_source])
result = joined.to_table()
expected = pa.table(
[
pa.array([], type=pa.int64()),
pa.array([], type=pa.int64()),
pa.array([], type=pa.int64()),
pa.array([], type=pa.int64())
],
names=["key", "a", "key", "b"]
)
assert result.equals(expected)


def test_asof_join():
left = pa.table({'key': [1, 2, 3], 'ts': [1, 1, 1], 'a': [4, 5, 6]})
left_source = Declaration("table_source", options=TableSourceNodeOptions(left))
Expand Down
Loading