[Mlir-commits] [mlir] [mlir] Fold memref.cast static-to-dynamic to memref.expand_shape (PR #170037)

Quinn Dawkins llvmlistbot at llvm.org
Tue Dec 2 11:29:02 PST 2025


================
@@ -551,6 +551,121 @@ func.func @fold_memref_expand_cast(%arg0 : memref<?x?xf32>) -> memref<2x4x4xf32>
 
 // -----
 
+// CHECK-LABEL: @fold_memref_expand_with_static_to_dynamic_cast
+// CHECK-NOT:     memref.cast
+// CHECK:         memref.expand_shape {{.*}} output_shape [2, 1, 4, 4] : memref<8x4xf32> into memref<2x1x4x4xf32>
+// CHECK-NOT:     memref.cast
+// CHECK:         return
+func.func @fold_memref_expand_with_static_to_dynamic_cast(%arg0 : memref<8x4xf32>) -> memref<2x1x4x4xf32> {
+  %0 = memref.cast %arg0 : memref<8x4xf32> to memref<?x4xf32>
+  %c0 = arith.constant 0 : index
+  %dim0 = memref.dim %0, %c0 : memref<?x4xf32>
+  %c4 = arith.constant 4 : index
+  %dim_ext = arith.divui %dim0 , %c4: index
+  %1 = memref.expand_shape %0 [[0, 1, 2], [3]] output_shape [%dim_ext, 1, 4, 4]
+      : memref<?x4xf32> into memref<?x1x4x4xf32>
+  %2 = memref.cast %1 : memref<?x1x4x4xf32> to memref<2x1x4x4xf32>
+  return %2 : memref<2x1x4x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @fold_memref_expand_static_to_dynamic_partial(
+// CHECK-NOT:     memref.cast
+// CHECK:         memref.expand_shape {{.*}} {{\[\[}}0, 1], [2, 3]] output_shape [1, 8, 1, %{{.*}}] : memref<8x?xf32> into memref<1x8x1x?xf32>
+// CHECK-NOT:     memref.cast
+// CHECK:         return
----------------
qedawkins wrote:

nit: Here and below, instead of using CHECK-NOT lines to make sure the cast propagated/folded the way you expected, you can match the labels like
```
// CHECK:  %[[EXPAND:.+]] = memref.expand_shape ...
// CHECK:  return %[[EXPAND]]
```
which is both a strong check and doesn't need CHECK-NOT's.

https://github.com/llvm/llvm-project/pull/170037


More information about the Mlir-commits mailing list