[Mlir-commits] [mlir] [mlir] Enhance dimOp fold (PR #187286)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Mar 18 07:35:12 PDT 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-tensor

@llvm/pr-subscribers-mlir

Author: None (yangji1993)

<details>
<summary>Changes</summary>

Support fold with expand_shape/collapse_shape

---
Full diff: https://github.com/llvm/llvm-project/pull/187286.diff


1 Files Affected:

- (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+62-1) 


``````````diff
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index ce0f8540d884a..a2fd5b02ff7b5 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1041,11 +1041,72 @@ struct DimOfReshapeOp : public OpRewritePattern<DimOp> {
     return success();
   }
 };
+
+struct DimOfExpandShapeOp : public OpRewritePattern<DimOp> {
+  using OpRewritePattern<DimOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(DimOp dim,
+                                PatternRewriter &rewriter) const override {
+    auto expand = dim.getSource().getDefiningOp<ExpandShapeOp>();
+
+    if (!expand)
+      return failure();
+
+    auto index = dim.getConstantIndex();
+    if (!index.has_value())
+      return failure();
+
+    OpFoldResult val = expand.getMixedOutputShape()[index.value()];
+    if (auto attr = dyn_cast<Attribute>(val)) {
+      rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(
+          dim, cast<IntegerAttr>(attr).getInt());
+    } else {
+      rewriter.replaceOp(dim, cast<Value>(val));
+    }
+
+    return success();
+  }
+};
+
+struct DimOfCollapseShapeOp : public OpRewritePattern<DimOp> {
+  using OpRewritePattern<DimOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(DimOp dim,
+                                PatternRewriter &rewriter) const override {
+    auto collapse = dim.getSource().getDefiningOp<CollapseShapeOp>();
+
+    if (!collapse)
+      return failure();
+
+    auto index = dim.getConstantIndex();
+    if (!index.has_value())
+      return failure();
+
+    Location loc = dim.getLoc();
+    auto src = collapse.getSrc();
+    auto indices = collapse.getReassociationIndices()[index.value()];
+
+    SmallVector<Value, 4> dims;
+    for (auto indice : indices) {
+      dims.push_back(rewriter.createOrFold<DimOp>(
+          loc, src, arith::ConstantIndexOp::create(rewriter, loc, indice)));
+    }
+    Value newDim = dims[0];
+
+    for (size_t i = 1; i < dims.size(); i++) {
+      newDim = arith::MulIOp::create(rewriter, loc, newDim, dims[i]);
+    }
+    rewriter.replaceOp(dim, newDim);
+
+    return success();
+  }
+};
 } // namespace
 
 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                         MLIRContext *context) {
-  results.add<DimOfCastOp, DimOfDestStyleOp, DimOfReshapeOp>(context);
+  results.add<DimOfCastOp, DimOfDestStyleOp, DimOfReshapeOp, DimOfExpandShapeOp,
+              DimOfCollapseShapeOp>(context);
 }
 
 //===----------------------------------------------------------------------===//

``````````

</details>


https://github.com/llvm/llvm-project/pull/187286


More information about the Mlir-commits mailing list