Author: Benoit Jacob (bjacob)


The lack of this folding pattern causes TypeConverter errors downstream (IREE) as `memref.dim` on `memref.expand_shape` cause non-1D memrefs to survive after we expect them to have been flattened.

I this code is mostly copied from the corresponding TensorOps.cpp code, performing the corresponding folding of `tensor.dim`. The difference is that that code used a `AffineApplyOp` and we can't do that here, because that could create a dependency of MemRefDialect on AffineDialect, which would be circular as AffineDialect depends on MemRefDialect.

For the same reason, this PR only folds into `expand_shape` and not `collapse_shape`. Sorry about the dissymetry, it's because the folding code for `collapse_shape` made more involved use of AffineDialect so would have been more work to reimplement without AffineDialect, and for my own immediate purposes, `expand_shape` is enough.

2 Files Affected:

- (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+57-1) 
- (modified) mlir/test/Dialect/MemRef/canonicalize.mlir (+16) 

diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 836dcb8f329e70..edc055c3180f07 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1125,11 +1125,67 @@ struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
+int64_t getCorrespondingSourceDim(ExpandShapeOp expandShapeOp,
+                                  int64_t resultDim) {
+  assert(resultDim >= 0 &&
+         resultDim < expandShapeOp.getResultType().getRank() &&
+         "invalid resultDim");
+  for (const auto &it :
+       llvm::enumerate(expandShapeOp.getReassociationIndices()))
+    if (llvm::is_contained(it.value(), resultDim))
+      return it.index();
+  assert(false && "could not find reassociation group");
+  return 0;
+struct FoldDimOfExpandShape : public OpRewritePattern<DimOp> {
+  using OpRewritePattern<DimOp>::OpRewritePattern;
+  LogicalResult matchAndRewrite(DimOp dimOp,
+                                PatternRewriter &rewriter) const override {
+    auto expandShapeOp = dimOp.getSource().getDefiningOp<ExpandShapeOp>();
+    if (!expandShapeOp)
+      return failure();
+    // Only constant dimension values are supported.
+    std::optional<int64_t> dim = dimOp.getConstantIndex();
+    if (!dim.has_value())
+      return failure();
+    // Skip static dims. These are folded to constant ops.
+    MemRefType resultType = expandShapeOp.getResultType();
+    if (!resultType.isDynamicDim(*dim))
+      return failure();
+    // Find reassociation group that contains this result dimension.
+    int64_t srcDim = getCorrespondingSourceDim(expandShapeOp, *dim);
+    // `dim` is the only dynamic dimension in `group`. (Otherwise, the
+    // ExpandShapeOp would be ambiguous.)
+    int64_t product = 1;
+    ReassociationIndices grp = expandShapeOp.getReassociationIndices()[srcDim];
+    for (int64_t d : grp) {
+      if (d != dim) {
+        assert(!resultType.isDynamicDim(d) && "expected static dim");
+        product *= resultType.getDimSize(d);
+      }
+    }
+    // result dim size = src dim size / (product(other dims in reassoc group))
+    Value srcDimSz =
+        rewriter.create<DimOp>(dimOp.getLoc(), expandShapeOp.getSrc(), srcDim);
+    rewriter.replaceOpWithNewOp<arith::FloorDivSIOp>(
+        dimOp, srcDimSz,
+        rewriter.create<arith::ConstantIndexOp>(dimOp.getLoc(), product));
+    return success();
+  }
 } // namespace
 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                         MLIRContext *context) {
-  results.add<DimOfMemRefReshape>(context);
+  results.add<DimOfMemRefReshape, FoldDimOfExpandShape>(context);
 // ---------------------------------------------------------------------------
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 506ed1f1c10b10..584f4d0e7067aa 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -313,6 +313,22 @@ func.func @dim_of_memref_reshape_i32(%arg0: memref<*xf32>, %arg1: memref<?xi32>)
 // -----
+// Test case: Folding of memref.dim(memref.expand_shape)
+// CHECK-LABEL: func @dim_of_memref_expand_shape(
+//  CHECK-SAME:     %[[MEM:[0-9a-z]+]]: memref<?x8xi32>
+//  CHECK-NEXT:   %[[IDX:.*]] = arith.constant 0
+//  CHECK-NEXT:   %[[DIM:.*]] = memref.dim %[[MEM]], %[[IDX]] : memref<?x8xi32>
+//       CHECK:   return %[[DIM]] : index
+func.func @dim_of_memref_expand_shape(%arg0: memref<?x8xi32>)
+    -> index {
+  %c1 = arith.constant 1 : index
+  %0 = memref.expand_shape %arg0 [[0, 1], [2, 3]]: memref<?x8xi32> into memref<1x?x2x4xi32>
+  %1 = memref.dim %0, %c1 : memref<1x?x2x4xi32>
+  return %1 : index
+// -----
 // Test case: memref.dim(memref.reshape %v %shp, %idx) -> memref.load %shp[%idx]
 // CHECK-LABEL: func @dim_of_memref_reshape_block_arg_index(
 //  CHECK-SAME:   %[[MEM:[0-9a-z]+]]: memref<*xf32>,




