[Mlir-commits] [mlir] [mlir] Enhance dimOp fold (PR #187286)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Mar 18 08:02:11 PDT 2026
https://github.com/yangji1993 updated https://github.com/llvm/llvm-project/pull/187286
>From 771322226b8eec0890d69f0e7d506e2060f93bae Mon Sep 17 00:00:00 2001
From: yangji <yangji1993 at gmail.com>
Date: Wed, 18 Mar 2026 22:27:27 +0800
Subject: [PATCH] [mlir] Enhance dimOp fold
Support fold with expand_shape/collapse_shape
---
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 63 +++++++++++++++++++++++-
1 file changed, 62 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index ce0f8540d884a..a2fd5b02ff7b5 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1041,11 +1041,72 @@ struct DimOfReshapeOp : public OpRewritePattern<DimOp> {
return success();
}
};
+
+struct DimOfExpandShapeOp : public OpRewritePattern<DimOp> {
+ using OpRewritePattern<DimOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(DimOp dim,
+ PatternRewriter &rewriter) const override {
+ auto expand = dim.getSource().getDefiningOp<ExpandShapeOp>();
+
+ if (!expand)
+ return failure();
+
+ auto index = dim.getConstantIndex();
+ if (!index.has_value())
+ return failure();
+
+ OpFoldResult val = expand.getMixedOutputShape()[index.value()];
+ if (auto attr = dyn_cast<Attribute>(val)) {
+ rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(
+ dim, cast<IntegerAttr>(attr).getInt());
+ } else {
+ rewriter.replaceOp(dim, cast<Value>(val));
+ }
+
+ return success();
+ }
+};
+
+struct DimOfCollapseShapeOp : public OpRewritePattern<DimOp> {
+ using OpRewritePattern<DimOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(DimOp dim,
+ PatternRewriter &rewriter) const override {
+ auto collapse = dim.getSource().getDefiningOp<CollapseShapeOp>();
+
+ if (!collapse)
+ return failure();
+
+ auto index = dim.getConstantIndex();
+ if (!index.has_value())
+ return failure();
+
+ Location loc = dim.getLoc();
+ auto src = collapse.getSrc();
+ auto indices = collapse.getReassociationIndices()[index.value()];
+
+ SmallVector<Value, 4> dims;
+ for (auto indice : indices) {
+ dims.push_back(rewriter.createOrFold<DimOp>(
+ loc, src, arith::ConstantIndexOp::create(rewriter, loc, indice)));
+ }
+ Value newDim = dims[0];
+
+ for (size_t i = 1; i < dims.size(); i++) {
+ newDim = arith::MulIOp::create(rewriter, loc, newDim, dims[i]);
+ }
+ rewriter.replaceOp(dim, newDim);
+
+ return success();
+ }
+};
} // namespace
void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<DimOfCastOp, DimOfDestStyleOp, DimOfReshapeOp>(context);
+ results.add<DimOfCastOp, DimOfDestStyleOp, DimOfReshapeOp, DimOfExpandShapeOp,
+ DimOfCollapseShapeOp>(context);
}
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list