-
Notifications
You must be signed in to change notification settings - Fork 28.6k
[SPARK-48710][PYTHON] Use NumPy 2.0 compatible types #47083
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
e6dd745
87f06bd
7512c71
6103cd5
2851ae0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5370,6 +5370,18 @@ def _test() -> None: | |
import tempfile | ||
from pyspark.core.context import SparkContext | ||
|
||
try: | ||
# Numpy 2.0+ changed its string format, | ||
# adding type information to numeric scalars. | ||
import numpy as np | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here actually did a try catch but I think there's some issue related to import .. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's move this to sql and ml only for now because both modules use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you. Yes, the problem is that
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
+1 for moving. |
||
from pandas.util.version import Version | ||
|
||
if Version(np.__version__) >= Version("2"): | ||
# `legacy="1.25"` only available in `nump>=2` | ||
np.set_printoptions(legacy="1.25") # type: ignore[arg-type] | ||
except TypeError: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, yeah let's catch ImportError .... |
||
pass | ||
|
||
tmp_dir = tempfile.TemporaryDirectory() | ||
globs = globals().copy() | ||
# The small batch size here ensures that we see multiple batches, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1833,9 +1833,16 @@ def _test() -> None: | |
import sys | ||
from pyspark.sql import SparkSession | ||
import pyspark.pandas.indexing | ||
from pandas.util.version import Version | ||
|
||
os.chdir(os.environ["SPARK_HOME"]) | ||
|
||
if Version(np.__version__) >= Version("2"): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto |
||
# Numpy 2.0+ changed its string format, | ||
# adding type information to numeric scalars. | ||
# `legacy="1.25"` only available in `nump>=2` | ||
np.set_printoptions(legacy="1.25") # type: ignore[arg-type] | ||
|
||
globs = pyspark.pandas.indexing.__dict__.copy() | ||
globs["ps"] = pyspark.pandas | ||
spark = ( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -38,7 +38,7 @@ def pser(self): | |
"\nleading-whitespace", | ||
"trailing-Whitespace \t", | ||
None, | ||
np.NaN, | ||
np.nan, | ||
] | ||
) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -38,7 +38,7 @@ def pser(self): | |
"\nleading-whitespace", | ||
"trailing-Whitespace \t", | ||
None, | ||
np.NaN, | ||
np.nan, | ||
] | ||
) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -176,7 +176,7 @@ def as_spark_type( | |
return None | ||
return types.ArrayType(element_type) | ||
# BinaryType | ||
elif tpe in (bytes, np.character, np.bytes_, np.string_): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, nvm. I just checked the PR description. |
||
elif tpe in (bytes, np.character, np.bytes_): | ||
return types.BinaryType() | ||
# BooleanType | ||
elif tpe in (bool, np.bool_, "bool", "?"): | ||
|
@@ -190,7 +190,7 @@ def as_spark_type( | |
elif tpe in (decimal.Decimal,): | ||
# TODO: considering the precision & scale for decimal type. | ||
return types.DecimalType(38, 18) | ||
elif tpe in (float, np.float_, np.float64, "float", "float64", "double"): | ||
elif tpe in (float, np.double, np.float64, "float", "float64", "double"): | ||
return types.DoubleType() | ||
elif tpe in (np.float32, "float32", "f"): | ||
return types.FloatType() | ||
|
@@ -201,7 +201,7 @@ def as_spark_type( | |
elif tpe in (np.int16, "int16", "short"): | ||
return types.ShortType() | ||
# StringType | ||
elif tpe in (str, np.unicode_, "str", "U"): | ||
elif tpe in (str, np.str_, "str", "U"): | ||
return types.StringType() | ||
# TimestampType or TimestampNTZType if timezone is not specified. | ||
elif tpe in (datetime.datetime, np.datetime64, "datetime64[ns]", "M", pd.Timestamp): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shall we add a todo that once we upgrade the minimum version >= 2.0, we can remove this try-except and update the doc tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There exist also multiple tests already with:
I'd guess these should be considered for updating before that (since the minimum NumPy version is at
1.21
currently).