-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[TOSA] bug fix infer shape for slice #113497
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-tosa Author: Tai Ly (Tai78641) ChangesThis fixes the infer output shape of TOSA slice op for start/size values that are out-of-bound or -1 added tests to check:
Full diff: https://github.com/llvm/llvm-project/pull/113497.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 631d3c48f2df02..01312584652049 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -844,8 +844,40 @@ LogicalResult tosa::SliceOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
SliceOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
- inferredReturnShapes.push_back(
- ShapedTypeComponents(convertToMlirShape(adaptor.getSize())));
+ auto start = adaptor.getStart();
+ auto size = adaptor.getSize();
+
+ // if size[i] is -1, all remaining elements in dimension i are included
+ // in the slice, similar to TF.
+ ShapeAdaptor inputShape(adaptor.getInput().getType());
+ // initialize outputShape to all unknown
+ SmallVector<int64_t> outputShape(size.size(), ShapedType::kDynamic);
+ if (inputShape.hasRank()) {
+ for (size_t i = 0; i < size.size(); i++) {
+ if (size[i] != 0 && size[i] >= -1 && start[i] >= 0 &&
+ (ShapedType::isDynamic(inputShape.getDimSize(i)) ||
+ start[i] < inputShape.getDimSize(i))) {
+ // size[i] is not 0 and not < -1, and start[i] is in valid range
+ if (ShapedType::isDynamic(inputShape.getDimSize(i))) {
+ // input shape has unknown dim[i] - only valid if size[i] > 0
+ if (size[i] > 0) {
+ outputShape[i] = size[i];
+ }
+ } else {
+ // input shape has known dim[i]
+ if (size[i] == -1) {
+ outputShape[i] = inputShape.getDimSize(i) - start[i];
+ } else if (start[i] + size[i] <= inputShape.getDimSize(i)) {
+ // start[i] + size[i] is within bound of input shape's dim[i]
+ outputShape[i] = size[i];
+ }
+ }
+ }
+ }
+ } else {
+ outputShape = convertToMlirShape(size);
+ }
+ inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
return success();
}
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index d46de740800e93..d2314698afa925 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -532,6 +532,48 @@ func.func @test_slice(%arg0 : tensor<?xi32>) -> () {
// -----
+// CHECK-LABEL: @test_slice_size_minus_one
+func.func @test_slice_size_minus_one(%arg0 : tensor<?x8x8x8xi32>) -> () {
+ // CHECK: tosa.slice %arg0 {size = array<i64: -1, -1, -1, -1>, start = array<i64: 0, 1, -1, 8>} : (tensor<?x8x8x8xi32>) -> tensor<?x7x?x?xi32>
+ // this checks following
+ // dim 0: size=-1, input dim=? => inferred output dim is ?
+ // dim 1: size=-1 => inferred output dim is input_dim - start
+ // dim 2: size=-1, start=-1 => inferred output dim is ?
+ // dim 3: size=-1, start=8 => inferred output dim is ? because start is out of bound
+ %2= tosa.slice %arg0 { start = array<i64: 0, 1, -1, 8>, size = array<i64: -1, -1, -1, -1> } : (tensor<?x8x8x8xi32>) -> tensor<?x?x?x?xi32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @test_slice_size_out_of_bound
+func.func @test_slice_size_out_of_bound(%arg0 : tensor<8x8x8x?xi32>) -> () {
+ // CHECK: tosa.slice %arg0 {size = array<i64: 0, -2, 9, 4>, start = array<i64: 0, 0, 0, 0>} : (tensor<8x8x8x?xi32>) -> tensor<?x?x?x4xi32>
+ // this checks following
+ // dim 0: size=0 => inferred output dim is ?
+ // dim 1: size=-2 => inferred output dim is ?
+ // dim 3: start+size out of bound because size too big: inferred output dim is ?
+ // dim 4: size=4, input dim=? => inferred output dim is 4
+ %2= tosa.slice %arg0 { start = array<i64: 0, 0, 0, 0>, size = array<i64: 0, -2, 9, 4> } : (tensor<8x8x8x?xi32>) -> tensor<?x?x?x?xi32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @test_slice_start_out_of_bound
+func.func @test_slice_start_out_of_bound(%arg0 : tensor<8x8x8x?xi32>) -> () {
+ // CHECK: tosa.slice %arg0 {size = array<i64: 1, 1, 3, 4>, start = array<i64: -1, 8, 6, 8000000>} : (tensor<8x8x8x?xi32>) -> tensor<?x?x?x4xi32>
+ // this checks following
+ // dim 0: start=-1 => inferred output dim is ?
+ // dim 1: start=8 => inferred output dim is ?
+ // dim 2: start+size out of bound: inferred output dim is ?
+ // dim 3: start=8000000, size=4, input dim=? => inferred output dim is 4
+ %2= tosa.slice %arg0 { start = array<i64: -1, 8, 6, 8000000>, size = array<i64: 1, 1, 3, 4> } : (tensor<8x8x8x?xi32>) -> tensor<?x?x?x?xi32>
+ return
+}
+
+// -----
+
// CHECK-LABEL: @test_slice_dynamic
func.func @test_slice_dynamic(%arg0 : tensor<10x?x2xf32>) -> () {
// CHECK: tosa.slice %arg0 {size = array<i64: 7, -1, 1>, start = array<i64: 1, 0, 0>} : (tensor<10x?x2xf32>) -> tensor<7x?x1xf32>
|
@Tai78641 can we rebase? |
db2a192
to
3c16454
Compare
done |
This fixes the infer output shape of TOSA slice op for start/size values that are out-of-bound or -1 added tests to check: - size = -1 - size is out of bound - start is out of bound Signed-off-by: Tai Ly <[email protected]> Change-Id: I8b59502a93cb332fe5c9a9f87970b83742538126
3c16454
to
cd57590
Compare
// dim 1: size=-2 => inferred output dim is ? | ||
// dim 3: start+size out of bound because size too big: inferred output dim is ? | ||
// dim 4: size=4, input dim=? => inferred output dim is 4 | ||
%2= tosa.slice %arg0 { start = array<i64: 0, 0, 0, 0>, size = array<i64: 0, -2, 9, 4> } : (tensor<8x8x8x?xi32>) -> tensor<?x?x?x?xi32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Tai78641 Hello, I have a question regarding this test.
Isn't the tosa.slice operation in this test input already invalid as it goes out of bound?
https://www.mlplatform.org/tosa/tosa_spec.html#_slice
This fixes the infer output shape of TOSA slice op for start/size values that are out-of-bound or -1
added tests to check: