-
Notifications
You must be signed in to change notification settings - Fork 600
[TORCH] Add support for aten.hinge_embedding_loss Op #4227
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
9d5b896
to
ddf112a
Compare
@stellaraccident , @vivekkhandelwal1, @zjgarvey I’d be grateful if any of you could take a look at this PR. Your feedback would be greatly appreciated! |
projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py
Outdated
Show resolved
Hide resolved
projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py
Outdated
Show resolved
Hide resolved
projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py
Outdated
Show resolved
Hide resolved
// Compute mask: target != 1 | ||
auto targetNotOne = | ||
rewriter.create<AtenNeScalarOp>(loc, boolType, target, one); | ||
// If target != 1 use marginClamp otherwise 0. | ||
auto outputMargin = rewriter.create<AtenWhereScalarOtherOp>( | ||
loc, inputTy, targetNotOne, marginClamp, zero); | ||
// Compute mask: target != -1 | ||
auto targetNotMinusOne = | ||
rewriter.create<AtenNeScalarOp>(loc, boolType, target, minusOne); | ||
// If target != -1 use the original input. Otherwise 0. | ||
auto outputSelf = rewriter.create<AtenWhereScalarOtherOp>( | ||
loc, inputTy, targetNotMinusOne, input, zero); | ||
// Add : outputMargin + outputSelf | ||
auto output = rewriter.create<AtenAddTensorOp>(loc, inputTy, outputMargin, | ||
outputSelf, /*alpha=*/alpha); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of doing all this, you can just do:
auto result = rewriter.create<AtenWhereScalarOtherOp>(
loc, inputTy, targetNotOne, marginClamp, input);
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@vivekkhandelwal1 Thanks for the suggestion, I did try the simplified version initially, but it caused numerical validation errors in some test cases. This happens because the target tensor can sometimes have values other than just -1 and 1.
To handle this properly and stay consistent with PyTorch's semantics, I decided to explicitly check for both target == 1 and target == -1. This way, the behavior stays correct even if target have values other than just -1 and 1.
Eg:
import torch
input=torch.randn(2,3)
target=torch.randn(2,3)
torch.hinge_embedding_loss(input,target)
Output:
tensor([[1.1361, 1.0000, 1.0000],
[1.4880, 1.1624, 1.0000]])
ddf112a
to
8d8c30b
Compare
@vivekkhandelwal1 Thanks a lot for the feedback. I've updated the code. |
This implementation addresses and closes #4222