[Mlir-commits] [mlir] [mlir] Turn `memref/tensor.dim` reification into canonicalization pattern (PR #70897)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Oct 31 22:45:39 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-linalg

Author: Matthias Springer (matthias-springer)

<details>
<summary>Changes</summary>

Instead of having a dedicated pass to fold `memref/tensor.dim` of ops that implement `ReifyRankedShapedTypeOpInterface`, turn the respective patterns into canonicalization patterns. This allows us to delete canonicalization patterns that do the same for specific ops. (Some of these canonicalization patterns do not have proper error checking; e.g., they crash when the dimension index is out-of-bounds.)

This change also decouples the tensor/memref transforms build units a bit: there is now one fewer dependency on `tensor.dim` in `MemRef/Transforms/ResolveShapedTypeResultDims.cpp`. The canonicalization pattern is now part of `mlir/Interfaces/InferTypeOpInterface.h`.

Also add a new `transform.tensor.resolve_ranked_shaped_type_result_dims` transform op. (`transform.memref.resolve_ranked_shaped_type_result_dims` no longer applies to `tensor.dim` ops.)

---

Patch is 22.70 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/70897.diff


16 Files Affected:

- (modified) mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td (+4-3) 
- (modified) mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h (-6) 
- (modified) mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td (-15) 
- (modified) mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h (-7) 
- (modified) mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td (+13) 
- (modified) mlir/include/mlir/Interfaces/InferTypeOpInterface.h (+52) 
- (modified) mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp (+1-18) 
- (modified) mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp (+4-2) 
- (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+1) 
- (modified) mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp (+2-2) 
- (modified) mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp (+4-56) 
- (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+3-19) 
- (modified) mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp (+5) 
- (modified) mlir/test/Dialect/MemRef/canonicalize.mlir (+12) 
- (removed) mlir/test/Dialect/MemRef/resolve-dim-ops.mlir (-27) 
- (modified) mlir/test/Dialect/Tensor/canonicalize.mlir (+15) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td
index d7bd8410e360a76..7dd2a95f0e621a5 100644
--- a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td
@@ -131,9 +131,10 @@ def ApplyFoldMemrefAliasOpsPatternsOp : Op<Transform_Dialect,
   let assemblyFormat = "attr-dict";
 }
 
-def ApplyResolveRankedShapedTypeResultDimsPatternsOp : Op<Transform_Dialect,
-    "apply_patterns.memref.resolve_ranked_shaped_type_result_dims",
-    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+def ApplyMemrefResolveRankedShapedTypeResultDimsPatternsOp
+    : Op<Transform_Dialect,
+        "apply_patterns.memref.resolve_ranked_shaped_type_result_dims",
+        [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
   let description = [{
     Collects patterns that resolve `memref.dim` operations with values that are
     defined by operations that implement the `ReifyRankedShapedTypeOpInterface`,
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
index f502aac79927094..5cc0b818de4c20f 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
@@ -57,12 +57,6 @@ std::unique_ptr<Pass> createFoldMemRefAliasOpsPass();
 /// (identity) layout map.
 std::unique_ptr<OperationPass<ModuleOp>> createNormalizeMemRefsPass();
 
-/// Creates an operation pass to resolve `memref.dim` operations with values
-/// that are defined by operations that implement the
-/// `ReifyRankedShapeTypeShapeOpInterface`, in terms of shapes of its input
-/// operands.
-std::unique_ptr<Pass> createResolveRankedShapeTypeResultDimsPass();
-
 /// Creates an operation pass to resolve `memref.dim` operations with values
 /// that are defined by operations that implement the
 /// `InferShapedTypeOpInterface` or the `ReifyRankedShapeTypeShapeOpInterface`,
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
index d7ee492b9e990e0..07bf42deabb26b8 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
@@ -159,21 +159,6 @@ def NormalizeMemRefs : Pass<"normalize-memrefs", "ModuleOp"> {
   let dependentDialects = ["affine::AffineDialect"];
 }
 
-def ResolveRankedShapeTypeResultDims :
-    Pass<"resolve-ranked-shaped-type-result-dims"> {
-  let summary = "Resolve memref.dim of result values of ranked shape type";
-  let description = [{
-    The pass resolves memref.dim of result of operations that
-    implement the `ReifyRankedShapedTypeOpInterface` in terms of
-    shapes of its operands.
-  }];
-  let constructor =
-      "mlir::memref::createResolveRankedShapeTypeResultDimsPass()";
-  let dependentDialects = [
-    "memref::MemRefDialect", "tensor::TensorDialect"
-  ];
-}
-
 def ResolveShapedTypeResultDims : Pass<"resolve-shaped-type-result-dims"> {
   let summary = "Resolve memref.dim of result values";
   let description = [{
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
index a918f62cbc8db8f..50000691a2928de 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
@@ -45,13 +45,6 @@ void populateExpandOpsPatterns(RewritePatternSet &patterns);
 /// ops into `patterns`.
 void populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns);
 
-/// Appends patterns that resolve `memref.dim` operations with values that are
-/// defined by operations that implement the
-/// `ReifyRankedShapedTypeOpInterface`, in terms of shapes of its input
-/// operands.
-void populateResolveRankedShapedTypeResultDimsPatterns(
-    RewritePatternSet &patterns);
-
 /// Appends patterns that resolve `memref.dim` operations with values that are
 /// defined by operations that implement the `InferShapedTypeOpInterface`, in
 /// terms of shapes of its input operands.
diff --git a/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td b/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td
index 66c6021418b471c..af598a5b35fab32 100644
--- a/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td
@@ -99,6 +99,19 @@ def ApplyReassociativeReshapeFoldingPatternsOp : Op<Transform_Dialect,
   let assemblyFormat = "attr-dict";
 }
 
+def ApplyTensorResolveRankedShapedTypeResultDimsPatternsOp
+    : Op<Transform_Dialect,
+        "apply_patterns.tensor.resolve_ranked_shaped_type_result_dims",
+        [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Collects patterns that resolve `tensor.dim` operations with values that are
+    defined by operations that implement the `ReifyRankedShapedTypeOpInterface`,
+    in terms of shapes of its input operands.
+  }];
+
+  let assemblyFormat = "attr-dict";
+}
+
 def ApplyRewriteTensorOpsAsConstantPatternsOp : Op<Transform_Dialect,
     "apply_patterns.tensor.rewrite_as_constant",
     [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h
index 67de05b0cb4ff34..79720807a2e7b8d 100644
--- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h
+++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h
@@ -19,6 +19,7 @@
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Location.h"
 #include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/PatternMatch.h"
 #include "mlir/Support/LLVM.h"
 #include "llvm/ADT/PointerUnion.h"
 #include "llvm/ADT/SmallVector.h"
@@ -277,6 +278,57 @@ template <typename ConcreteType>
 class InferTensorType : public TraitBase<ConcreteType, InferTensorType> {};
 
 } // namespace OpTrait
+
+namespace {
+/// Fold dim of an operation that implements ReifyRankedShapedTypeOpInterface.
+template <typename OpTy>
+struct FoldDimOfReifyRankedShapedTypeOp : public OpRewritePattern<OpTy> {
+  using OpRewritePattern<OpTy>::OpRewritePattern;
+
+  void initialize() { OpRewritePattern<OpTy>::setHasBoundedRewriteRecursion(); }
+
+  LogicalResult matchAndRewrite(OpTy dimOp,
+                                PatternRewriter &rewriter) const override {
+    OpResult dimValue = dyn_cast<OpResult>(dimOp.getSource());
+    if (!dimValue)
+      return failure();
+    // Can fold only if the dimension is a constant.
+    std::optional<int64_t> dimIndex = dimOp.getConstantIndex();
+    if (!dimIndex)
+      return failure();
+    // Reify result dimensions.
+    ReifiedRankedShapedTypeDims reifiedResultShapes;
+    if (failed(reifyResultShapes(rewriter, dimValue.getOwner(),
+                                 reifiedResultShapes)))
+      return rewriter.notifyMatchFailure(dimOp,
+                                         "failed to reify result shapes");
+    unsigned resultNumber = dimValue.getResultNumber();
+    // Do not apply pattern if the IR is invalid (dim out of bounds).
+    if (*dimIndex >= reifiedResultShapes[resultNumber].size())
+      return rewriter.notifyMatchFailure(dimOp, "dimension is out of bounds");
+    OpFoldResult dimSize = reifiedResultShapes[resultNumber][*dimIndex];
+    // If the dim size is a value, replace the op directly.
+    if (auto value = dimSize.dyn_cast<Value>()) {
+      rewriter.replaceOp(dimOp, value);
+      return success();
+    }
+    // Otherwise, materialize a constant value.
+    rewriter.replaceOp(dimOp, dimOp->getDialect()->materializeConstant(
+                                  rewriter, dimSize.get<Attribute>(),
+                                  rewriter.getIndexType(), dimOp->getLoc()));
+    return success();
+  }
+};
+} // namespace
+
+/// Populate `patterns` with a pattern that dim ops of type OpTy that operate
+/// on ops that implement ReifyRankedShapedTypeOpInterface.
+template <typename OpTy>
+void populateResolveRankedShapedTypeResultDimsPattern(
+    RewritePatternSet &patterns) {
+  patterns.insert<FoldDimOfReifyRankedShapedTypeOp<OpTy>>(
+      patterns.getContext());
+}
 } // namespace mlir
 
 #endif // MLIR_INTERFACES_INFERTYPEOPINTERFACE_H_
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 8f19245efdba6c8..1d6897ebf3437aa 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -329,28 +329,11 @@ struct ReplaceStaticShapeDims : OpRewritePattern<AllocTensorOp> {
     return success();
   }
 };
-
-struct FoldDimOfAllocTensorOp : public OpRewritePattern<tensor::DimOp> {
-  using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(tensor::DimOp dimOp,
-                                PatternRewriter &rewriter) const override {
-    std::optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
-    auto allocTensorOp = dimOp.getSource().getDefiningOp<AllocTensorOp>();
-    if (!allocTensorOp || !maybeConstantIndex)
-      return failure();
-    if (!allocTensorOp.getType().isDynamicDim(*maybeConstantIndex))
-      return failure();
-    rewriter.replaceOp(
-        dimOp, allocTensorOp.getDynamicSize(rewriter, *maybeConstantIndex));
-    return success();
-  }
-};
 } // namespace
 
 void AllocTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                 MLIRContext *ctx) {
-  results.add<FoldDimOfAllocTensorOp, ReplaceStaticShapeDims>(ctx);
+  results.add<ReplaceStaticShapeDims>(ctx);
 }
 
 LogicalResult AllocTensorOp::reifyResultShapes(
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 2e3610b7c08d9da..ea6239e78a66656 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -647,7 +647,8 @@ populateFoldUnitExtentDimsViaReshapesPatterns(RewritePatternSet &patterns,
   tensor::EmptyOp::getCanonicalizationPatterns(patterns, context);
   tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context);
   tensor::populateFoldTensorEmptyPatterns(patterns);
-  memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
+  populateResolveRankedShapedTypeResultDimsPattern<memref::DimOp>(patterns);
+  populateResolveRankedShapedTypeResultDimsPattern<tensor::DimOp>(patterns);
   memref::populateResolveShapedTypeResultDimsPatterns(patterns);
 }
 
@@ -662,7 +663,8 @@ populateFoldUnitExtentDimsViaSlicesPatterns(RewritePatternSet &patterns,
   linalg::FillOp::getCanonicalizationPatterns(patterns, context);
   tensor::EmptyOp::getCanonicalizationPatterns(patterns, context);
   tensor::populateFoldTensorEmptyPatterns(patterns);
-  memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
+  populateResolveRankedShapedTypeResultDimsPattern<memref::DimOp>(patterns);
+  populateResolveRankedShapedTypeResultDimsPattern<tensor::DimOp>(patterns);
   memref::populateResolveShapedTypeResultDimsPatterns(patterns);
 }
 
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 215a8f5e7d18be0..749802837186227 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1136,6 +1136,7 @@ struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                         MLIRContext *context) {
   results.add<DimOfMemRefReshape>(context);
+  populateResolveRankedShapedTypeResultDimsPattern<memref::DimOp>(results);
 }
 
 // ---------------------------------------------------------------------------
diff --git a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
index eed29efcaaada88..d56fa102451c366 100644
--- a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
+++ b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
@@ -121,9 +121,9 @@ void transform::ApplyFoldMemrefAliasOpsPatternsOp::populatePatterns(
   memref::populateFoldMemRefAliasOpPatterns(patterns);
 }
 
-void transform::ApplyResolveRankedShapedTypeResultDimsPatternsOp::
+void transform::ApplyMemrefResolveRankedShapedTypeResultDimsPatternsOp::
     populatePatterns(RewritePatternSet &patterns) {
-  memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
+  populateResolveRankedShapedTypeResultDimsPattern<memref::DimOp>(patterns);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
index 0cb5931ce6bf9b9..9f3f33aadf93fb8 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
@@ -24,7 +24,6 @@
 
 namespace mlir {
 namespace memref {
-#define GEN_PASS_DEF_RESOLVERANKEDSHAPETYPERESULTDIMS
 #define GEN_PASS_DEF_RESOLVESHAPEDTYPERESULTDIMS
 #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
 } // namespace memref
@@ -72,37 +71,6 @@ struct DimOfShapedTypeOpInterface : public OpRewritePattern<OpTy> {
     return success();
   }
 };
-
-/// Fold dim of an operation that implements the InferShapedTypeOpInterface
-template <typename OpTy>
-struct DimOfReifyRankedShapedTypeOpInterface : public OpRewritePattern<OpTy> {
-  using OpRewritePattern<OpTy>::OpRewritePattern;
-
-  void initialize() { OpRewritePattern<OpTy>::setHasBoundedRewriteRecursion(); }
-
-  LogicalResult matchAndRewrite(OpTy dimOp,
-                                PatternRewriter &rewriter) const override {
-    OpResult dimValue = dyn_cast<OpResult>(dimOp.getSource());
-    if (!dimValue)
-      return failure();
-    std::optional<int64_t> dimIndex = dimOp.getConstantIndex();
-    if (!dimIndex)
-      return failure();
-
-    ReifiedRankedShapedTypeDims reifiedResultShapes;
-    if (failed(reifyResultShapes(rewriter, dimValue.getOwner(),
-                                 reifiedResultShapes)))
-      return failure();
-    unsigned resultNumber = dimValue.getResultNumber();
-    // Do not apply pattern if the IR is invalid (dim out of bounds).
-    if ((size_t)(*dimIndex) >= reifiedResultShapes[resultNumber].size())
-      return rewriter.notifyMatchFailure(dimOp, "dimension is out of bounds");
-    Value replacement = getValueOrCreateConstantIndexOp(
-        rewriter, dimOp.getLoc(), reifiedResultShapes[resultNumber][*dimIndex]);
-    rewriter.replaceOp(dimOp, replacement);
-    return success();
-  }
-};
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -110,11 +78,6 @@ struct DimOfReifyRankedShapedTypeOpInterface : public OpRewritePattern<OpTy> {
 //===----------------------------------------------------------------------===//
 
 namespace {
-struct ResolveRankedShapeTypeResultDimsPass final
-    : public memref::impl::ResolveRankedShapeTypeResultDimsBase<
-          ResolveRankedShapeTypeResultDimsPass> {
-  void runOnOperation() override;
-};
 
 struct ResolveShapedTypeResultDimsPass final
     : public memref::impl::ResolveShapedTypeResultDimsBase<
@@ -124,13 +87,6 @@ struct ResolveShapedTypeResultDimsPass final
 
 } // namespace
 
-void memref::populateResolveRankedShapedTypeResultDimsPatterns(
-    RewritePatternSet &patterns) {
-  patterns.add<DimOfReifyRankedShapedTypeOpInterface<memref::DimOp>,
-               DimOfReifyRankedShapedTypeOpInterface<tensor::DimOp>>(
-      patterns.getContext());
-}
-
 void memref::populateResolveShapedTypeResultDimsPatterns(
     RewritePatternSet &patterns) {
   // TODO: Move tensor::DimOp pattern to the Tensor dialect.
@@ -139,17 +95,13 @@ void memref::populateResolveShapedTypeResultDimsPatterns(
       patterns.getContext());
 }
 
-void ResolveRankedShapeTypeResultDimsPass::runOnOperation() {
-  RewritePatternSet patterns(&getContext());
-  memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
-  if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
-    return signalPassFailure();
-}
-
 void ResolveShapedTypeResultDimsPass::runOnOperation() {
   RewritePatternSet patterns(&getContext());
-  memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
   memref::populateResolveShapedTypeResultDimsPatterns(patterns);
+  // TODO: `populateResolveRankedShapedTypeResultDimsPattern` does not really
+  // belong here.
+  populateResolveRankedShapedTypeResultDimsPattern<tensor::DimOp>(patterns);
+  populateResolveRankedShapedTypeResultDimsPattern<memref::DimOp>(patterns);
   if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
     return signalPassFailure();
 }
@@ -157,7 +109,3 @@ void ResolveShapedTypeResultDimsPass::runOnOperation() {
 std::unique_ptr<Pass> memref::createResolveShapedTypeResultDimsPass() {
   return std::make_unique<ResolveShapedTypeResultDimsPass>();
 }
-
-std::unique_ptr<Pass> memref::createResolveRankedShapeTypeResultDimsPass() {
-  return std::make_unique<ResolveRankedShapeTypeResultDimsPass>();
-}
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index f719cfed6b6dd30..29e623a06933cce 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -606,6 +606,7 @@ struct DimOfDestStyleOp : public OpRewritePattern<DimOp> {
 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                         MLIRContext *context) {
   results.add<DimOfCastOp, DimOfDestStyleOp>(context);
+  populateResolveRankedShapedTypeResultDimsPattern<tensor::DimOp>(results);
 }
 
 //===----------------------------------------------------------------------===//
@@ -737,23 +738,6 @@ struct ReplaceEmptyTensorStaticShapeDims : OpRewritePattern<EmptyOp> {
   }
 };
 
-struct FoldEmptyTensorWithDimOp : public OpRewritePattern<DimOp> {
-  using OpRewritePattern<DimOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(tensor::DimOp dimOp,
-                                PatternRewriter &rewriter) const override {
-    std::optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
-    auto emptyTensorOp = dimOp.getSource().getDefiningOp<EmptyOp>();
-    if (!emptyTensorOp || !maybeConstantIndex)
-      return failure();
-    if (!emptyTensorOp.getType().isDynamicDim(*maybeConstantIndex))
-      return failure();
-    rewriter.replaceOp(dimOp,
-                       emptyTensorOp.getDynamicSize(*maybeConstantIndex));
-    return success();
-  }
-};
-
 /// Canonicalize
 ///
 /// ```mlir
@@ -830,8 +814,8 @@ struct FoldEmptyTensorWithCastOp : public OpRewritePattern<CastOp> {
 
 void EmptyOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                           MLIRContext *context) {
-  results.add<FoldEmptyTensorWithCastOp, FoldEmptyTensorWithDimOp,
-              ReplaceEmptyTensorStaticShapeDims>(context);
+  results.add<FoldEmptyTensorWithCastOp, ReplaceEmptyTensorStaticShapeDims>(
+      context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
index 3cec91389392246..d92f68712a9972b 100644
--- a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
+++ b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
@@ -118,6 +118,11 @@ void transform::ApplyReassociativeReshapeFoldingPatternsOp::populatePatterns(
   tensor::populateReassociativeReshapeFoldingPatterns(patterns);
 }
 
+void transform::ApplyTensorResolveRankedShapedTypeResultDimsPatternsOp::
+    populatePatterns(RewritePatternSet &patterns) {
+  populateResolveRankedShapedTypeResultDimsPattern<tensor::DimOp>(patterns);
+}
+
 void transform::ApplyRewriteTensorOpsAsConstantPatternsOp::populatePatterns(
     RewritePatternSet &patterns) {
   tensor::populateRewriteAsConstantPatterns(patterns);
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemR...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/70897


More information about the Mlir-commits mailing list