Skip to content

[TORCH] Add support for aten.hinge_embedding_loss Op #4227

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

sharavak
Copy link

  • Decomposed hinge_embedding_loss op into Aten ops.
  • Added test cases in the e2e part.

This implementation addresses and closes #4222

@sharavak sharavak force-pushed the hinge_embedding_loss branch 2 times, most recently from 9d5b896 to ddf112a Compare June 10, 2025 14:40
@sharavak sharavak marked this pull request as ready for review June 10, 2025 14:41
@sharavak
Copy link
Author

sharavak commented Jun 10, 2025

@stellaraccident , @vivekkhandelwal1, @zjgarvey I’d be grateful if any of you could take a look at this PR. Your feedback would be greatly appreciated!

@vivekkhandelwal1 vivekkhandelwal1 self-requested a review June 16, 2025 12:13
Comment on lines +10519 to +10622
// Compute mask: target != 1
auto targetNotOne =
rewriter.create<AtenNeScalarOp>(loc, boolType, target, one);
// If target != 1 use marginClamp otherwise 0.
auto outputMargin = rewriter.create<AtenWhereScalarOtherOp>(
loc, inputTy, targetNotOne, marginClamp, zero);
// Compute mask: target != -1
auto targetNotMinusOne =
rewriter.create<AtenNeScalarOp>(loc, boolType, target, minusOne);
// If target != -1 use the original input. Otherwise 0.
auto outputSelf = rewriter.create<AtenWhereScalarOtherOp>(
loc, inputTy, targetNotMinusOne, input, zero);
// Add : outputMargin + outputSelf
auto output = rewriter.create<AtenAddTensorOp>(loc, inputTy, outputMargin,
outputSelf, /*alpha=*/alpha);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of doing all this, you can just do:

auto result = rewriter.create<AtenWhereScalarOtherOp>(
        loc, inputTy, targetNotOne, marginClamp, input);

Copy link
Author

@sharavak sharavak Jun 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vivekkhandelwal1 Thanks for the suggestion, I did try the simplified version initially, but it caused numerical validation errors in some test cases. This happens because the target tensor can sometimes have values other than just -1 and 1.

To handle this properly and stay consistent with PyTorch's semantics, I decided to explicitly check for both target == 1 and target == -1. This way, the behavior stays correct even if target have values other than just -1 and 1.

Eg:

import torch
input=torch.randn(2,3)
target=torch.randn(2,3)
torch.hinge_embedding_loss(input,target)

Output:
tensor([[1.1361, 1.0000, 1.0000],
        [1.4880, 1.1624, 1.0000]])

@sharavak sharavak force-pushed the hinge_embedding_loss branch from ddf112a to 8d8c30b Compare June 17, 2025 17:13
@sharavak
Copy link
Author

sharavak commented Jun 17, 2025

@vivekkhandelwal1 Thanks a lot for the feedback. I've updated the code.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[TORCH] Add support for aten.hinge_embedding_loss
2 participants