[Mlir-commits] [mlir] Fold `memref.dim` into `memref.expand_shape` (PR #88423)
Benoit Jacob
llvmlistbot at llvm.org
Thu Apr 11 11:41:13 PDT 2024
https://github.com/bjacob created https://github.com/llvm/llvm-project/pull/88423
The lack of this folding pattern causes TypeConverter errors downstream (IREE) as `memref.dim` on `memref.expand_shape` cause non-1D memrefs to survive after we expect them to have been flattened.
I this code is mostly copied from the corresponding TensorOps.cpp code, performing the corresponding folding of `tensor.dim`. The difference is that that code used a `AffineApplyOp` and we can't do that here, because that could create a dependency of MemRefDialect on AffineDialect, which would be circular as AffineDialect depends on MemRefDialect.
For the same reason, this PR only folds into `expand_shape` and not `collapse_shape`. Sorry about the dissymetry, it's because the folding code for `collapse_shape` made more involved use of AffineDialect so would have been more work to reimplement without AffineDialect, and for my own immediate purposes, `expand_shape` is enough.
>From b5a26dcecd1a3196cb97d06458d4018106e1d095 Mon Sep 17 00:00:00 2001
From: Benoit Jacob <jacob.benoit.1 at gmail.com>
Date: Thu, 11 Apr 2024 13:41:41 -0400
Subject: [PATCH] foo
---
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 58 +++++++++++++++++++++-
mlir/test/Dialect/MemRef/canonicalize.mlir | 16 ++++++
2 files changed, 73 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 836dcb8f329e70..edc055c3180f07 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1125,11 +1125,67 @@ struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
}
};
+int64_t getCorrespondingSourceDim(ExpandShapeOp expandShapeOp,
+ int64_t resultDim) {
+ assert(resultDim >= 0 &&
+ resultDim < expandShapeOp.getResultType().getRank() &&
+ "invalid resultDim");
+ for (const auto &it :
+ llvm::enumerate(expandShapeOp.getReassociationIndices()))
+ if (llvm::is_contained(it.value(), resultDim))
+ return it.index();
+ assert(false && "could not find reassociation group");
+ return 0;
+}
+
+struct FoldDimOfExpandShape : public OpRewritePattern<DimOp> {
+ using OpRewritePattern<DimOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(DimOp dimOp,
+ PatternRewriter &rewriter) const override {
+ auto expandShapeOp = dimOp.getSource().getDefiningOp<ExpandShapeOp>();
+ if (!expandShapeOp)
+ return failure();
+
+ // Only constant dimension values are supported.
+ std::optional<int64_t> dim = dimOp.getConstantIndex();
+ if (!dim.has_value())
+ return failure();
+
+ // Skip static dims. These are folded to constant ops.
+ MemRefType resultType = expandShapeOp.getResultType();
+ if (!resultType.isDynamicDim(*dim))
+ return failure();
+
+ // Find reassociation group that contains this result dimension.
+ int64_t srcDim = getCorrespondingSourceDim(expandShapeOp, *dim);
+
+ // `dim` is the only dynamic dimension in `group`. (Otherwise, the
+ // ExpandShapeOp would be ambiguous.)
+ int64_t product = 1;
+ ReassociationIndices grp = expandShapeOp.getReassociationIndices()[srcDim];
+ for (int64_t d : grp) {
+ if (d != dim) {
+ assert(!resultType.isDynamicDim(d) && "expected static dim");
+ product *= resultType.getDimSize(d);
+ }
+ }
+
+ // result dim size = src dim size / (product(other dims in reassoc group))
+ Value srcDimSz =
+ rewriter.create<DimOp>(dimOp.getLoc(), expandShapeOp.getSrc(), srcDim);
+ rewriter.replaceOpWithNewOp<arith::FloorDivSIOp>(
+ dimOp, srcDimSz,
+ rewriter.create<arith::ConstantIndexOp>(dimOp.getLoc(), product));
+ return success();
+ }
+};
+
} // namespace
void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<DimOfMemRefReshape>(context);
+ results.add<DimOfMemRefReshape, FoldDimOfExpandShape>(context);
}
// ---------------------------------------------------------------------------
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 506ed1f1c10b10..584f4d0e7067aa 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -313,6 +313,22 @@ func.func @dim_of_memref_reshape_i32(%arg0: memref<*xf32>, %arg1: memref<?xi32>)
// -----
+// Test case: Folding of memref.dim(memref.expand_shape)
+// CHECK-LABEL: func @dim_of_memref_expand_shape(
+// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<?x8xi32>
+// CHECK-NEXT: %[[IDX:.*]] = arith.constant 0
+// CHECK-NEXT: %[[DIM:.*]] = memref.dim %[[MEM]], %[[IDX]] : memref<?x8xi32>
+// CHECK: return %[[DIM]] : index
+func.func @dim_of_memref_expand_shape(%arg0: memref<?x8xi32>)
+ -> index {
+ %c1 = arith.constant 1 : index
+ %0 = memref.expand_shape %arg0 [[0, 1], [2, 3]]: memref<?x8xi32> into memref<1x?x2x4xi32>
+ %1 = memref.dim %0, %c1 : memref<1x?x2x4xi32>
+ return %1 : index
+}
+
+// -----
+
// Test case: memref.dim(memref.reshape %v %shp, %idx) -> memref.load %shp[%idx]
// CHECK-LABEL: func @dim_of_memref_reshape_block_arg_index(
// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<*xf32>,
More information about the Mlir-commits
mailing list