[Mlir-commits] [mlir] c0db8d5 - [mlir] Expose a function to populate tensor constant bufferization patterns

Benjamin Kramer llvmlistbot at llvm.org
Wed Jun 9 04:52:06 PDT 2021


Author: Benjamin Kramer
Date: 2021-06-09T13:47:33+02:00
New Revision: c0db8d50ca3ceb1301b2ade2fb86c591a5b64e5c

URL: https://github.com/llvm/llvm-project/commit/c0db8d50ca3ceb1301b2ade2fb86c591a5b64e5c
DIFF: https://github.com/llvm/llvm-project/commit/c0db8d50ca3ceb1301b2ade2fb86c591a5b64e5c.diff

LOG: [mlir] Expose a function to populate tensor constant bufferization patterns

This makes it easier to use it from other bufferization passes.

Differential Revision: https://reviews.llvm.org/D103838

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h
    mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h
index 2b7f3da150cdf..c7e331e856519 100644
--- a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h
@@ -19,6 +19,7 @@
 
 namespace mlir {
 
+class GlobalCreator;
 class RewritePatternSet;
 using OwningRewritePatternList = RewritePatternSet;
 
@@ -31,6 +32,12 @@ std::unique_ptr<Pass> createStdBufferizePass();
 /// Creates an instance of func bufferization pass.
 std::unique_ptr<Pass> createFuncBufferizePass();
 
+/// Add patterns to bufferize tensor constants into global memrefs to the given
+/// pattern list.
+void populateTensorConstantBufferizePatterns(
+    GlobalCreator &globalCreator, BufferizeTypeConverter &typeConverter,
+    RewritePatternSet &patterns);
+
 /// Creates an instance of tensor constant bufferization pass.
 std::unique_ptr<Pass> createTensorConstantBufferizePass();
 

diff  --git a/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp b/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp
index b40e47c944141..518405aabb49f 100644
--- a/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp
+++ b/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp
@@ -81,6 +81,13 @@ class BufferizeTensorConstantOp : public OpConversionPattern<ConstantOp> {
 };
 } // namespace
 
+void mlir::populateTensorConstantBufferizePatterns(
+    GlobalCreator &globalCreator, BufferizeTypeConverter &typeConverter,
+    RewritePatternSet &patterns) {
+  patterns.add<BufferizeTensorConstantOp>(globalCreator, typeConverter,
+                                          patterns.getContext());
+}
+
 namespace {
 struct TensorConstantBufferizePass
     : public TensorConstantBufferizeBase<TensorConstantBufferizePass> {
@@ -94,7 +101,7 @@ struct TensorConstantBufferizePass
     ConversionTarget target(*context);
 
     target.addLegalDialect<memref::MemRefDialect>();
-    patterns.add<BufferizeTensorConstantOp>(globals, typeConverter, context);
+    populateTensorConstantBufferizePatterns(globals, typeConverter, patterns);
     target.addDynamicallyLegalOp<ConstantOp>(
         [&](ConstantOp op) { return typeConverter.isLegal(op.getType()); });
     if (failed(applyPartialConversion(module, target, std::move(patterns))))


        


More information about the Mlir-commits mailing list