@@ -1139,7 +1139,7 @@ func.func @fold_collapse_of_expand(%arg0 : tensor<12x4xf32>) -> tensor<12x4xf32>
1139
1139
return %1 : tensor <12 x4 xf32 >
1140
1140
}
1141
1141
// CHECK-LABEL: @fold_collapse_of_expand
1142
- // CHECK-NOT: linalg .{{.*}}shape
1142
+ // CHECK-NOT: tensor .{{.*}}_shape
1143
1143
1144
1144
// -----
1145
1145
@@ -1152,7 +1152,75 @@ func.func @fold_collapse_of_expand_dynamic(%arg0 : tensor<?x?xf32>, %arg1: index
1152
1152
return %1 : tensor <?x?xf32 >
1153
1153
}
1154
1154
// CHECK-LABEL: @fold_collapse_of_expand_dynamic
1155
- // CHECK-NOT: linalg.{{.*}}_shape
1155
+ // CHECK-NOT: tensor.{{.*}}_shape
1156
+
1157
+ // -----
1158
+
1159
+ func.func @fold_collapse_of_expand_fully_dynamic (%arg0 : tensor <?x?xf32 >, %arg1: index , %arg2: index , %arg3: index )
1160
+ -> tensor <?x?xf32 > {
1161
+ %0 = tensor.expand_shape %arg0 [[0 , 1 ], [2 ]] output_shape [%arg1 , %arg2 , %arg3 ]
1162
+ : tensor <?x?xf32 > into tensor <?x?x?xf32 >
1163
+ %1 = tensor.collapse_shape %0 [[0 , 1 ], [2 ]]
1164
+ : tensor <?x?x?xf32 > into tensor <?x?xf32 >
1165
+ return %1 : tensor <?x?xf32 >
1166
+ }
1167
+ // CHECK-LABEL: @fold_collapse_of_expand_fully_dynamic
1168
+ // CHECK-NOT: tensor.{{.*}}_shape
1169
+
1170
+ // -----
1171
+
1172
+ func.func @no_fold_parallel_collapse_of_expand_dynamic (%arg0 : tensor <?x?x?xf32 >, %arg1: index , %arg2: index , %arg3: index , %arg4: index )
1173
+ -> tensor <?x?x?xf32 > {
1174
+ %0 = tensor.expand_shape %arg0 [[0 , 1 ], [2 ], [3 ]] output_shape [%arg1 , %arg2 , %arg3 , %arg4 ]
1175
+ : tensor <?x?x?xf32 > into tensor <?x?x?x?xf32 >
1176
+ %1 = tensor.collapse_shape %0 [[0 ], [1 ], [2 , 3 ]]
1177
+ : tensor <?x?x?x?xf32 > into tensor <?x?x?xf32 >
1178
+ return %1 : tensor <?x?x?xf32 >
1179
+ }
1180
+ // CHECK-LABEL: @no_fold_parallel_collapse_of_expand_dynamic
1181
+ // CHECK: tensor.expand_shape
1182
+ // CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape
1183
+ // CHECK: return %[[COLLAPSE]]
1184
+
1185
+ // -----
1186
+
1187
+ func.func @fold_expand_of_collapse (%arg0 : tensor <3 x4 x4 xf32 >) -> tensor <3 x4 x4 xf32 > {
1188
+ %0 = tensor.collapse_shape %arg0 [[0 , 1 ], [2 ]]
1189
+ : tensor <3 x4 x4 xf32 > into tensor <12 x4 xf32 >
1190
+ %1 = tensor.expand_shape %0 [[0 , 1 ], [2 ]] output_shape [3 , 4 , 4 ]
1191
+ : tensor <12 x4 xf32 > into tensor <3 x4 x4 xf32 >
1192
+ return %1 : tensor <3 x4 x4 xf32 >
1193
+ }
1194
+ // CHECK-LABEL: @fold_expand_of_collapse
1195
+ // CHECK-NOT: tensor.{{.*}}_shape
1196
+
1197
+ // -----
1198
+
1199
+ func.func @fold_expand_of_collapse_dynamic (%arg0 : tensor <?x4 x?xf32 >, %arg1: index , %arg2: index )
1200
+ -> tensor <?x4 x?xf32 > {
1201
+ %0 = tensor.collapse_shape %arg0 [[0 , 1 ], [2 ]]
1202
+ : tensor <?x4 x?xf32 > into tensor <?x?xf32 >
1203
+ %1 = tensor.expand_shape %0 [[0 , 1 ], [2 ]] output_shape [%arg1 , 4 , %arg2 ]
1204
+ : tensor <?x?xf32 > into tensor <?x4 x?xf32 >
1205
+ return %1 : tensor <?x4 x?xf32 >
1206
+ }
1207
+ // CHECK-LABEL: @fold_expand_of_collapse_dynamic
1208
+ // CHECK-NOT: tensor.{{.*}}_shape
1209
+
1210
+ // -----
1211
+
1212
+ func.func @no_fold_expand_of_collapse_dynamic (%arg0 : tensor <?x?x?xf32 >, %arg1: index , %arg2: index , %arg3: index )
1213
+ -> tensor <?x?x?xf32 > {
1214
+ %0 = tensor.collapse_shape %arg0 [[0 , 1 ], [2 ]]
1215
+ : tensor <?x?x?xf32 > into tensor <?x?xf32 >
1216
+ %1 = tensor.expand_shape %0 [[0 , 1 ], [2 ]] output_shape [%arg1 , %arg2 , %arg3 ]
1217
+ : tensor <?x?xf32 > into tensor <?x?x?xf32 >
1218
+ return %1 : tensor <?x?x?xf32 >
1219
+ }
1220
+ // CHECK-LABEL: @no_fold_expand_of_collapse_dynamic
1221
+ // CHECK: tensor.collapse_shape
1222
+ // CHECK: %[[EXPAND:.+]] = tensor.expand_shape
1223
+ // CHECK: return %[[EXPAND]]
1156
1224
1157
1225
// -----
1158
1226
0 commit comments