[Mlir-commits] [mlir] 7a9258e - [mlir][shape] Add a func to populate ShapeToShape patterns.
Alexander Belyaev
llvmlistbot at llvm.org
Tue Jun 16 08:53:14 PDT 2020
Author: Alexander Belyaev
Date: 2020-06-16T17:52:34+02:00
New Revision: 7a9258e9bbf0f2a057fc894299d9a5a79d8c321d
URL: https://github.com/llvm/llvm-project/commit/7a9258e9bbf0f2a057fc894299d9a5a79d8c321d
DIFF: https://github.com/llvm/llvm-project/commit/7a9258e9bbf0f2a057fc894299d9a5a79d8c321d.diff
LOG: [mlir][shape] Add a func to populate ShapeToShape patterns.
Differential Revision: https://reviews.llvm.org/D81933
Added:
Modified:
mlir/include/mlir/Dialect/Shape/Transforms/Passes.h
mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h
index 29cf9d1b6715..7e6065341608 100644
--- a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h
@@ -18,6 +18,8 @@
namespace mlir {
+class MLIRContext;
+class OwningRewritePatternList;
class Pass;
/// Creates an instance of the ShapeToShapeLowering pass that legalizes Shape
@@ -25,6 +27,9 @@ class Pass;
/// transformed to `shape.reduce`, which can be lowered to SCF and Standard.
std::unique_ptr<Pass> createShapeToShapeLowering();
+/// Collects a set of patterns to rewrite ops within the Shape dialect.
+void populateShapeRewritePatterns(MLIRContext *context,
+ OwningRewritePatternList &patterns);
} // end namespace mlir
#endif // MLIR_DIALECT_SHAPE_TRANSFORMS_PASSES_H_
diff --git a/mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp b/mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp
index 1ba68a0a94ee..467f3d33ce23 100644
--- a/mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp
+++ b/mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp
@@ -54,8 +54,10 @@ struct ShapeToShapeLowering
} // namespace
void ShapeToShapeLowering::runOnFunction() {
+ MLIRContext &ctx = getContext();
+
OwningRewritePatternList patterns;
- patterns.insert<NumElementsOpConverter>(&getContext());
+ populateShapeRewritePatterns(&ctx, patterns);
ConversionTarget target(getContext());
target.addLegalDialect<ShapeDialect>();
@@ -64,6 +66,11 @@ void ShapeToShapeLowering::runOnFunction() {
signalPassFailure();
}
+void mlir::populateShapeRewritePatterns(MLIRContext *context,
+ OwningRewritePatternList &patterns) {
+ patterns.insert<NumElementsOpConverter>(context);
+}
+
std::unique_ptr<Pass> mlir::createShapeToShapeLowering() {
return std::make_unique<ShapeToShapeLowering>();
}
More information about the Mlir-commits
mailing list