Skip to content

[TorchToLinalg] Casting float to integer should round to nearest for AtenPowTensorTensorOp. #4228

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

cats-marin
Copy link

@cats-marin cats-marin commented Jun 11, 2025

Fixes #4091. I assume this will also need to be fixed for AtenPowScalarOp and AtenPowTensorScalarOp as well. I'm putting up a PR to ensure the initial approach is correct (new contributor :D ) before I put up another fix for AtenPowScalarOp and AtenPowTensorScalarOp.

@cats-marin cats-marin marked this pull request as ready for review June 11, 2025 05:12
@cats-marin
Copy link
Author

It'd be great if you could review this @zjgarvey!

Copy link
Collaborator

@zjgarvey zjgarvey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey, this generally looks like good work, but I have a few notes:

  1. The issue TorchToLinalg: casting float to integer should round to nearest #4091 really pertains to the behavior of the TorchOnnxToTorch lowering.
  2. Note the comment here , which indicates that the torch op only has an integer result type when both the base and exponent dtypes are integer types. This means that if we properly generate IR for AtenPowTensorTensorOp, we will never be in the situation covered by your current changes.
  3. A proper resolution to the issue is likely to edit
    rewriter.replaceOpWithNewOp<Torch::AtenToDtypeOp>(
    binder.op, resultType, pow, outTyConst, cstFalse, cstFalse, none);
    to also include something like:
pow = rewriter.create<AtenRoundOp>(loc, pow.getType(), pow);

(Note: AtenRoundOp lowers to an elementwise application of math::RoundEvenOp).

Comment on lines +1030 to +1031
if (payloadArgs[0].getType().isInteger() &&
payloadArgs[1].getType().isInteger()) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This conversion still doesn't support non-float/non-int dtypes. E.g. complex dtypes.

I'd recommend still computing the dtype and adding a check on dtype.isIntOrFloat().

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

TorchToLinalg: casting float to integer should round to nearest
2 participants