Skip to content

[mlir][vector] Update tests for xfer permutation lowering (1/N) #123076

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 20, 2025

Conversation

banach-space
Copy link
Contributor

  1. Remove %c0 = arith.constant 0 : index from testt functions. This
    extra Op is not needed (the index can be passed as an argument), so
    this is just noise.
  2. Replaced %cst_0 with %pad to communicate what the underlying SSA
    value is intended for.
  3. Unified some comments.

1. Remove `%c0 = arith.constant 0 : index` from testt functions. This
   extra Op is not needed (the index can be passed as an argument), so
   this is just noise.
2. Replaced `%cst_0` with `%pad` to communicate what the underlying SSA
   value is intended for.
3. Unified some comments.
@llvmbot
Copy link
Member

llvmbot commented Jan 15, 2025

@llvm/pr-subscribers-mlir-vector

Author: Andrzej Warzyński (banach-space)

Changes
  1. Remove %c0 = arith.constant 0 : index from testt functions. This
    extra Op is not needed (the index can be passed as an argument), so
    this is just noise.
  2. Replaced %cst_0 with %pad to communicate what the underlying SSA
    value is intended for.
  3. Unified some comments.

Full diff: https://github.com/llvm/llvm-project/pull/123076.diff

1 Files Affected:

  • (modified) mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir (+77-73)
diff --git a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
index 0feaf690af2510..045d4a9cdb5ddb 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
@@ -13,17 +13,17 @@
 
 // CHECK-LABEL:   func.func @xfer_write_transposing_permutation_map
 // CHECK-SAME:      %[[VEC:.*]]: vector<4x8xi16>,
-// CHECK-SAME:      %[[MEM:.*]]: memref<2x2x8x4xi16>) {
+// CHECK-SAME:      %[[MEM:.*]]: memref<2x2x8x4xi16>
 // CHECK:           %[[TR:.*]] = vector.transpose %[[VEC]], [1, 0] : vector<4x8xi16> to vector<8x4xi16>
 // CHECK:           vector.transfer_write
 // CHECK-NOT:       permutation_map
 // CHECK-SAME:      %[[TR]], %[[MEM]]{{.*}} {in_bounds = [true, true]} : vector<8x4xi16>, memref<2x2x8x4xi16>
 func.func @xfer_write_transposing_permutation_map(
     %vec: vector<4x8xi16>,
-    %mem: memref<2x2x8x4xi16>) {
+    %mem: memref<2x2x8x4xi16>,
+    %idx: index) {
 
-  %c0 = arith.constant 0 : index
-  vector.transfer_write %vec, %mem[%c0, %c0, %c0, %c0] {
+  vector.transfer_write %vec, %mem[%idx, %idx, %idx, %idx] {
     in_bounds = [true, true],
     permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
   } : vector<4x8xi16>, memref<2x2x8x4xi16>
@@ -31,24 +31,25 @@ func.func @xfer_write_transposing_permutation_map(
   return
 }
 
-// Even with out-of-bounds, it is safe to apply this pattern
+// Even with out-of-bounds accesses, it is safe to apply this pattern
+
 // CHECK-LABEL:   func.func @xfer_write_transposing_permutation_map_out_of_bounds
 // CHECK-SAME:      %[[VEC:.*]]: vector<4x8xi16>,
-// CHECK-SAME:      %[[MEM:.*]]: memref<2x2x?x?xi16>) {
-// CHECK:           %[[C0:.*]] = arith.constant 0 : index
+// CHECK-SAME:      %[[MEM:.*]]: memref<2x2x?x?xi16>,
+// CHECK-SAME:      %[[IDX:.*]]: index) {
 // CHECK:           %[[TR:.*]] = vector.transpose %[[VEC]], [1, 0] : vector<4x8xi16> to vector<8x4xi16>
 // Expect the in_bounds attribute to be preserved. Since we don't print it when
 // all flags are "false", it should not appear in the output.
 // CHECK-NOT:       in_bounds
 // CHECK:           vector.transfer_write
 // CHECK-NOT:       permutation_map
-// CHECK-SAME:      %[[TR]], %[[MEM]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] : vector<8x4xi16>, memref<2x2x?x?xi16>
+// CHECK-SAME:      %[[TR]], %[[MEM]][%[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]]] : vector<8x4xi16>, memref<2x2x?x?xi16>
 func.func @xfer_write_transposing_permutation_map_out_of_bounds(
     %vec: vector<4x8xi16>,
-    %mem: memref<2x2x?x?xi16>) {
+    %mem: memref<2x2x?x?xi16>,
+    %idx: index) {
 
-  %c0 = arith.constant 0 : index
-  vector.transfer_write %vec, %mem[%c0, %c0, %c0, %c0] {
+  vector.transfer_write %vec, %mem[%idx, %idx, %idx, %idx] {
     in_bounds = [false, false],
     permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
   } : vector<4x8xi16>, memref<2x2x?x?xi16>
@@ -59,7 +60,7 @@ func.func @xfer_write_transposing_permutation_map_out_of_bounds(
 // CHECK-LABEL:   func.func @xfer_write_transposing_permutation_map_with_mask_scalable
 // CHECK-SAME:      %[[VEC:.*]]: vector<4x[8]xi16>,
 // CHECK-SAME:      %[[MEM:.*]]: memref<2x2x?x4xi16>,
-// CHECK-SAME:      %[[MASK:.*]]: vector<[8]x4xi1>) {
+// CHECK-SAME:      %[[MASK:.*]]: vector<[8]x4xi1>
 // CHECK:           %[[TR:.*]] = vector.transpose %[[VEC]], [1, 0] : vector<4x[8]xi16> to vector<[8]x4xi16>
 // CHECK:           vector.transfer_write
 // CHECK-NOT:       permutation_map
@@ -67,10 +68,11 @@ func.func @xfer_write_transposing_permutation_map_out_of_bounds(
 func.func @xfer_write_transposing_permutation_map_with_mask_scalable(
     %vec: vector<4x[8]xi16>,
     %mem: memref<2x2x?x4xi16>,
-    %mask: vector<[8]x4xi1>) {
+    %mask: vector<[8]x4xi1>,
+    %idx: index) {
 
   %c0 = arith.constant 0 : index
-  vector.transfer_write %vec, %mem[%c0, %c0, %c0, %c0], %mask {
+  vector.transfer_write %vec, %mem[%idx, %idx, %idx, %idx], %mask {
     in_bounds = [true, true],
     permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
   } : vector<4x[8]xi16>, memref<2x2x?x4xi16>
@@ -79,16 +81,18 @@ func.func @xfer_write_transposing_permutation_map_with_mask_scalable(
 }
 
 // Masked version is not supported
+
 // CHECK-LABEL:   func.func @xfer_write_transposing_permutation_map_masked
 // CHECK-NOT: vector.transpose
 func.func @xfer_write_transposing_permutation_map_masked(
     %vec: vector<4x8xi16>,
     %mem: memref<2x2x8x4xi16>,
-    %mask: vector<8x4xi1>) {
+    %mask: vector<8x4xi1>,
+    %idx: index) {
 
   %c0 = arith.constant 0 : index
   vector.mask %mask {
-    vector.transfer_write %vec, %mem[%c0, %c0, %c0, %c0] {
+    vector.transfer_write %vec, %mem[%idx, %idx, %idx, %idx] {
       in_bounds = [true, true],
       permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
     } : vector<4x8xi16>, memref<2x2x8x4xi16>
@@ -128,7 +132,8 @@ func.func @xfer_write_non_transposing_permutation_map(
   return
 }
 
-// Even with out-of-bounds, it is safe to apply this pattern
+// Even with out-of-bounds accesses, it is safe to apply this pattern
+
 // CHECK-LABEL:   func.func @xfer_write_non_transposing_permutation_map_with_mask_out_of_bounds(
 // CHECK-SAME:      %[[MEM:.*]]: memref<?x?xf32>,
 // CHECK-SAME:      %[[VEC:.*]]: vector<7xf32>,
@@ -157,8 +162,7 @@ func.func @xfer_write_non_transposing_permutation_map_with_mask_out_of_bounds(
 // CHECK:           func.func @permutation_with_mask_xfer_write_scalable(
 // CHECK-SAME:        %[[VEC:.*]]: vector<4x[8]xi16>,
 // CHECK-SAME:        %[[MEM:.*]]: memref<1x4x?x1xi16>,
-// CHECK-SAME:        %[[MASK:.*]]: vector<4x[8]xi1>) {
-// CHECK:             %[[C0:.*]] = arith.constant 0 : index
+// CHECK-SAME:        %[[MASK:.*]]: vector<4x[8]xi1>
 // CHECK:             %[[BC_1:.*]] = vector.broadcast %[[VEC]] : vector<4x[8]xi16> to vector<1x4x[8]xi16>
 // CHECK:             %[[BC_2:.*]] = vector.broadcast %[[MASK]] : vector<4x[8]xi1> to vector<1x4x[8]xi1>
 // CHECK:             %[[TRANSPOSE_1:.*]] =  vector.transpose %[[BC_2]], [1, 2, 0] : vector<1x4x[8]xi1> to vector<4x[8]x1xi1>
@@ -167,10 +171,10 @@ func.func @xfer_write_non_transposing_permutation_map_with_mask_out_of_bounds(
 func.func @permutation_with_mask_xfer_write_scalable(
     %vec: vector<4x[8]xi16>,
     %mem: memref<1x4x?x1xi16>,
-    %mask: vector<4x[8]xi1>){
+    %mask: vector<4x[8]xi1>,
+    %idx: index){
 
-  %c0 = arith.constant 0 : index
-  vector.transfer_write %vec, %mem[%c0, %c0, %c0, %c0], %mask {
+  vector.transfer_write %vec, %mem[%idx, %idx, %idx, %idx], %mask {
     in_bounds = [true, true],
     permutation_map = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
   } : vector<4x[8]xi16>, memref<1x4x?x1xi16>
@@ -178,7 +182,8 @@ func.func @permutation_with_mask_xfer_write_scalable(
   return
 }
 
-// transfer_write in MaskOp case not supported.
+// Masked version is not supported
+
 // CHECK-LABEL: func @masked_permutation_xfer_write_fixed_width
 //  CHECK-SAME:   %[[DEST:.*]]: tensor<?x?xf32>,
 //  CHECK-SAME:   %[[VEC:.*]]: vector<16xf32>,
@@ -204,18 +209,19 @@ func.func @masked_permutation_xfer_write_fixed_width(
 // CHECK-LABEL: func.func @masked_permutation_xfer_write_scalable(
 //  CHECK-SAME:   %[[VEC:.*]]: vector<4x[8]xi16>,
 //  CHECK-SAME:   %[[DEST:.*]]: tensor<?x?x?x?xf32>,
-//  CHECK-SAME:   %[[MASK:.*]]: vector<4x[8]xi1>)
+//  CHECK-SAME:   %[[MASK:.*]]: vector<4x[8]xi1>
 //  CHECK-SAME:   -> tensor<?x?x?x?xf32> {
 //   CHECK-NOT:   vector.transpose
 //       CHECK:   vector.mask %[[MASK]] { vector.transfer_write %[[VEC]], %[[DEST]]{{.*}} : vector<4x[8]xi16>, tensor<?x?x?x?xf32> } : vector<4x[8]xi1> -> tensor<?x?x?x?xf32>
 func.func @masked_permutation_xfer_write_scalable(
     %vec: vector<4x[8]xi16>,
     %dest: tensor<?x?x?x?xf32>,
-    %mask:  vector<4x[8]xi1>) -> tensor<?x?x?x?xf32> {
+    %mask:  vector<4x[8]xi1>,
+    %idx: index) -> tensor<?x?x?x?xf32> {
 
   %c0 = arith.constant 0 : index
   %res = vector.mask %mask {
-    vector.transfer_write %vec, %dest[%c0, %c0, %c0, %c0] {
+    vector.transfer_write %vec, %dest[%idx, %idx, %idx, %idx] {
       in_bounds = [true, true],
       permutation_map = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
     } : vector<4x[8]xi16>, tensor<?x?x?x?xf32>
@@ -224,22 +230,23 @@ func.func @masked_permutation_xfer_write_scalable(
   return %res : tensor<?x?x?x?xf32>
 }
 
-// transfer_write in MaskOp case not supported.
+// Masked version is not supported
+
 // CHECK-LABEL: func @masked_non_permutation_xfer_write_fixed_width
 //  CHECK-SAME:   %[[DEST:.*]]: tensor<?x?x?x?xf32>
 //  CHECK-SAME:   %[[VEC:.*]]: vector<14x8x16xf32>
-//  CHECK-SAME:   %[[IDX:.*]]: index) -> tensor<?x?x?x?xf32>
+//  CHECK-SAME:   %[[DIM:.*]]: index, %[[IDX:.*]]: index) -> tensor<?x?x?x?xf32>
 //   CHECK-NOT:   vector.broadcast
 //       CHECK:   vector.mask %0 { vector.transfer_write %[[VEC]], %[[DEST]]{{.*}} : vector<14x8x16xf32>, tensor<?x?x?x?xf32> } : vector<14x8x16xi1> -> tensor<?x?x?x?xf32>
 func.func @masked_non_permutation_xfer_write_fixed_width(
     %dest : tensor<?x?x?x?xf32>,
     %vec : vector<14x8x16xf32>,
-    %dim : index) -> tensor<?x?x?x?xf32> {
+    %dim : index,
+    %idx: index) -> tensor<?x?x?x?xf32> {
 
-  %c0 = arith.constant 0 : index
   %mask = vector.create_mask %dim, %dim, %dim : vector<14x8x16xi1>
   %res = vector.mask %mask {
-    vector.transfer_write %vec, %dest[%c0, %c0, %c0, %c0] {
+    vector.transfer_write %vec, %dest[%idx, %idx, %idx, %idx] {
       in_bounds = [false, false, true],
       permutation_map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
     } : vector<14x8x16xf32>, tensor<?x?x?x?xf32>
@@ -259,25 +266,23 @@ func.func @masked_non_permutation_xfer_write_fixed_width(
 
 // CHECK-LABEL:   func.func @permutation_with_mask_xfer_read_fixed_width(
 // CHECK-SAME:      %[[MEM:.*]]: memref<?x?xf32>,
-// CHECK-SAME:      %[[IDX_1:.*]]: index,
-// CHECK-SAME:      %[[IDX_2:.*]]: index) -> vector<8x4x2xf32> {
-// CHECK:           %[[C0:.*]] = arith.constant 0 : index
+// CHECK-SAME:      %[[DIM_1:.*]]: index, %[[DIM_2:.*]]: index, %[[IDX:.*]]: index) -> vector<8x4x2xf32> {
 // CHECK:           %[[PASS_THROUGH:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK:           %[[MASK:.*]] = vector.create_mask %[[IDX_2]], %[[IDX_1]] : vector<2x4xi1>
-// CHECK:           %[[T_READ:.*]] = vector.transfer_read %[[MEM]]{{\[}}%[[C0]], %[[C0]]], %[[PASS_THROUGH]], %[[MASK]] {in_bounds = [true, true]} : memref<?x?xf32>, vector<2x4xf32>
+// CHECK:           %[[MASK:.*]] = vector.create_mask %[[DIM_2]], %[[DIM_1]] : vector<2x4xi1>
+// CHECK:           %[[T_READ:.*]] = vector.transfer_read %[[MEM]]{{\[}}%[[IDX]], %[[IDX]]], %[[PASS_THROUGH]], %[[MASK]] {in_bounds = [true, true]} : memref<?x?xf32>, vector<2x4xf32>
 // CHECK:           %[[BCAST:.*]] = vector.broadcast %[[T_READ]] : vector<2x4xf32> to vector<8x2x4xf32>
 // CHECK:           %[[TRANSPOSE:.*]] = vector.transpose %[[BCAST]], [0, 2, 1] : vector<8x2x4xf32> to vector<8x4x2xf32>
 // CHECK:           return %[[TRANSPOSE]] : vector<8x4x2xf32>
 func.func @permutation_with_mask_xfer_read_fixed_width(
     %mem: memref<?x?xf32>,
     %dim_1: index,
-    %dim_2: index) -> (vector<8x4x2xf32>) {
+    %dim_2: index,
+    %idx: index) -> (vector<8x4x2xf32>) {
 
-  %c0 = arith.constant 0 : index
-  %cst_0 = arith.constant 0.000000e+00 : f32
+  %pad = arith.constant 0.000000e+00 : f32
 
   %mask = vector.create_mask %dim_2, %dim_1 : vector<2x4xi1>
-  %res = vector.transfer_read %mem[%c0, %c0], %cst_0, %mask {
+  %res = vector.transfer_read %mem[%idx, %idx], %pad, %mask {
     in_bounds = [true, true, true],
     permutation_map = affine_map<(d0, d1) -> (0, d1, d0)>
   } : memref<?x?xf32>, vector<8x4x2xf32>
@@ -287,25 +292,23 @@ func.func @permutation_with_mask_xfer_read_fixed_width(
 
 // CHECK-LABEL:   func.func @permutation_with_mask_xfer_read_scalable(
 // CHECK-SAME:      %[[MEM:.*]]: memref<?x?xf32>,
-// CHECK-SAME:      %[[IDX_1:.*]]: index,
-// CHECK-SAME:      %[[IDX_2:.*]]: index) -> vector<8x[4]x2xf32> {
-// CHECK:           %[[C0:.*]] = arith.constant 0 : index
-// CHECK:           %[[PASS_THROUGH:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK:           %[[MASK:.*]] = vector.create_mask %[[IDX_2]], %[[IDX_1]] : vector<2x[4]xi1>
-// CHECK:           %[[T_READ:.*]] = vector.transfer_read %[[MEM]]{{\[}}%[[C0]], %[[C0]]], %[[PASS_THROUGH]], %[[MASK]] {in_bounds = [true, true]} : memref<?x?xf32>, vector<2x[4]xf32>
+// CHECK-SAME:      %[[DIM_1:.*]]: index, %[[DIM_2:.*]]: index, %[[IDX:.*]]: index) -> vector<8x[4]x2xf32> {
+// CHECK:           %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           %[[MASK:.*]] = vector.create_mask %[[DIM_2]], %[[DIM_1]] : vector<2x[4]xi1>
+// CHECK:           %[[T_READ:.*]] = vector.transfer_read %[[MEM]]{{\[}}%[[IDX]], %[[IDX]]], %[[PAD]], %[[MASK]] {in_bounds = [true, true]} : memref<?x?xf32>, vector<2x[4]xf32>
 // CHECK:           %[[BCAST:.*]] = vector.broadcast %[[T_READ]] : vector<2x[4]xf32> to vector<8x2x[4]xf32>
 // CHECK:           %[[TRANSPOSE:.*]] = vector.transpose %[[BCAST]], [0, 2, 1] : vector<8x2x[4]xf32> to vector<8x[4]x2xf32>
 // CHECK:           return %[[TRANSPOSE]] : vector<8x[4]x2xf32>
 func.func @permutation_with_mask_xfer_read_scalable(
     %mem: memref<?x?xf32>,
     %dim_1: index,
-    %dim_2: index) -> (vector<8x[4]x2xf32>) {
+    %dim_2: index,
+    %idx: index) -> (vector<8x[4]x2xf32>) {
 
-  %c0 = arith.constant 0 : index
-  %cst_0 = arith.constant 0.000000e+00 : f32
+  %pad = arith.constant 0.000000e+00 : f32
 
   %mask = vector.create_mask %dim_2, %dim_1 : vector<2x[4]xi1>
-  %res = vector.transfer_read %mem[%c0, %c0], %cst_0, %mask {
+  %res = vector.transfer_read %mem[%idx, %idx], %pad, %mask {
     in_bounds = [true, true, true],
     permutation_map = affine_map<(d0, d1) -> (0, d1, d0)>
   } : memref<?x?xf32>, vector<8x[4]x2xf32>
@@ -313,7 +316,8 @@ func.func @permutation_with_mask_xfer_read_scalable(
   return %res : vector<8x[4]x2xf32>
 }
 
-// transfer_read in MaskOp case not supported.
+// Masked version is not supported
+
 // CHECK-LABEL: func @masked_permutation_xfer_read_fixed_width
 //  CHECK-SAME:   %[[DEST:.*]]: tensor<?x1xf32>,
 //  CHECK-SAME:   %[[MASK:.*]]: vector<4x1xi1>
@@ -321,12 +325,12 @@ func.func @permutation_with_mask_xfer_read_scalable(
 //       CHECK:   vector.mask %[[MASK]] { vector.transfer_read %[[DEST]]{{.*}}: tensor<?x1xf32>, vector<1x4x4xf32> } : vector<4x1xi1> -> vector<1x4x4xf32>
 func.func @masked_permutation_xfer_read_fixed_width(
     %dest: tensor<?x1xf32>,
-    %mask : vector<4x1xi1>) {
+    %mask : vector<4x1xi1>,
+    %idx: index) {
 
-  %cst = arith.constant 0.000000e+00 : f32
-  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0.000000e+00 : f32
   %3 = vector.mask %mask {
-    vector.transfer_read %dest[%c0, %c0], %cst {
+    vector.transfer_read %dest[%idx, %idx], %pad {
       permutation_map = affine_map<(d0, d1) -> (d1, 0, d0)>
     } : tensor<?x1xf32>, vector<1x4x4xf32>
   } : vector<4x1xi1> -> vector<1x4x4xf32>
@@ -337,18 +341,18 @@ func.func @masked_permutation_xfer_read_fixed_width(
 
 // CHECK-LABEL:  func.func @masked_permutation_xfer_read_scalable(
 //  CHECK-SAME:    %[[DEST:.*]]: tensor<?x?xf32>,
-//  CHECK-SAME:    %[[MASK:.*]]: vector<2x[4]xi1>) -> vector<8x[4]x2xf32> {
+//  CHECK-SAME:    %[[MASK:.*]]: vector<2x[4]xi1>
 //   CHECK-NOT:    vector.transpose
 //       CHECK:    %[[T_READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[DEST]]{{.*}} : tensor<?x?xf32>, vector<8x[4]x2xf32> } : vector<2x[4]xi1> -> vector<8x[4]x2xf32>
 func.func @masked_permutation_xfer_read_scalable(
   %dest: tensor<?x?xf32>,
-  %mask : vector<2x[4]xi1>) -> vector<8x[4]x2xf32> {
+  %mask : vector<2x[4]xi1>,
+  %idx: index) -> vector<8x[4]x2xf32> {
 
-  %c0 = arith.constant 0 : index
-  %cst_0 = arith.constant 0.000000e+00 : f32
+  %pad = arith.constant 0.000000e+00 : f32
 
   %res = vector.mask %mask {
-    vector.transfer_read %dest[%c0, %c0], %cst_0 {
+    vector.transfer_read %dest[%idx, %idx], %pad {
       in_bounds = [true, true, true],
       permutation_map = affine_map<(d0, d1) -> (0, d1, d0)>
     } : tensor<?x?xf32>, vector<8x[4]x2xf32>
@@ -377,18 +381,16 @@ module attributes {transform.with_named_sequence} {
 
 //       CHECK:   #[[MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, 0, d3)>
 //       CHECK:   func.func @transfer_read_reduce_rank_scalable(
-//  CHECK-SAME:     %[[MEM:.*]]: memref<?x?x?x?xf32>) -> vector<8x[4]x2x3xf32> {
-//       CHECK:     %[[C0:.*]] = arith.constant 0 : index
-//       CHECK:     %[[T_READ:.*]] = vector.transfer_read %[[MEM]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]]{{.*}} permutation_map = #[[MAP]]} : memref<?x?x?x?xf32>, vector<[4]x2x3xf32>
+//  CHECK-SAME:     %[[MEM:.*]]: memref<?x?x?x?xf32>, %[[IDX:.*]]: index) -> vector<8x[4]x2x3xf32> {
+//       CHECK:     %[[T_READ:.*]] = vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]]]{{.*}} permutation_map = #[[MAP]]} : memref<?x?x?x?xf32>, vector<[4]x2x3xf32>
 //       CHECK:     %[[BC:.*]] = vector.broadcast %[[T_READ]] : vector<[4]x2x3xf32> to vector<8x[4]x2x3xf32>
 //       CHECK:     return %[[BC]] : vector<8x[4]x2x3xf32>
 func.func @transfer_read_reduce_rank_scalable(
-    %mem: memref<?x?x?x?xf32>) -> vector<8x[4]x2x3xf32> {
+    %mem: memref<?x?x?x?xf32>, %idx: index) -> vector<8x[4]x2x3xf32> {
 
-  %c0 = arith.constant 0 : index
-  %cst_0 = arith.constant 0.000000e+00 : f32
+  %pad = arith.constant 0.000000e+00 : f32
 
-  %res = vector.transfer_read %mem[%c0, %c0, %c0, %c0], %cst_0 {
+  %res = vector.transfer_read %mem[%idx, %idx, %idx, %idx], %pad {
     in_bounds = [true, true, true, true],
     permutation_map = affine_map<(d0, d1, d2, d3) -> (0, d1, 0, d3)>
   } : memref<?x?x?x?xf32>, vector<8x[4]x2x3xf32>
@@ -396,22 +398,24 @@ func.func @transfer_read_reduce_rank_scalable(
   return %res : vector<8x[4]x2x3xf32>
 }
 
-// Masked case not supported.
+// Masked version is not supported
+
 // CHECK-LABEL:   func.func @masked_transfer_read_reduce_rank(
 //  CHECK-SAME:     %[[MEM:.*]]: memref<?x?x?x?xf32>,
-//  CHECK-SAME:     %[[DIM:.*]]: index) -> vector<8x[4]x2x3xf32> {
+//  CHECK-SAME:     %[[DIM:.*]]: index,
+//  CHECK-SAME:     %[[IDX:.*]]: index) -> vector<8x[4]x2x3xf32> {
 //   CHECK-NOT:     vector.broadcast
 //       CHECK:     %[[MASK:.*]] = vector.mask %0 { vector.transfer_read %[[MEM]]{{.*}} : memref<?x?x?x?xf32>, vector<8x[4]x2x3xf32> } : vector<[4]x3xi1> -> vector<8x[4]x2x3xf32>
 func.func @masked_transfer_read_reduce_rank(
     %mem: memref<?x?x?x?xf32>,
-    %dim: index) -> vector<8x[4]x2x3xf32> {
+    %dim: index,
+    %idx: index) -> vector<8x[4]x2x3xf32> {
 
-  %c0 = arith.constant 0 : index
-  %cst_0 = arith.constant 0.000000e+00 : f32
+  %pad = arith.constant 0.000000e+00 : f32
   %mask = vector.create_mask %dim, %dim: vector<[4]x3xi1>
 
   %res = vector.mask %mask {
-    vector.transfer_read %mem[%c0, %c0, %c0, %c0], %cst_0 {
+    vector.transfer_read %mem[%idx, %idx, %idx, %idx], %pad {
       in_bounds = [true, true, true, true],
       permutation_map = affine_map<(d0, d1, d2, d3) -> (0, d1, 0, d3)>
     } : memref<?x?x?x?xf32>, vector<8x[4]x2x3xf32>

@llvmbot
Copy link
Member

llvmbot commented Jan 15, 2025

@llvm/pr-subscribers-mlir

Author: Andrzej Warzyński (banach-space)

Changes
  1. Remove %c0 = arith.constant 0 : index from testt functions. This
    extra Op is not needed (the index can be passed as an argument), so
    this is just noise.
  2. Replaced %cst_0 with %pad to communicate what the underlying SSA
    value is intended for.
  3. Unified some comments.

Full diff: https://github.com/llvm/llvm-project/pull/123076.diff

1 Files Affected:

  • (modified) mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir (+77-73)
diff --git a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
index 0feaf690af2510..045d4a9cdb5ddb 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
@@ -13,17 +13,17 @@
 
 // CHECK-LABEL:   func.func @xfer_write_transposing_permutation_map
 // CHECK-SAME:      %[[VEC:.*]]: vector<4x8xi16>,
-// CHECK-SAME:      %[[MEM:.*]]: memref<2x2x8x4xi16>) {
+// CHECK-SAME:      %[[MEM:.*]]: memref<2x2x8x4xi16>
 // CHECK:           %[[TR:.*]] = vector.transpose %[[VEC]], [1, 0] : vector<4x8xi16> to vector<8x4xi16>
 // CHECK:           vector.transfer_write
 // CHECK-NOT:       permutation_map
 // CHECK-SAME:      %[[TR]], %[[MEM]]{{.*}} {in_bounds = [true, true]} : vector<8x4xi16>, memref<2x2x8x4xi16>
 func.func @xfer_write_transposing_permutation_map(
     %vec: vector<4x8xi16>,
-    %mem: memref<2x2x8x4xi16>) {
+    %mem: memref<2x2x8x4xi16>,
+    %idx: index) {
 
-  %c0 = arith.constant 0 : index
-  vector.transfer_write %vec, %mem[%c0, %c0, %c0, %c0] {
+  vector.transfer_write %vec, %mem[%idx, %idx, %idx, %idx] {
     in_bounds = [true, true],
     permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
   } : vector<4x8xi16>, memref<2x2x8x4xi16>
@@ -31,24 +31,25 @@ func.func @xfer_write_transposing_permutation_map(
   return
 }
 
-// Even with out-of-bounds, it is safe to apply this pattern
+// Even with out-of-bounds accesses, it is safe to apply this pattern
+
 // CHECK-LABEL:   func.func @xfer_write_transposing_permutation_map_out_of_bounds
 // CHECK-SAME:      %[[VEC:.*]]: vector<4x8xi16>,
-// CHECK-SAME:      %[[MEM:.*]]: memref<2x2x?x?xi16>) {
-// CHECK:           %[[C0:.*]] = arith.constant 0 : index
+// CHECK-SAME:      %[[MEM:.*]]: memref<2x2x?x?xi16>,
+// CHECK-SAME:      %[[IDX:.*]]: index) {
 // CHECK:           %[[TR:.*]] = vector.transpose %[[VEC]], [1, 0] : vector<4x8xi16> to vector<8x4xi16>
 // Expect the in_bounds attribute to be preserved. Since we don't print it when
 // all flags are "false", it should not appear in the output.
 // CHECK-NOT:       in_bounds
 // CHECK:           vector.transfer_write
 // CHECK-NOT:       permutation_map
-// CHECK-SAME:      %[[TR]], %[[MEM]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] : vector<8x4xi16>, memref<2x2x?x?xi16>
+// CHECK-SAME:      %[[TR]], %[[MEM]][%[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]]] : vector<8x4xi16>, memref<2x2x?x?xi16>
 func.func @xfer_write_transposing_permutation_map_out_of_bounds(
     %vec: vector<4x8xi16>,
-    %mem: memref<2x2x?x?xi16>) {
+    %mem: memref<2x2x?x?xi16>,
+    %idx: index) {
 
-  %c0 = arith.constant 0 : index
-  vector.transfer_write %vec, %mem[%c0, %c0, %c0, %c0] {
+  vector.transfer_write %vec, %mem[%idx, %idx, %idx, %idx] {
     in_bounds = [false, false],
     permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
   } : vector<4x8xi16>, memref<2x2x?x?xi16>
@@ -59,7 +60,7 @@ func.func @xfer_write_transposing_permutation_map_out_of_bounds(
 // CHECK-LABEL:   func.func @xfer_write_transposing_permutation_map_with_mask_scalable
 // CHECK-SAME:      %[[VEC:.*]]: vector<4x[8]xi16>,
 // CHECK-SAME:      %[[MEM:.*]]: memref<2x2x?x4xi16>,
-// CHECK-SAME:      %[[MASK:.*]]: vector<[8]x4xi1>) {
+// CHECK-SAME:      %[[MASK:.*]]: vector<[8]x4xi1>
 // CHECK:           %[[TR:.*]] = vector.transpose %[[VEC]], [1, 0] : vector<4x[8]xi16> to vector<[8]x4xi16>
 // CHECK:           vector.transfer_write
 // CHECK-NOT:       permutation_map
@@ -67,10 +68,11 @@ func.func @xfer_write_transposing_permutation_map_out_of_bounds(
 func.func @xfer_write_transposing_permutation_map_with_mask_scalable(
     %vec: vector<4x[8]xi16>,
     %mem: memref<2x2x?x4xi16>,
-    %mask: vector<[8]x4xi1>) {
+    %mask: vector<[8]x4xi1>,
+    %idx: index) {
 
   %c0 = arith.constant 0 : index
-  vector.transfer_write %vec, %mem[%c0, %c0, %c0, %c0], %mask {
+  vector.transfer_write %vec, %mem[%idx, %idx, %idx, %idx], %mask {
     in_bounds = [true, true],
     permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
   } : vector<4x[8]xi16>, memref<2x2x?x4xi16>
@@ -79,16 +81,18 @@ func.func @xfer_write_transposing_permutation_map_with_mask_scalable(
 }
 
 // Masked version is not supported
+
 // CHECK-LABEL:   func.func @xfer_write_transposing_permutation_map_masked
 // CHECK-NOT: vector.transpose
 func.func @xfer_write_transposing_permutation_map_masked(
     %vec: vector<4x8xi16>,
     %mem: memref<2x2x8x4xi16>,
-    %mask: vector<8x4xi1>) {
+    %mask: vector<8x4xi1>,
+    %idx: index) {
 
   %c0 = arith.constant 0 : index
   vector.mask %mask {
-    vector.transfer_write %vec, %mem[%c0, %c0, %c0, %c0] {
+    vector.transfer_write %vec, %mem[%idx, %idx, %idx, %idx] {
       in_bounds = [true, true],
       permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
     } : vector<4x8xi16>, memref<2x2x8x4xi16>
@@ -128,7 +132,8 @@ func.func @xfer_write_non_transposing_permutation_map(
   return
 }
 
-// Even with out-of-bounds, it is safe to apply this pattern
+// Even with out-of-bounds accesses, it is safe to apply this pattern
+
 // CHECK-LABEL:   func.func @xfer_write_non_transposing_permutation_map_with_mask_out_of_bounds(
 // CHECK-SAME:      %[[MEM:.*]]: memref<?x?xf32>,
 // CHECK-SAME:      %[[VEC:.*]]: vector<7xf32>,
@@ -157,8 +162,7 @@ func.func @xfer_write_non_transposing_permutation_map_with_mask_out_of_bounds(
 // CHECK:           func.func @permutation_with_mask_xfer_write_scalable(
 // CHECK-SAME:        %[[VEC:.*]]: vector<4x[8]xi16>,
 // CHECK-SAME:        %[[MEM:.*]]: memref<1x4x?x1xi16>,
-// CHECK-SAME:        %[[MASK:.*]]: vector<4x[8]xi1>) {
-// CHECK:             %[[C0:.*]] = arith.constant 0 : index
+// CHECK-SAME:        %[[MASK:.*]]: vector<4x[8]xi1>
 // CHECK:             %[[BC_1:.*]] = vector.broadcast %[[VEC]] : vector<4x[8]xi16> to vector<1x4x[8]xi16>
 // CHECK:             %[[BC_2:.*]] = vector.broadcast %[[MASK]] : vector<4x[8]xi1> to vector<1x4x[8]xi1>
 // CHECK:             %[[TRANSPOSE_1:.*]] =  vector.transpose %[[BC_2]], [1, 2, 0] : vector<1x4x[8]xi1> to vector<4x[8]x1xi1>
@@ -167,10 +171,10 @@ func.func @xfer_write_non_transposing_permutation_map_with_mask_out_of_bounds(
 func.func @permutation_with_mask_xfer_write_scalable(
     %vec: vector<4x[8]xi16>,
     %mem: memref<1x4x?x1xi16>,
-    %mask: vector<4x[8]xi1>){
+    %mask: vector<4x[8]xi1>,
+    %idx: index){
 
-  %c0 = arith.constant 0 : index
-  vector.transfer_write %vec, %mem[%c0, %c0, %c0, %c0], %mask {
+  vector.transfer_write %vec, %mem[%idx, %idx, %idx, %idx], %mask {
     in_bounds = [true, true],
     permutation_map = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
   } : vector<4x[8]xi16>, memref<1x4x?x1xi16>
@@ -178,7 +182,8 @@ func.func @permutation_with_mask_xfer_write_scalable(
   return
 }
 
-// transfer_write in MaskOp case not supported.
+// Masked version is not supported
+
 // CHECK-LABEL: func @masked_permutation_xfer_write_fixed_width
 //  CHECK-SAME:   %[[DEST:.*]]: tensor<?x?xf32>,
 //  CHECK-SAME:   %[[VEC:.*]]: vector<16xf32>,
@@ -204,18 +209,19 @@ func.func @masked_permutation_xfer_write_fixed_width(
 // CHECK-LABEL: func.func @masked_permutation_xfer_write_scalable(
 //  CHECK-SAME:   %[[VEC:.*]]: vector<4x[8]xi16>,
 //  CHECK-SAME:   %[[DEST:.*]]: tensor<?x?x?x?xf32>,
-//  CHECK-SAME:   %[[MASK:.*]]: vector<4x[8]xi1>)
+//  CHECK-SAME:   %[[MASK:.*]]: vector<4x[8]xi1>
 //  CHECK-SAME:   -> tensor<?x?x?x?xf32> {
 //   CHECK-NOT:   vector.transpose
 //       CHECK:   vector.mask %[[MASK]] { vector.transfer_write %[[VEC]], %[[DEST]]{{.*}} : vector<4x[8]xi16>, tensor<?x?x?x?xf32> } : vector<4x[8]xi1> -> tensor<?x?x?x?xf32>
 func.func @masked_permutation_xfer_write_scalable(
     %vec: vector<4x[8]xi16>,
     %dest: tensor<?x?x?x?xf32>,
-    %mask:  vector<4x[8]xi1>) -> tensor<?x?x?x?xf32> {
+    %mask:  vector<4x[8]xi1>,
+    %idx: index) -> tensor<?x?x?x?xf32> {
 
   %c0 = arith.constant 0 : index
   %res = vector.mask %mask {
-    vector.transfer_write %vec, %dest[%c0, %c0, %c0, %c0] {
+    vector.transfer_write %vec, %dest[%idx, %idx, %idx, %idx] {
       in_bounds = [true, true],
       permutation_map = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
     } : vector<4x[8]xi16>, tensor<?x?x?x?xf32>
@@ -224,22 +230,23 @@ func.func @masked_permutation_xfer_write_scalable(
   return %res : tensor<?x?x?x?xf32>
 }
 
-// transfer_write in MaskOp case not supported.
+// Masked version is not supported
+
 // CHECK-LABEL: func @masked_non_permutation_xfer_write_fixed_width
 //  CHECK-SAME:   %[[DEST:.*]]: tensor<?x?x?x?xf32>
 //  CHECK-SAME:   %[[VEC:.*]]: vector<14x8x16xf32>
-//  CHECK-SAME:   %[[IDX:.*]]: index) -> tensor<?x?x?x?xf32>
+//  CHECK-SAME:   %[[DIM:.*]]: index, %[[IDX:.*]]: index) -> tensor<?x?x?x?xf32>
 //   CHECK-NOT:   vector.broadcast
 //       CHECK:   vector.mask %0 { vector.transfer_write %[[VEC]], %[[DEST]]{{.*}} : vector<14x8x16xf32>, tensor<?x?x?x?xf32> } : vector<14x8x16xi1> -> tensor<?x?x?x?xf32>
 func.func @masked_non_permutation_xfer_write_fixed_width(
     %dest : tensor<?x?x?x?xf32>,
     %vec : vector<14x8x16xf32>,
-    %dim : index) -> tensor<?x?x?x?xf32> {
+    %dim : index,
+    %idx: index) -> tensor<?x?x?x?xf32> {
 
-  %c0 = arith.constant 0 : index
   %mask = vector.create_mask %dim, %dim, %dim : vector<14x8x16xi1>
   %res = vector.mask %mask {
-    vector.transfer_write %vec, %dest[%c0, %c0, %c0, %c0] {
+    vector.transfer_write %vec, %dest[%idx, %idx, %idx, %idx] {
       in_bounds = [false, false, true],
       permutation_map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
     } : vector<14x8x16xf32>, tensor<?x?x?x?xf32>
@@ -259,25 +266,23 @@ func.func @masked_non_permutation_xfer_write_fixed_width(
 
 // CHECK-LABEL:   func.func @permutation_with_mask_xfer_read_fixed_width(
 // CHECK-SAME:      %[[MEM:.*]]: memref<?x?xf32>,
-// CHECK-SAME:      %[[IDX_1:.*]]: index,
-// CHECK-SAME:      %[[IDX_2:.*]]: index) -> vector<8x4x2xf32> {
-// CHECK:           %[[C0:.*]] = arith.constant 0 : index
+// CHECK-SAME:      %[[DIM_1:.*]]: index, %[[DIM_2:.*]]: index, %[[IDX:.*]]: index) -> vector<8x4x2xf32> {
 // CHECK:           %[[PASS_THROUGH:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK:           %[[MASK:.*]] = vector.create_mask %[[IDX_2]], %[[IDX_1]] : vector<2x4xi1>
-// CHECK:           %[[T_READ:.*]] = vector.transfer_read %[[MEM]]{{\[}}%[[C0]], %[[C0]]], %[[PASS_THROUGH]], %[[MASK]] {in_bounds = [true, true]} : memref<?x?xf32>, vector<2x4xf32>
+// CHECK:           %[[MASK:.*]] = vector.create_mask %[[DIM_2]], %[[DIM_1]] : vector<2x4xi1>
+// CHECK:           %[[T_READ:.*]] = vector.transfer_read %[[MEM]]{{\[}}%[[IDX]], %[[IDX]]], %[[PASS_THROUGH]], %[[MASK]] {in_bounds = [true, true]} : memref<?x?xf32>, vector<2x4xf32>
 // CHECK:           %[[BCAST:.*]] = vector.broadcast %[[T_READ]] : vector<2x4xf32> to vector<8x2x4xf32>
 // CHECK:           %[[TRANSPOSE:.*]] = vector.transpose %[[BCAST]], [0, 2, 1] : vector<8x2x4xf32> to vector<8x4x2xf32>
 // CHECK:           return %[[TRANSPOSE]] : vector<8x4x2xf32>
 func.func @permutation_with_mask_xfer_read_fixed_width(
     %mem: memref<?x?xf32>,
     %dim_1: index,
-    %dim_2: index) -> (vector<8x4x2xf32>) {
+    %dim_2: index,
+    %idx: index) -> (vector<8x4x2xf32>) {
 
-  %c0 = arith.constant 0 : index
-  %cst_0 = arith.constant 0.000000e+00 : f32
+  %pad = arith.constant 0.000000e+00 : f32
 
   %mask = vector.create_mask %dim_2, %dim_1 : vector<2x4xi1>
-  %res = vector.transfer_read %mem[%c0, %c0], %cst_0, %mask {
+  %res = vector.transfer_read %mem[%idx, %idx], %pad, %mask {
     in_bounds = [true, true, true],
     permutation_map = affine_map<(d0, d1) -> (0, d1, d0)>
   } : memref<?x?xf32>, vector<8x4x2xf32>
@@ -287,25 +292,23 @@ func.func @permutation_with_mask_xfer_read_fixed_width(
 
 // CHECK-LABEL:   func.func @permutation_with_mask_xfer_read_scalable(
 // CHECK-SAME:      %[[MEM:.*]]: memref<?x?xf32>,
-// CHECK-SAME:      %[[IDX_1:.*]]: index,
-// CHECK-SAME:      %[[IDX_2:.*]]: index) -> vector<8x[4]x2xf32> {
-// CHECK:           %[[C0:.*]] = arith.constant 0 : index
-// CHECK:           %[[PASS_THROUGH:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK:           %[[MASK:.*]] = vector.create_mask %[[IDX_2]], %[[IDX_1]] : vector<2x[4]xi1>
-// CHECK:           %[[T_READ:.*]] = vector.transfer_read %[[MEM]]{{\[}}%[[C0]], %[[C0]]], %[[PASS_THROUGH]], %[[MASK]] {in_bounds = [true, true]} : memref<?x?xf32>, vector<2x[4]xf32>
+// CHECK-SAME:      %[[DIM_1:.*]]: index, %[[DIM_2:.*]]: index, %[[IDX:.*]]: index) -> vector<8x[4]x2xf32> {
+// CHECK:           %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           %[[MASK:.*]] = vector.create_mask %[[DIM_2]], %[[DIM_1]] : vector<2x[4]xi1>
+// CHECK:           %[[T_READ:.*]] = vector.transfer_read %[[MEM]]{{\[}}%[[IDX]], %[[IDX]]], %[[PAD]], %[[MASK]] {in_bounds = [true, true]} : memref<?x?xf32>, vector<2x[4]xf32>
 // CHECK:           %[[BCAST:.*]] = vector.broadcast %[[T_READ]] : vector<2x[4]xf32> to vector<8x2x[4]xf32>
 // CHECK:           %[[TRANSPOSE:.*]] = vector.transpose %[[BCAST]], [0, 2, 1] : vector<8x2x[4]xf32> to vector<8x[4]x2xf32>
 // CHECK:           return %[[TRANSPOSE]] : vector<8x[4]x2xf32>
 func.func @permutation_with_mask_xfer_read_scalable(
     %mem: memref<?x?xf32>,
     %dim_1: index,
-    %dim_2: index) -> (vector<8x[4]x2xf32>) {
+    %dim_2: index,
+    %idx: index) -> (vector<8x[4]x2xf32>) {
 
-  %c0 = arith.constant 0 : index
-  %cst_0 = arith.constant 0.000000e+00 : f32
+  %pad = arith.constant 0.000000e+00 : f32
 
   %mask = vector.create_mask %dim_2, %dim_1 : vector<2x[4]xi1>
-  %res = vector.transfer_read %mem[%c0, %c0], %cst_0, %mask {
+  %res = vector.transfer_read %mem[%idx, %idx], %pad, %mask {
     in_bounds = [true, true, true],
     permutation_map = affine_map<(d0, d1) -> (0, d1, d0)>
   } : memref<?x?xf32>, vector<8x[4]x2xf32>
@@ -313,7 +316,8 @@ func.func @permutation_with_mask_xfer_read_scalable(
   return %res : vector<8x[4]x2xf32>
 }
 
-// transfer_read in MaskOp case not supported.
+// Masked version is not supported
+
 // CHECK-LABEL: func @masked_permutation_xfer_read_fixed_width
 //  CHECK-SAME:   %[[DEST:.*]]: tensor<?x1xf32>,
 //  CHECK-SAME:   %[[MASK:.*]]: vector<4x1xi1>
@@ -321,12 +325,12 @@ func.func @permutation_with_mask_xfer_read_scalable(
 //       CHECK:   vector.mask %[[MASK]] { vector.transfer_read %[[DEST]]{{.*}}: tensor<?x1xf32>, vector<1x4x4xf32> } : vector<4x1xi1> -> vector<1x4x4xf32>
 func.func @masked_permutation_xfer_read_fixed_width(
     %dest: tensor<?x1xf32>,
-    %mask : vector<4x1xi1>) {
+    %mask : vector<4x1xi1>,
+    %idx: index) {
 
-  %cst = arith.constant 0.000000e+00 : f32
-  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0.000000e+00 : f32
   %3 = vector.mask %mask {
-    vector.transfer_read %dest[%c0, %c0], %cst {
+    vector.transfer_read %dest[%idx, %idx], %pad {
       permutation_map = affine_map<(d0, d1) -> (d1, 0, d0)>
     } : tensor<?x1xf32>, vector<1x4x4xf32>
   } : vector<4x1xi1> -> vector<1x4x4xf32>
@@ -337,18 +341,18 @@ func.func @masked_permutation_xfer_read_fixed_width(
 
 // CHECK-LABEL:  func.func @masked_permutation_xfer_read_scalable(
 //  CHECK-SAME:    %[[DEST:.*]]: tensor<?x?xf32>,
-//  CHECK-SAME:    %[[MASK:.*]]: vector<2x[4]xi1>) -> vector<8x[4]x2xf32> {
+//  CHECK-SAME:    %[[MASK:.*]]: vector<2x[4]xi1>
 //   CHECK-NOT:    vector.transpose
 //       CHECK:    %[[T_READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[DEST]]{{.*}} : tensor<?x?xf32>, vector<8x[4]x2xf32> } : vector<2x[4]xi1> -> vector<8x[4]x2xf32>
 func.func @masked_permutation_xfer_read_scalable(
   %dest: tensor<?x?xf32>,
-  %mask : vector<2x[4]xi1>) -> vector<8x[4]x2xf32> {
+  %mask : vector<2x[4]xi1>,
+  %idx: index) -> vector<8x[4]x2xf32> {
 
-  %c0 = arith.constant 0 : index
-  %cst_0 = arith.constant 0.000000e+00 : f32
+  %pad = arith.constant 0.000000e+00 : f32
 
   %res = vector.mask %mask {
-    vector.transfer_read %dest[%c0, %c0], %cst_0 {
+    vector.transfer_read %dest[%idx, %idx], %pad {
       in_bounds = [true, true, true],
       permutation_map = affine_map<(d0, d1) -> (0, d1, d0)>
     } : tensor<?x?xf32>, vector<8x[4]x2xf32>
@@ -377,18 +381,16 @@ module attributes {transform.with_named_sequence} {
 
 //       CHECK:   #[[MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, 0, d3)>
 //       CHECK:   func.func @transfer_read_reduce_rank_scalable(
-//  CHECK-SAME:     %[[MEM:.*]]: memref<?x?x?x?xf32>) -> vector<8x[4]x2x3xf32> {
-//       CHECK:     %[[C0:.*]] = arith.constant 0 : index
-//       CHECK:     %[[T_READ:.*]] = vector.transfer_read %[[MEM]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]]{{.*}} permutation_map = #[[MAP]]} : memref<?x?x?x?xf32>, vector<[4]x2x3xf32>
+//  CHECK-SAME:     %[[MEM:.*]]: memref<?x?x?x?xf32>, %[[IDX:.*]]: index) -> vector<8x[4]x2x3xf32> {
+//       CHECK:     %[[T_READ:.*]] = vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]]]{{.*}} permutation_map = #[[MAP]]} : memref<?x?x?x?xf32>, vector<[4]x2x3xf32>
 //       CHECK:     %[[BC:.*]] = vector.broadcast %[[T_READ]] : vector<[4]x2x3xf32> to vector<8x[4]x2x3xf32>
 //       CHECK:     return %[[BC]] : vector<8x[4]x2x3xf32>
 func.func @transfer_read_reduce_rank_scalable(
-    %mem: memref<?x?x?x?xf32>) -> vector<8x[4]x2x3xf32> {
+    %mem: memref<?x?x?x?xf32>, %idx: index) -> vector<8x[4]x2x3xf32> {
 
-  %c0 = arith.constant 0 : index
-  %cst_0 = arith.constant 0.000000e+00 : f32
+  %pad = arith.constant 0.000000e+00 : f32
 
-  %res = vector.transfer_read %mem[%c0, %c0, %c0, %c0], %cst_0 {
+  %res = vector.transfer_read %mem[%idx, %idx, %idx, %idx], %pad {
     in_bounds = [true, true, true, true],
     permutation_map = affine_map<(d0, d1, d2, d3) -> (0, d1, 0, d3)>
   } : memref<?x?x?x?xf32>, vector<8x[4]x2x3xf32>
@@ -396,22 +398,24 @@ func.func @transfer_read_reduce_rank_scalable(
   return %res : vector<8x[4]x2x3xf32>
 }
 
-// Masked case not supported.
+// Masked version is not supported
+
 // CHECK-LABEL:   func.func @masked_transfer_read_reduce_rank(
 //  CHECK-SAME:     %[[MEM:.*]]: memref<?x?x?x?xf32>,
-//  CHECK-SAME:     %[[DIM:.*]]: index) -> vector<8x[4]x2x3xf32> {
+//  CHECK-SAME:     %[[DIM:.*]]: index,
+//  CHECK-SAME:     %[[IDX:.*]]: index) -> vector<8x[4]x2x3xf32> {
 //   CHECK-NOT:     vector.broadcast
 //       CHECK:     %[[MASK:.*]] = vector.mask %0 { vector.transfer_read %[[MEM]]{{.*}} : memref<?x?x?x?xf32>, vector<8x[4]x2x3xf32> } : vector<[4]x3xi1> -> vector<8x[4]x2x3xf32>
 func.func @masked_transfer_read_reduce_rank(
     %mem: memref<?x?x?x?xf32>,
-    %dim: index) -> vector<8x[4]x2x3xf32> {
+    %dim: index,
+    %idx: index) -> vector<8x[4]x2x3xf32> {
 
-  %c0 = arith.constant 0 : index
-  %cst_0 = arith.constant 0.000000e+00 : f32
+  %pad = arith.constant 0.000000e+00 : f32
   %mask = vector.create_mask %dim, %dim: vector<[4]x3xi1>
 
   %res = vector.mask %mask {
-    vector.transfer_read %mem[%c0, %c0, %c0, %c0], %cst_0 {
+    vector.transfer_read %mem[%idx, %idx, %idx, %idx], %pad {
       in_bounds = [true, true, true, true],
       permutation_map = affine_map<(d0, d1, d2, d3) -> (0, d1, 0, d3)>
     } : memref<?x?x?x?xf32>, vector<8x[4]x2x3xf32>

Copy link
Contributor

@nujaa nujaa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, Thanks. Since you are at it, please consider those extra couple of NITs. 😃

@@ -377,41 +381,41 @@ module attributes {transform.with_named_sequence} {

// CHECK: #[[MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, 0, d3)>
// CHECK: func.func @transfer_read_reduce_rank_scalable(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: This line deserves a CHECK-LABEL.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noticed it as well for permutation_with_mask_xfer_write_scalable.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressing this in #123237. Alongside other improvements :)

@banach-space banach-space changed the title [mlir][vector] Update tests for xfer permutation lowering [mlir][vector] Update tests for xfer permutation lowering (1/N) Jan 16, 2025
@banach-space banach-space merged commit 3b001db into llvm:main Jan 20, 2025
12 checks passed
@banach-space banach-space deleted the andrzej/update_xfer_perm_tests branch January 20, 2025 14:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants