[Mlir-commits] [mlir] 4abccd3 - [mlir][memref][transform] Register memref dialect patterns

Matthias Springer llvmlistbot at llvm.org
Sun Jun 4 23:50:05 PDT 2023


Author: Matthias Springer
Date: 2023-06-05T08:43:39+02:00
New Revision: 4abccd3913be0fc56e0383e04b3c0a4b872db767

URL: https://github.com/llvm/llvm-project/commit/4abccd3913be0fc56e0383e04b3c0a4b872db767
DIFF: https://github.com/llvm/llvm-project/commit/4abccd3913be0fc56e0383e04b3c0a4b872db767.diff

LOG: [mlir][memref][transform] Register memref dialect patterns

Differential Revision: https://reviews.llvm.org/D151998

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
    mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
    mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
index 1fe13423fa5ce..91ef1620fce64 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
@@ -6,7 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 //
-/// This header declares functions that assit transformations in the MemRef
+/// This header declares functions that assist transformations in the MemRef
 /// dialect.
 //
 //===----------------------------------------------------------------------===//
@@ -44,9 +44,9 @@ void populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns);
 
 /// Appends patterns that resolve `memref.dim` operations with values that are
 /// defined by operations that implement the
-/// `ReifyRankedShapeTypeShapeOpInterface`, in terms of shapes of its input
+/// `ReifyRankedShapedTypeOpInterface`, in terms of shapes of its input
 /// operands.
-void populateResolveRankedShapeTypeResultDimsPatterns(
+void populateResolveRankedShapedTypeResultDimsPatterns(
     RewritePatternSet &patterns);
 
 /// Appends patterns that resolve `memref.dim` operations with values that are
@@ -68,7 +68,7 @@ void populateMemRefWideIntEmulationPatterns(
     arith::WideIntEmulationConverter &typeConverter,
     RewritePatternSet &patterns);
 
-/// Appends type converions for emulating wide integer memref operations with
+/// Appends type conversions for emulating wide integer memref operations with
 /// ops over narrowe integer types.
 void populateMemRefWideIntEmulationConversions(
     arith::WideIntEmulationConverter &typeConverter);

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index d8eccb9675894..f23830699aeb9 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -675,7 +675,7 @@ void mlir::linalg::populateFoldUnitExtentDimsViaReshapesPatterns(
   tensor::EmptyOp::getCanonicalizationPatterns(patterns, context);
   tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context);
   tensor::populateFoldTensorEmptyPatterns(patterns);
-  memref::populateResolveRankedShapeTypeResultDimsPatterns(patterns);
+  memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
   memref::populateResolveShapedTypeResultDimsPatterns(patterns);
 }
 
@@ -689,7 +689,7 @@ void mlir::linalg::populateFoldUnitExtentDimsViaSlicesPatterns(
   linalg::FillOp::getCanonicalizationPatterns(patterns, context);
   tensor::EmptyOp::getCanonicalizationPatterns(patterns, context);
   tensor::populateFoldTensorEmptyPatterns(patterns);
-  memref::populateResolveRankedShapeTypeResultDimsPatterns(patterns);
+  memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
   memref::populateResolveShapedTypeResultDimsPatterns(patterns);
 }
 

diff  --git a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
index f8b449122feb5..7b636137e20ac 100644
--- a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
+++ b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/Dialect/Transform/IR/TransformOps.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Interfaces/LoopLikeInterface.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -161,6 +162,23 @@ class MemRefTransformDialectExtension
 #define GET_OP_LIST
 #include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp.inc"
         >();
+
+    addDialectDataInitializer<transform::PatternRegistry>(
+        [&](transform::PatternRegistry &registry) {
+          registry.registerPatterns("memref.expand_ops",
+                                    memref::populateExpandOpsPatterns);
+          registry.registerPatterns("memref.fold_memref_alias_ops",
+                                    memref::populateFoldMemRefAliasOpPatterns);
+          registry.registerPatterns(
+              "memref.resolve_ranked_shaped_type_result_dims",
+              memref::populateResolveRankedShapedTypeResultDimsPatterns);
+          registry.registerPatterns(
+              "memref.expand_strided_metadata",
+              memref::populateExpandStridedMetadataPatterns);
+          registry.registerPatterns(
+              "memref.extract_address_computations",
+              memref::populateExtractAddressComputationsPatterns);
+        });
   }
 };
 } // namespace

diff  --git a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
index 526c1c6e198ff..9e5fc73bea06a 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
@@ -121,7 +121,7 @@ struct ResolveShapedTypeResultDimsPass final
 
 } // namespace
 
-void memref::populateResolveRankedShapeTypeResultDimsPatterns(
+void memref::populateResolveRankedShapedTypeResultDimsPatterns(
     RewritePatternSet &patterns) {
   patterns.add<DimOfReifyRankedShapedTypeOpInterface<memref::DimOp>,
                DimOfReifyRankedShapedTypeOpInterface<tensor::DimOp>>(
@@ -138,14 +138,14 @@ void memref::populateResolveShapedTypeResultDimsPatterns(
 
 void ResolveRankedShapeTypeResultDimsPass::runOnOperation() {
   RewritePatternSet patterns(&getContext());
-  memref::populateResolveRankedShapeTypeResultDimsPatterns(patterns);
+  memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
   if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
     return signalPassFailure();
 }
 
 void ResolveShapedTypeResultDimsPass::runOnOperation() {
   RewritePatternSet patterns(&getContext());
-  memref::populateResolveRankedShapeTypeResultDimsPatterns(patterns);
+  memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
   memref::populateResolveShapedTypeResultDimsPatterns(patterns);
   if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
     return signalPassFailure();


        


More information about the Mlir-commits mailing list