[Mlir-commits] [mlir] [mlir] Enhance dimOp fold (PR #187286)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Mar 18 07:35:12 PDT 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-tensor
@llvm/pr-subscribers-mlir
Author: None (yangji1993)
<details>
<summary>Changes</summary>
Support fold with expand_shape/collapse_shape
---
Full diff: https://github.com/llvm/llvm-project/pull/187286.diff
1 Files Affected:
- (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+62-1)
``````````diff
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index ce0f8540d884a..a2fd5b02ff7b5 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1041,11 +1041,72 @@ struct DimOfReshapeOp : public OpRewritePattern<DimOp> {
return success();
}
};
+
+struct DimOfExpandShapeOp : public OpRewritePattern<DimOp> {
+ using OpRewritePattern<DimOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(DimOp dim,
+ PatternRewriter &rewriter) const override {
+ auto expand = dim.getSource().getDefiningOp<ExpandShapeOp>();
+
+ if (!expand)
+ return failure();
+
+ auto index = dim.getConstantIndex();
+ if (!index.has_value())
+ return failure();
+
+ OpFoldResult val = expand.getMixedOutputShape()[index.value()];
+ if (auto attr = dyn_cast<Attribute>(val)) {
+ rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(
+ dim, cast<IntegerAttr>(attr).getInt());
+ } else {
+ rewriter.replaceOp(dim, cast<Value>(val));
+ }
+
+ return success();
+ }
+};
+
+struct DimOfCollapseShapeOp : public OpRewritePattern<DimOp> {
+ using OpRewritePattern<DimOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(DimOp dim,
+ PatternRewriter &rewriter) const override {
+ auto collapse = dim.getSource().getDefiningOp<CollapseShapeOp>();
+
+ if (!collapse)
+ return failure();
+
+ auto index = dim.getConstantIndex();
+ if (!index.has_value())
+ return failure();
+
+ Location loc = dim.getLoc();
+ auto src = collapse.getSrc();
+ auto indices = collapse.getReassociationIndices()[index.value()];
+
+ SmallVector<Value, 4> dims;
+ for (auto indice : indices) {
+ dims.push_back(rewriter.createOrFold<DimOp>(
+ loc, src, arith::ConstantIndexOp::create(rewriter, loc, indice)));
+ }
+ Value newDim = dims[0];
+
+ for (size_t i = 1; i < dims.size(); i++) {
+ newDim = arith::MulIOp::create(rewriter, loc, newDim, dims[i]);
+ }
+ rewriter.replaceOp(dim, newDim);
+
+ return success();
+ }
+};
} // namespace
void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<DimOfCastOp, DimOfDestStyleOp, DimOfReshapeOp>(context);
+ results.add<DimOfCastOp, DimOfDestStyleOp, DimOfReshapeOp, DimOfExpandShapeOp,
+ DimOfCollapseShapeOp>(context);
}
//===----------------------------------------------------------------------===//
``````````
</details>
https://github.com/llvm/llvm-project/pull/187286
More information about the Mlir-commits
mailing list