[Mlir-commits] [mlir] [mlir] Fix bugs in expand_shape patterns after semantics changes (PR #94631)

Quinn Dawkins llvmlistbot at llvm.org
Thu Jun 6 09:06:57 PDT 2024


================
@@ -1152,7 +1152,60 @@ func.func @fold_collapse_of_expand_dynamic(%arg0 : tensor<?x?xf32>, %arg1: index
   return %1 : tensor<?x?xf32>
 }
 // CHECK-LABEL: @fold_collapse_of_expand_dynamic
-//   CHECK-NOT:   linalg.{{.*}}_shape
+//   CHECK-NOT:   tensor.{{.*}}_shape
+
+// -----
+
+func.func @fold_collapse_of_expand_fully_dynamic(%arg0 : tensor<?x?xf32>, %arg1: index, %arg2: index, %arg3: index)
+    -> tensor<?x?xf32> {
+  %0 = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [%arg1, %arg2, %arg3]
+      : tensor<?x?xf32> into tensor<?x?x?xf32>
+  %1 = tensor.collapse_shape %0 [[0, 1], [2]]
+      : tensor<?x?x?xf32> into tensor<?x?xf32>
+  return %1 : tensor<?x?xf32>
+}
+// CHECK-LABEL: @fold_collapse_of_expand_fully_dynamic
+//   CHECK-NOT:   tensor.{{.*}}_shape
+
+// -----
+
+func.func @fold_expand_of_collapse(%arg0 : tensor<3x4x4xf32>) -> tensor<3x4x4xf32> {
+  %0 = tensor.collapse_shape %arg0 [[0, 1], [2]]
+      : tensor<3x4x4xf32> into tensor<12x4xf32>
+  %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [3, 4, 4]
+      : tensor<12x4xf32> into tensor<3x4x4xf32>
+  return %1 : tensor<3x4x4xf32>
+}
+// CHECK-LABEL: @fold_expand_of_collapse
+//   CHECK-NOT:   tensor.{{.*}}_shape
+
+// -----
+
+func.func @fold_expand_of_collapse_dynamic(%arg0 : tensor<?x4x?xf32>, %arg1: index, %arg2: index)
+    -> tensor<?x4x?xf32> {
+  %0 = tensor.collapse_shape %arg0 [[0, 1], [2]]
+      : tensor<?x4x?xf32> into tensor<?x?xf32>
+  %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%arg1, 4, %arg2]
+      : tensor<?x?xf32> into tensor<?x4x?xf32>
+  return %1 : tensor<?x4x?xf32>
+}
+// CHECK-LABEL: @fold_expand_of_collapse_dynamic
+//   CHECK-NOT:   tensor.{{.*}}_shape
+
+// -----
+
+func.func @no_fold_expand_of_collapse_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1: index, %arg2: index, %arg3: index)
+    -> tensor<?x?x?xf32> {
+  %0 = tensor.collapse_shape %arg0 [[0, 1], [2]]
+      : tensor<?x?x?xf32> into tensor<?x?xf32>
+  %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%arg1, %arg2, %arg3]
+      : tensor<?x?xf32> into tensor<?x?x?xf32>
+  return %1 : tensor<?x?x?xf32>
+}
+// CHECK-LABEL: @no_fold_expand_of_collapse_dynamic
+//       CHECK:   tensor.collapse_shape
+//       CHECK:   %[[EXPAND:.+]] = tensor.expand_shape
+//       CHECK:   return %[[EXPAND]]
----------------
qedawkins wrote:

Please add a test for the cases with different reassociation indices (reflecting the above comment)

https://github.com/llvm/llvm-project/pull/94631


More information about the Mlir-commits mailing list