Skip to content

Commit 43161f9

Browse files
committed
fix: concat with union categories
1 parent 8f925fe commit 43161f9

File tree

3 files changed

+6
-4
lines changed

3 files changed

+6
-4
lines changed

awswrangler/s3/_read.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,8 @@ def _extract_partitions_dtypes_from_table_details(response: "GetTableResponseTyp
116116
return dtypes
117117

118118

119-
def _union(dfs: list[pd.DataFrame], ignore_index: bool) -> pd.DataFrame:
119+
def _concat_union_categoricals(dfs: list[pd.DataFrame], ignore_index: bool) -> pd.DataFrame:
120+
"""Concatenate dataframes with union of categorical columns."""
120121
cats: tuple[set[str], ...] = tuple(set(df.select_dtypes(include="category").columns) for df in dfs)
121122
for col in set.intersection(*cats):
122123
cat = union_categoricals([df[col] for df in dfs])

awswrangler/s3/_read_parquet.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
_get_path_ignore_suffix,
3939
_get_path_root,
4040
_get_paths_for_glue_table,
41+
_concat_union_categoricals,
4142
_InternalReadTableMetadataReturnValue,
4243
_TableMetadataReader,
4344
)
@@ -264,7 +265,7 @@ def _read_parquet_chunked(
264265
yield df
265266
else:
266267
if next_slice is not None:
267-
df = pd.concat(objs=[next_slice, df], sort=False, copy=False)
268+
df = _concat_union_categoricals(dfs=[next_slice, df], ignore_index=False)
268269
while len(df.index) >= chunked:
269270
yield df.iloc[:chunked, :].copy()
270271
df = df.iloc[chunked:, :]

awswrangler/s3/_read_text.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
_get_num_output_blocks,
2323
_get_path_ignore_suffix,
2424
_get_path_root,
25-
_union,
25+
_concat_union_categoricals,
2626
)
2727
from awswrangler.s3._read_text_core import _read_text_file, _read_text_files_chunked
2828
from awswrangler.typing import RaySettings
@@ -70,7 +70,7 @@ def _read_text(
7070
itertools.repeat(s3_additional_kwargs),
7171
itertools.repeat(dataset),
7272
)
73-
return _union(dfs=tables, ignore_index=ignore_index)
73+
return _concat_union_categoricals(dfs=tables, ignore_index=ignore_index)
7474

7575

7676
def _read_text_format(

0 commit comments

Comments
 (0)