[Mlir-commits] [mlir] Fold `memref.dim` into `memref.expand_shape` (PR #88423)

Benoit Jacob llvmlistbot at llvm.org
Thu Apr 11 11:41:13 PDT 2024


https://github.com/bjacob created https://github.com/llvm/llvm-project/pull/88423

The lack of this folding pattern causes TypeConverter errors downstream (IREE) as `memref.dim` on `memref.expand_shape` cause non-1D memrefs to survive after we expect them to have been flattened.

I this code is mostly copied from the corresponding TensorOps.cpp code, performing the corresponding folding of `tensor.dim`. The difference is that that code used a `AffineApplyOp` and we can't do that here, because that could create a dependency of MemRefDialect on AffineDialect, which would be circular as AffineDialect depends on MemRefDialect.

For the same reason, this PR only folds into `expand_shape` and not `collapse_shape`. Sorry about the dissymetry, it's because the folding code for `collapse_shape` made more involved use of AffineDialect so would have been more work to reimplement without AffineDialect, and for my own immediate purposes, `expand_shape` is enough.

>From b5a26dcecd1a3196cb97d06458d4018106e1d095 Mon Sep 17 00:00:00 2001
From: Benoit Jacob <jacob.benoit.1 at gmail.com>
Date: Thu, 11 Apr 2024 13:41:41 -0400
Subject: [PATCH] foo

---
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp   | 58 +++++++++++++++++++++-
 mlir/test/Dialect/MemRef/canonicalize.mlir | 16 ++++++
 2 files changed, 73 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 836dcb8f329e70..edc055c3180f07 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1125,11 +1125,67 @@ struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
   }
 };
 
+int64_t getCorrespondingSourceDim(ExpandShapeOp expandShapeOp,
+                                  int64_t resultDim) {
+  assert(resultDim >= 0 &&
+         resultDim < expandShapeOp.getResultType().getRank() &&
+         "invalid resultDim");
+  for (const auto &it :
+       llvm::enumerate(expandShapeOp.getReassociationIndices()))
+    if (llvm::is_contained(it.value(), resultDim))
+      return it.index();
+  assert(false && "could not find reassociation group");
+  return 0;
+}
+
+struct FoldDimOfExpandShape : public OpRewritePattern<DimOp> {
+  using OpRewritePattern<DimOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(DimOp dimOp,
+                                PatternRewriter &rewriter) const override {
+    auto expandShapeOp = dimOp.getSource().getDefiningOp<ExpandShapeOp>();
+    if (!expandShapeOp)
+      return failure();
+
+    // Only constant dimension values are supported.
+    std::optional<int64_t> dim = dimOp.getConstantIndex();
+    if (!dim.has_value())
+      return failure();
+
+    // Skip static dims. These are folded to constant ops.
+    MemRefType resultType = expandShapeOp.getResultType();
+    if (!resultType.isDynamicDim(*dim))
+      return failure();
+
+    // Find reassociation group that contains this result dimension.
+    int64_t srcDim = getCorrespondingSourceDim(expandShapeOp, *dim);
+
+    // `dim` is the only dynamic dimension in `group`. (Otherwise, the
+    // ExpandShapeOp would be ambiguous.)
+    int64_t product = 1;
+    ReassociationIndices grp = expandShapeOp.getReassociationIndices()[srcDim];
+    for (int64_t d : grp) {
+      if (d != dim) {
+        assert(!resultType.isDynamicDim(d) && "expected static dim");
+        product *= resultType.getDimSize(d);
+      }
+    }
+
+    // result dim size = src dim size / (product(other dims in reassoc group))
+    Value srcDimSz =
+        rewriter.create<DimOp>(dimOp.getLoc(), expandShapeOp.getSrc(), srcDim);
+    rewriter.replaceOpWithNewOp<arith::FloorDivSIOp>(
+        dimOp, srcDimSz,
+        rewriter.create<arith::ConstantIndexOp>(dimOp.getLoc(), product));
+    return success();
+  }
+};
+
 } // namespace
 
 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                         MLIRContext *context) {
-  results.add<DimOfMemRefReshape>(context);
+  results.add<DimOfMemRefReshape, FoldDimOfExpandShape>(context);
 }
 
 // ---------------------------------------------------------------------------
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 506ed1f1c10b10..584f4d0e7067aa 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -313,6 +313,22 @@ func.func @dim_of_memref_reshape_i32(%arg0: memref<*xf32>, %arg1: memref<?xi32>)
 
 // -----
 
+// Test case: Folding of memref.dim(memref.expand_shape)
+// CHECK-LABEL: func @dim_of_memref_expand_shape(
+//  CHECK-SAME:     %[[MEM:[0-9a-z]+]]: memref<?x8xi32>
+//  CHECK-NEXT:   %[[IDX:.*]] = arith.constant 0
+//  CHECK-NEXT:   %[[DIM:.*]] = memref.dim %[[MEM]], %[[IDX]] : memref<?x8xi32>
+//       CHECK:   return %[[DIM]] : index
+func.func @dim_of_memref_expand_shape(%arg0: memref<?x8xi32>)
+    -> index {
+  %c1 = arith.constant 1 : index
+  %0 = memref.expand_shape %arg0 [[0, 1], [2, 3]]: memref<?x8xi32> into memref<1x?x2x4xi32>
+  %1 = memref.dim %0, %c1 : memref<1x?x2x4xi32>
+  return %1 : index
+}
+
+// -----
+
 // Test case: memref.dim(memref.reshape %v %shp, %idx) -> memref.load %shp[%idx]
 // CHECK-LABEL: func @dim_of_memref_reshape_block_arg_index(
 //  CHECK-SAME:   %[[MEM:[0-9a-z]+]]: memref<*xf32>,



More information about the Mlir-commits mailing list