Skip to content

Commit 5b0d719

Browse files
committed
Add NumToTensor
1 parent 18e8806 commit 5b0d719

File tree

5 files changed

+90
-10
lines changed

5 files changed

+90
-10
lines changed

e2e_testing/torchscript/basic.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -542,3 +542,19 @@ def forward(self, tensor):
542542
@register_test_case(module_factory=lambda: LogSoftmaxIntModule())
543543
def LogSoftmaxIntModule_basic(module, tu: TestUtils):
544544
module.forward(torch.randn(3, 2, 4).double())
545+
546+
class NumToTensorModule(torch.nn.Module):
547+
def __init__(self):
548+
super().__init__()
549+
550+
@export
551+
@annotate_args([
552+
None,
553+
])
554+
555+
def forward(self):
556+
return torch.ops.prim.NumToTensor(1)
557+
558+
@register_test_case(module_factory=lambda: NumToTensorModule())
559+
def NumToTensorModule_basic(module, tu: TestUtils):
560+
module.forward()

lib/Conversion/TorchToLinalg/TorchToLinalg.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2794,6 +2794,29 @@ class ConvertAtenOnesOp : public OpConversionPattern<AtenOnesOp> {
27942794
};
27952795
} // namespace
27962796

2797+
namespace {
2798+
class ConvertPrimNumToTensorScalarOp
2799+
: public OpConversionPattern<PrimNumToTensorScalarOp> {
2800+
public:
2801+
using OpConversionPattern::OpConversionPattern;
2802+
LogicalResult
2803+
matchAndRewrite(PrimNumToTensorScalarOp op, ArrayRef<Value> operands,
2804+
ConversionPatternRewriter &rewriter) const override {
2805+
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
2806+
return failure();
2807+
PrimNumToTensorScalarOp::Adaptor adaptor(operands);
2808+
Location loc = op.getLoc();
2809+
Value a = adaptor.a();
2810+
Value outTensor =
2811+
rewriter.create<linalg::InitTensorOp>(loc, ValueRange{}, a.getType())
2812+
->getResult(0);
2813+
rewriter.replaceOpWithNewOp<linalg::FillOp>(op, a, outTensor);
2814+
2815+
return success();
2816+
}
2817+
};
2818+
} // namespace
2819+
27972820
// -----------------------------------------------------------------------------
27982821
// The pass
27992822
// -----------------------------------------------------------------------------
@@ -2878,6 +2901,9 @@ class ConvertTorchToLinalg
28782901
patterns.add<ConvertAtenContiguousOp>(typeConverter, context);
28792902
target.addIllegalOp<AtenIntTensorOp>();
28802903
patterns.add<ConvertAtenIntTensorOp>(typeConverter, context);
2904+
target.addIllegalOp<PrimNumToTensorScalarOp>();
2905+
patterns.add<ConvertPrimNumToTensorScalarOp>(typeConverter, context);
2906+
28812907
if (failed(applyPartialConversion(getOperation(), target,
28822908
std::move(patterns))))
28832909
return signalPassFailure();

lib/Dialect/Torch/Transforms/RefineTypes.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,8 @@ class TypeAnalyzer : public ForwardDataFlowAnalysis<ValueKnowledge> {
416416
return visitAtenSoftmaxLikeOp(softmaxIntOp, operands);
417417
} else if (auto logSoftmaxIntOp = dyn_cast<AtenLogSoftmaxIntOp>(op)) {
418418
return visitAtenSoftmaxLikeOp(logSoftmaxIntOp, operands);
419+
} else if (auto numToTensorOp = dyn_cast<PrimNumToTensorScalarOp>(op)) {
420+
return visitNumToTensorOp(numToTensorOp);
419421
}
420422

421423
// Otherwise, this is an unknown operation. Just mark all results as
@@ -477,6 +479,7 @@ class TypeAnalyzer : public ForwardDataFlowAnalysis<ValueKnowledge> {
477479
ChangeResult
478480
visitAtenPermuteOp(AtenPermuteOp op,
479481
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
482+
ChangeResult visitNumToTensorOp(PrimNumToTensorScalarOp op);
480483
template <typename OpTy>
481484
ChangeResult visitScalarToTensorConversionOp(OpTy op);
482485
ChangeResult visitAtenTensorOp(AtenTensorOp op);
@@ -1262,6 +1265,14 @@ ChangeResult TypeAnalyzer::visitAtenShapeAsTensorOp(
12621265
return getLatticeElement(op.getResult()).join(knowledge);
12631266
}
12641267

1268+
ChangeResult TypeAnalyzer::visitNumToTensorOp(PrimNumToTensorScalarOp op) {
1269+
auto knowledge =
1270+
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
1271+
knowledge.hasSizes = true;
1272+
knowledge.dtype = getDefaultDtypeForTorchScalar(op.a().getType());
1273+
return getLatticeElement(op.getResult()).join(knowledge);
1274+
}
1275+
12651276
ChangeResult TypeAnalyzer::visitAtenEmbeddingOp(
12661277
AtenEmbeddingOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
12671278
auto knowledge =

test/Conversion/TorchToLinalg/basic.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,17 @@ func @integer_extract(%arg0: !torch.vtensor<[],si64>) -> !torch.int {
6767
%0 = torch.aten.Int.Tensor %arg0 : !torch.vtensor<[],si64> -> !torch.int
6868
return %0 : !torch.int
6969
}
70+
71+
// -----
72+
73+
// CHECK: func @torch.prim.NumToTensor.Scalar$basic(%[[IN:.*]]: !torch.int) -> !torch.vtensor<[],si64> {
74+
// CHECK: %[[INI64:.*]] = torch_c.to_i64 %[[IN]]
75+
// CHECK: %[[NEWVEC:.*]] = linalg.init_tensor [] : tensor<i64>
76+
// CHECK: %[[FILLVEC:.*]] = linalg.fill(%[[INI64]], %[[NEWVEC]]) : i64, tensor<i64> -> tensor<i64>
77+
// CHECK: %[[OUTVEC:.*]] = torch_c.from_builtin_tensor %[[FILLVEC]] : tensor<i64> -> !torch.vtensor<[],si64>
78+
// CHECK: return %[[OUTVEC]] : !torch.vtensor<[],si64>
79+
80+
func @torch.prim.NumToTensor.Scalar$basic(%arg0: !torch.int) -> !torch.vtensor<[],si64> {
81+
%0 = torch.prim.NumToTensor.Scalar %arg0 : !torch.int -> !torch.vtensor<[],si64>
82+
return %0 : !torch.vtensor<[],si64>
83+
}

test/Dialect/Torch/refine-types.mlir

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -979,28 +979,28 @@ func @torch.aten.softmax.int$specified_dtype(%t: !torch.tensor<[2,3],f32>, %dim:
979979

980980

981981
// ----
982-
// CHECK-LABEL: func @aten_matmul_broadcast_matrix(
982+
// CHECK-LABEL: func @torch.aten.Matmul.Broadcast.Matrix(
983983
// CHECK-SAME: %[[LHS:.*]]: !torch.vtensor<[?,?,?,?,?],f32>,
984984
// CHECK-SAME: %[[RHS:.*]]: !torch.vtensor<[?,?,?],f32>)
985985
// CHECK-SAME: -> !torch.tensor {
986-
// CHECK: %[[MUL:.*]] = torch.aten.matmul %[[LHS]], %[[RHS]] : !torch.vtensor<[?,?,?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor<[?,?,?,?,?],f32>
987-
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[MUL]] : !torch.tensor<[?,?,?,?,?],f32> to !torch.tensor
988-
// CHECK: return %[[CAST]] : !torch.tensor
989-
func @aten_matmul_broadcast_matrix(%arg0: !torch.vtensor<[?,?,?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?],f32>) -> !torch.tensor {
986+
// CHECK: %[[MUL:.*]] = torch.aten.matmul %[[LHS]], %[[RHS]] : !torch.vtensor<[?,?,?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor<[?,?,?,?,?],f32>
987+
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[MUL]] : !torch.tensor<[?,?,?,?,?],f32> to !torch.tensor
988+
// CHECK: return %[[CAST]] : !torch.tensor
989+
func @torch.aten.Matmul.Broadcast.Matrix(%arg0: !torch.vtensor<[?,?,?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?],f32>) -> !torch.tensor {
990990
%0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[?,?,?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor
991991
return %0 : !torch.tensor
992992
}
993993

994994

995995
// ----
996-
// CHECK-LABEL: func @aten_matmul_broadcast_vector(
996+
// CHECK-LABEL: func @torch.aten.Matmul.Broadcast.Vector(
997997
// CHECK-SAME: %[[LHS:.*]]: !torch.vtensor<[?,?,?,?,?],f32>,
998998
// CHECK-SAME: %[[RHS:.*]]: !torch.vtensor<[?],f32>)
999999
// CHECK-SAME: -> !torch.tensor {
1000-
// CHECK: %[[MUL:.*]] = torch.aten.matmul %[[LHS]], %[[RHS]] : !torch.vtensor<[?,?,?,?,?],f32>, !torch.vtensor<[?],f32> -> !torch.tensor<[?,?,?,?],f32>
1001-
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[MUL]] : !torch.tensor<[?,?,?,?],f32> to !torch.tensor
1002-
// CHECK: return %[[CAST]] : !torch.tensor
1003-
func @aten_matmul_broadcast_vector(%arg0: !torch.vtensor<[?,?,?,?,?],f32>, %arg1: !torch.vtensor<[?],f32>) -> !torch.tensor {
1000+
// CHECK: %[[MUL:.*]] = torch.aten.matmul %[[LHS]], %[[RHS]] : !torch.vtensor<[?,?,?,?,?],f32>, !torch.vtensor<[?],f32> -> !torch.tensor<[?,?,?,?],f32>
1001+
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[MUL]] : !torch.tensor<[?,?,?,?],f32> to !torch.tensor
1002+
// CHECK: return %[[CAST]] : !torch.tensor
1003+
func @torch.aten.Matmul.Broadcast.Vector(%arg0: !torch.vtensor<[?,?,?,?,?],f32>, %arg1: !torch.vtensor<[?],f32>) -> !torch.tensor {
10041004
%0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[?,?,?,?,?],f32>, !torch.vtensor<[?],f32> -> !torch.tensor
10051005
return %0 : !torch.tensor
10061006
}
@@ -1022,3 +1022,16 @@ func @torch.aten.to.dtype(%arg0: !torch.tensor<[?,?],f32>) -> !torch.tensor{
10221022
%0 = torch.aten.to.dtype %arg0, %int4, %false, %false, %none : !torch.tensor<[?,?],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.tensor
10231023
return %0 : !torch.tensor
10241024
}
1025+
1026+
// ----
1027+
// CHECK-LABEL: func @torch.prim.NumToTensor.Scalar(
1028+
// CHECK-SAME: %[[SELF:.*]]: !torch.int)
1029+
// CHECK-SAME: -> !torch.tensor {
1030+
// CHECK: %[[NTT:.*]] = torch.prim.NumToTensor.Scalar %[[SELF]] : !torch.int -> !torch.tensor<[],si64>
1031+
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[NTT]] : !torch.tensor<[],si64> to !torch.tensor
1032+
// CHECK: return %[[CAST]] : !torch.tensor
1033+
1034+
func @torch.prim.NumToTensor.Scalar(%arg0: !torch.int) -> !torch.tensor {
1035+
%0 = torch.prim.NumToTensor.Scalar %arg0: !torch.int -> !torch.tensor
1036+
return %0: !torch.tensor
1037+
}

0 commit comments

Comments
 (0)