[Mlir-commits] [mlir] 3ff93f8 - [mlir][SCF][bufferize][NFC] Bufferize scf.for terminator separately

Matthias Springer llvmlistbot at llvm.org
Mon Jun 27 04:38:36 PDT 2022


Author: Matthias Springer
Date: 2022-06-27T13:26:32+02:00
New Revision: 3ff93f838ebd04a910d76b7f3242140d60a2e56d

URL: https://github.com/llvm/llvm-project/commit/3ff93f838ebd04a910d76b7f3242140d60a2e56d
DIFF: https://github.com/llvm/llvm-project/commit/3ff93f838ebd04a910d76b7f3242140d60a2e56d.diff

LOG: [mlir][SCF][bufferize][NFC] Bufferize scf.for terminator separately

This allows for better type inference during bufferization and is in preparation of supporting memory spaces.

Differential Revision: https://reviews.llvm.org/D128422

Added: 
    

Modified: 
    mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index b9236b3573dd..d55e2784c199 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -445,6 +445,13 @@ struct ForOpInterface
     return success();
   }
 
+  BaseMemRefType getBufferType(Operation *op, BlockArgument bbArg,
+                               const BufferizationOptions &options) const {
+    auto forOp = cast<scf::ForOp>(op);
+    return bufferization::getBufferType(
+        forOp.getOpOperandForRegionIterArg(bbArg).get(), options);
+  }
+
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                           const BufferizationOptions &options) const {
     auto forOp = cast<scf::ForOp>(op);
@@ -474,20 +481,9 @@ struct ForOpInterface
         getBbArgReplacements(rewriter, newForOp.getRegionIterArgs(), indices);
     iterArgs.insert(iterArgs.begin(), newForOp.getInductionVar());
 
-    // Erase terminator if present.
-    if (iterArgs.size() == 1)
-      rewriter.eraseOp(loopBody->getTerminator());
-
     // Move loop body to new loop.
     rewriter.mergeBlocks(oldLoopBody, loopBody, iterArgs);
 
-    // Update scf.yield of new loop.
-    auto yieldOp = cast<scf::YieldOp>(loopBody->getTerminator());
-    rewriter.setInsertionPoint(yieldOp);
-    SmallVector<Value> yieldValues = getYieldedValues(
-        rewriter, yieldOp.getResults(), initArgsTypes, indices, options);
-    yieldOp.getResultsMutable().assign(yieldValues);
-
     // Replace loop results.
     replaceOpWithBufferizedValues(rewriter, op, newForOp->getResults());
 
@@ -844,9 +840,9 @@ struct YieldOpInterface
             yieldOp->getParentOp()))
       return yieldOp->emitError("unsupported scf::YieldOp parent");
 
-    // TODO: Bufferize scf.yield inside scf.while/scf.for here.
-    // (Currently bufferized together with scf.while/scf.for.)
-    if (isa<scf::ForOp, scf::WhileOp>(yieldOp->getParentOp()))
+    // TODO: Bufferize scf.yield inside scf.while here. (Currently bufferized
+    // together with scf.while.)
+    if (isa<scf::WhileOp>(yieldOp->getParentOp()))
       return success();
 
     SmallVector<Value> newResults;
@@ -854,6 +850,13 @@ struct YieldOpInterface
       Value value = it.value();
       if (value.getType().isa<TensorType>()) {
         Value buffer = getBuffer(rewriter, value, options);
+        if (auto forOp = dyn_cast<scf::ForOp>(yieldOp->getParentOp())) {
+          BaseMemRefType resultType =
+              cast<BufferizableOpInterface>(forOp.getOperation())
+                  .getBufferType(forOp.getRegionIterArgs()[it.index()],
+                                 options);
+          buffer = castBuffer(rewriter, buffer, resultType);
+        }
         newResults.push_back(buffer);
       } else {
         newResults.push_back(value);


        


More information about the Mlir-commits mailing list