[Mlir-commits] [mlir] 7a9258e - [mlir][shape] Add a func to populate ShapeToShape patterns.

Alexander Belyaev llvmlistbot at llvm.org
Tue Jun 16 08:53:14 PDT 2020


Author: Alexander Belyaev
Date: 2020-06-16T17:52:34+02:00
New Revision: 7a9258e9bbf0f2a057fc894299d9a5a79d8c321d

URL: https://github.com/llvm/llvm-project/commit/7a9258e9bbf0f2a057fc894299d9a5a79d8c321d
DIFF: https://github.com/llvm/llvm-project/commit/7a9258e9bbf0f2a057fc894299d9a5a79d8c321d.diff

LOG: [mlir][shape] Add a func to populate ShapeToShape patterns.

Differential Revision: https://reviews.llvm.org/D81933

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Shape/Transforms/Passes.h
    mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h
index 29cf9d1b6715..7e6065341608 100644
--- a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h
@@ -18,6 +18,8 @@
 
 namespace mlir {
 
+class MLIRContext;
+class OwningRewritePatternList;
 class Pass;
 
 /// Creates an instance of the ShapeToShapeLowering pass that legalizes Shape
@@ -25,6 +27,9 @@ class Pass;
 /// transformed to `shape.reduce`, which can be lowered to SCF and Standard.
 std::unique_ptr<Pass> createShapeToShapeLowering();
 
+/// Collects a set of patterns to rewrite ops within the Shape dialect.
+void populateShapeRewritePatterns(MLIRContext *context,
+                                  OwningRewritePatternList &patterns);
 } // end namespace mlir
 
 #endif // MLIR_DIALECT_SHAPE_TRANSFORMS_PASSES_H_

diff  --git a/mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp b/mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp
index 1ba68a0a94ee..467f3d33ce23 100644
--- a/mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp
+++ b/mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp
@@ -54,8 +54,10 @@ struct ShapeToShapeLowering
 } // namespace
 
 void ShapeToShapeLowering::runOnFunction() {
+  MLIRContext &ctx = getContext();
+
   OwningRewritePatternList patterns;
-  patterns.insert<NumElementsOpConverter>(&getContext());
+  populateShapeRewritePatterns(&ctx, patterns);
 
   ConversionTarget target(getContext());
   target.addLegalDialect<ShapeDialect>();
@@ -64,6 +66,11 @@ void ShapeToShapeLowering::runOnFunction() {
     signalPassFailure();
 }
 
+void mlir::populateShapeRewritePatterns(MLIRContext *context,
+                                        OwningRewritePatternList &patterns) {
+  patterns.insert<NumElementsOpConverter>(context);
+}
+
 std::unique_ptr<Pass> mlir::createShapeToShapeLowering() {
   return std::make_unique<ShapeToShapeLowering>();
 }


        


More information about the Mlir-commits mailing list