[Mlir-commits] [mlir] a88732d - [mlir][bufferization][NFC] Extract block signature bufferization into separate function

Matthias Springer llvmlistbot at llvm.org
Thu Aug 17 02:22:22 PDT 2023


Author: Matthias Springer
Date: 2023-08-17T11:16:49+02:00
New Revision: a88732d98b0ccdb57c82635a3b97badd9755f99b

URL: https://github.com/llvm/llvm-project/commit/a88732d98b0ccdb57c82635a3b97badd9755f99b
DIFF: https://github.com/llvm/llvm-project/commit/a88732d98b0ccdb57c82635a3b97badd9755f99b.diff

LOG: [mlir][bufferization][NFC] Extract block signature bufferization into separate function

When bufferizing "func.func", the entry block signature is bufferized. (Only functions with a single block are supported at the moment.) This functionality is moved into a separate function, so that it can be used for bufferizing unstructured control flow in the future.

Differential Revision: https://reviews.llvm.org/D158154

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
    mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
    mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
index 49e74140626fb7..6b1994a5335f15 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
@@ -78,6 +78,15 @@ LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options,
                           const OpFilter *opFilter = nullptr,
                           BufferizationStatistics *statistics = nullptr);
 
+/// Bufferize the signature of `block`. All block argument types are changed to
+/// memref types.
+///
+/// It is expected that the parent op of this block implements the
+/// `BufferizableOpInterface`. The buffer types of tensor block arguments are
+/// computed with `BufferizableOpIntercace::getBufferType`.
+LogicalResult bufferizeBlockSignature(Block *block, RewriterBase &rewriter,
+                                      const BufferizationOptions &options);
+
 BufferizationOptions getPartialBufferizationOptions();
 
 } // namespace bufferization

diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index da9b1d9868b571..67e749be6d7020 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -535,6 +535,56 @@ LogicalResult bufferization::bufferizeOp(Operation *op,
   return success();
 }
 
+LogicalResult
+bufferization::bufferizeBlockSignature(Block *block, RewriterBase &rewriter,
+                                       const BufferizationOptions &options) {
+  OpBuilder::InsertionGuard g(rewriter);
+  auto bufferizableOp = options.dynCastBufferizableOp(block->getParentOp());
+  if (!bufferizableOp)
+    return failure();
+
+  // Compute the new signature.
+  SmallVector<Type> newTypes;
+  for (BlockArgument &bbArg : block->getArguments()) {
+    auto tensorType = dyn_cast<TensorType>(bbArg.getType());
+    if (!tensorType) {
+      newTypes.push_back(bbArg.getType());
+      continue;
+    }
+
+    FailureOr<BaseMemRefType> memrefType =
+        bufferization::getBufferType(bbArg, options);
+    if (failed(memrefType))
+      return failure();
+    newTypes.push_back(*memrefType);
+  }
+
+  // Change the type of all block arguments.
+  for (auto [bbArg, type] : llvm::zip(block->getArguments(), newTypes)) {
+    if (bbArg.getType() == type)
+      continue;
+
+    // Collect all uses of the bbArg.
+    SmallVector<OpOperand *> bbArgUses;
+    for (OpOperand &use : bbArg.getUses())
+      bbArgUses.push_back(&use);
+
+    // Change the bbArg type to memref.
+    bbArg.setType(type);
+
+    // Replace all uses of the original tensor bbArg.
+    rewriter.setInsertionPointToStart(block);
+    if (!bbArgUses.empty()) {
+      Value toTensorOp =
+          rewriter.create<bufferization::ToTensorOp>(bbArg.getLoc(), bbArg);
+      for (OpOperand *use : bbArgUses)
+        use->set(toTensorOp);
+    }
+  }
+
+  return success();
+}
+
 BufferizationOptions bufferization::getPartialBufferizationOptions() {
   BufferizationOptions options;
   options.allowUnknownOps = true;

diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index 10c704fc64dd51..afac36fa9c6d71 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -9,6 +9,7 @@
 #include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h"
 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
@@ -374,37 +375,11 @@ struct FuncOpInterface
     assert(returnOp && "expected func with single return op");
     Location loc = returnOp.getLoc();
 
-    // 1. Rewrite the bbArgs. Turn every tensor bbArg into a memref bbArg.
-    Block &frontBlock = funcOp.getBody().front();
-    for (BlockArgument &bbArg : frontBlock.getArguments()) {
-      auto tensorType = dyn_cast<TensorType>(bbArg.getType());
-      // Non-tensor types stay the same.
-      if (!tensorType)
-        continue;
-
-      // Collect all uses of the bbArg.
-      SmallVector<OpOperand *> bbArgUses;
-      for (OpOperand &use : bbArg.getUses())
-        bbArgUses.push_back(&use);
-
-      // Change the bbArg type to memref.
-      FailureOr<BaseMemRefType> memrefType =
-          bufferization::getBufferType(bbArg, options);
-      if (failed(memrefType))
+    // 1. Bufferize every block.
+    for (Block &block : funcOp.getBody())
+      if (failed(bufferization::bufferizeBlockSignature(&block, rewriter,
+                                                        options)))
         return failure();
-      bbArg.setType(*memrefType);
-
-      // Replace all uses of the original tensor bbArg.
-      rewriter.setInsertionPointToStart(&frontBlock);
-      if (!bbArgUses.empty()) {
-        // Insert to_tensor because the remaining function body has not been
-        // bufferized yet.
-        Value toTensorOp =
-            rewriter.create<bufferization::ToTensorOp>(funcOp.getLoc(), bbArg);
-        for (OpOperand *use : bbArgUses)
-          use->set(toTensorOp);
-      }
-    }
 
     // 2. For each result, keep track of which inplace argument it reuses.
     SmallVector<Value> returnValues;


        


More information about the Mlir-commits mailing list