[Mlir-commits] [mlir] [mlir] Enhance dimOp fold (PR #187286)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Mar 18 08:23:22 PDT 2026


https://github.com/yangji1993 updated https://github.com/llvm/llvm-project/pull/187286

>From 771322226b8eec0890d69f0e7d506e2060f93bae Mon Sep 17 00:00:00 2001
From: yangji <yangji1993 at gmail.com>
Date: Wed, 18 Mar 2026 22:27:27 +0800
Subject: [PATCH] [mlir] Enhance dimOp fold

Support fold with expand_shape/collapse_shape
---
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 63 +++++++++++++++++++++++-
 1 file changed, 62 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index ce0f8540d884a..a2fd5b02ff7b5 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1041,11 +1041,72 @@ struct DimOfReshapeOp : public OpRewritePattern<DimOp> {
     return success();
   }
 };
+
+struct DimOfExpandShapeOp : public OpRewritePattern<DimOp> {
+  using OpRewritePattern<DimOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(DimOp dim,
+                                PatternRewriter &rewriter) const override {
+    auto expand = dim.getSource().getDefiningOp<ExpandShapeOp>();
+
+    if (!expand)
+      return failure();
+
+    auto index = dim.getConstantIndex();
+    if (!index.has_value())
+      return failure();
+
+    OpFoldResult val = expand.getMixedOutputShape()[index.value()];
+    if (auto attr = dyn_cast<Attribute>(val)) {
+      rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(
+          dim, cast<IntegerAttr>(attr).getInt());
+    } else {
+      rewriter.replaceOp(dim, cast<Value>(val));
+    }
+
+    return success();
+  }
+};
+
+struct DimOfCollapseShapeOp : public OpRewritePattern<DimOp> {
+  using OpRewritePattern<DimOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(DimOp dim,
+                                PatternRewriter &rewriter) const override {
+    auto collapse = dim.getSource().getDefiningOp<CollapseShapeOp>();
+
+    if (!collapse)
+      return failure();
+
+    auto index = dim.getConstantIndex();
+    if (!index.has_value())
+      return failure();
+
+    Location loc = dim.getLoc();
+    auto src = collapse.getSrc();
+    auto indices = collapse.getReassociationIndices()[index.value()];
+
+    SmallVector<Value, 4> dims;
+    for (auto indice : indices) {
+      dims.push_back(rewriter.createOrFold<DimOp>(
+          loc, src, arith::ConstantIndexOp::create(rewriter, loc, indice)));
+    }
+    Value newDim = dims[0];
+
+    for (size_t i = 1; i < dims.size(); i++) {
+      newDim = arith::MulIOp::create(rewriter, loc, newDim, dims[i]);
+    }
+    rewriter.replaceOp(dim, newDim);
+
+    return success();
+  }
+};
 } // namespace
 
 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                         MLIRContext *context) {
-  results.add<DimOfCastOp, DimOfDestStyleOp, DimOfReshapeOp>(context);
+  results.add<DimOfCastOp, DimOfDestStyleOp, DimOfReshapeOp, DimOfExpandShapeOp,
+              DimOfCollapseShapeOp>(context);
 }
 
 //===----------------------------------------------------------------------===//



More information about the Mlir-commits mailing list