[Mlir-commits] [mlir] 7ebf70d - [mlir][SCF][bufferize][NFC] Bufferize parallel_insert_slice separately

Matthias Springer llvmlistbot at llvm.org
Mon Jun 27 04:16:11 PDT 2022


Author: Matthias Springer
Date: 2022-06-27T13:16:02+02:00
New Revision: 7ebf70d85d63acaaea56bdb1a13c5a2573868e1c

URL: https://github.com/llvm/llvm-project/commit/7ebf70d85d63acaaea56bdb1a13c5a2573868e1c
DIFF: https://github.com/llvm/llvm-project/commit/7ebf70d85d63acaaea56bdb1a13c5a2573868e1c.diff

LOG: [mlir][SCF][bufferize][NFC] Bufferize parallel_insert_slice separately

This allows for better type inference during bufferization and is in preparation of supporting memory spaces.

Differential Revision: https://reviews.llvm.org/D128580

Added: 
    

Modified: 
    mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index 785cd0d7806d..fb514a2f2b08 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -949,64 +949,32 @@ struct ForeachThreadOpInterface
     return success();
   }
 
-  LogicalResult bufferize(Operation *op, RewriterBase &b,
+  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                           const BufferizationOptions &options) const {
-    OpBuilder::InsertionGuard g(b);
     auto foreachThreadOp = cast<ForeachThreadOp>(op);
 
-    // Gather new results of the ForeachThreadOp.
-    SmallVector<Value> newResults;
-    for (OpResult opResult : foreachThreadOp->getOpResults()) {
-      OpOperand *insertDest =
-          getInsertionDest(foreachThreadOp)[opResult.getResultNumber()];
-      // Insert copies right before the PerformConcurrentlyOp terminator. They
-      // should not be inside terminator (which would be the default insertion
-      // point).
-      Value buffer = getBuffer(b, insertDest->get(), options);
-      newResults.push_back(buffer);
-    }
+#ifndef NDEBUG
+    // ParallelInsertSliceOpInterface replaces all uses.
+    for (OpResult opResult : foreachThreadOp->getOpResults())
+      assert(opResult.getUses().empty() &&
+             "expected that all uses were already replaced");
+#endif // NDEBUG
 
     // Create new ForeachThreadOp without any results and drop the automatically
     // introduced terminator.
     TypeRange newResultTypes;
-    auto newForeachThreadOp =
-        b.create<ForeachThreadOp>(foreachThreadOp.getLoc(), newResultTypes,
-                                  foreachThreadOp.getNumThreads());
+    auto newForeachThreadOp = rewriter.create<ForeachThreadOp>(
+        foreachThreadOp.getLoc(), newResultTypes,
+        foreachThreadOp.getNumThreads());
     newForeachThreadOp.getBody()->getTerminator()->erase();
 
     // Move over block contents of the old op.
-    b.mergeBlocks(foreachThreadOp.getBody(), newForeachThreadOp.getBody(),
-                  {newForeachThreadOp.getBody()->getArguments()});
-
-    // Bufferize terminator.
-    auto performConcurrentlyOp = cast<PerformConcurrentlyOp>(
-        newForeachThreadOp.getBody()->getTerminator());
-    b.setInsertionPoint(performConcurrentlyOp);
-    unsigned resultCounter = 0;
-    WalkResult walkResult =
-        performConcurrentlyOp.walk([&](ParallelInsertSliceOp insertOp) {
-          Location loc = insertOp.getLoc();
-          Type srcType = getMemRefType(
-              insertOp.getSource().getType().cast<RankedTensorType>(), options);
-          // ParallelInsertSliceOp bufferizes to a copy.
-          auto srcMemref = b.create<bufferization::ToMemrefOp>(
-              loc, srcType, insertOp.getSource());
-          Value destMemref = newResults[resultCounter++];
-          Value subview = b.create<memref::SubViewOp>(
-              loc, destMemref, insertOp.getMixedOffsets(),
-              insertOp.getMixedSizes(), insertOp.getMixedStrides());
-          // This memcpy will fold away if everything bufferizes in-place.
-          if (failed(options.createMemCpy(b, insertOp.getLoc(), srcMemref,
-                                          subview)))
-            return WalkResult::interrupt();
-          b.eraseOp(insertOp);
-          return WalkResult::advance();
-        });
-    if (walkResult.wasInterrupted())
-      return failure();
+    rewriter.mergeBlocks(foreachThreadOp.getBody(),
+                         newForeachThreadOp.getBody(),
+                         {newForeachThreadOp.getBody()->getArguments()});
 
-    // Replace the op.
-    replaceOpWithBufferizedValues(b, op, newResults);
+    // Remove the old op.
+    rewriter.eraseOp(op);
 
     return success();
   }
@@ -1104,9 +1072,50 @@ struct ParallelInsertSliceOpInterface
     return success();
   }
 
-  LogicalResult bufferize(Operation *op, RewriterBase &b,
+  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                           const BufferizationOptions &options) const {
-    // Will be bufferized as part of ForeachThreadOp.
+    OpBuilder::InsertionGuard g(rewriter);
+    auto insertOp = cast<ParallelInsertSliceOp>(op);
+    auto performConcurrentlyOp = cast<PerformConcurrentlyOp>(op->getParentOp());
+    auto foreachThreadOp =
+        cast<ForeachThreadOp>(performConcurrentlyOp->getParentOp());
+
+    // If the op bufferizes out-of-place, allocate the copy before the
+    // ForeachThreadOp.
+    rewriter.setInsertionPoint(foreachThreadOp);
+    Value destBuffer = getBuffer(rewriter, insertOp.getDest(), options);
+
+    // Bufferize the ParallelInsertSliceOp outside of the PerformConcurrentlyOp.
+    rewriter.setInsertionPoint(performConcurrentlyOp);
+    Value srcBuffer = getBuffer(rewriter, insertOp.getSource(), options);
+    Value subview = rewriter.create<memref::SubViewOp>(
+        insertOp.getLoc(), destBuffer, insertOp.getMixedOffsets(),
+        insertOp.getMixedSizes(), insertOp.getMixedStrides());
+    // This memcpy will fold away if everything bufferizes in-place.
+    if (failed(options.createMemCpy(rewriter, insertOp.getLoc(), srcBuffer,
+                                    subview)))
+      return failure();
+    rewriter.eraseOp(op);
+
+    // Replace all uses of ForeachThreadOp (just the corresponding result).
+    rewriter.setInsertionPointAfter(foreachThreadOp);
+    Value toTensorOp =
+        rewriter.create<ToTensorOp>(foreachThreadOp.getLoc(), destBuffer);
+    unsigned resultNum = 0;
+    for (Operation &nextOp : performConcurrentlyOp.yieldingOps()) {
+      if (&nextOp == op)
+        break;
+      resultNum++;
+    }
+    assert(resultNum < foreachThreadOp->getNumResults() &&
+           "ParallelInsertSliceOp not found in PerformConcurrentlyOp");
+    SmallVector<OpOperand *> resultUses = llvm::to_vector(
+        llvm::map_range(foreachThreadOp->getResult(resultNum).getUses(),
+                        [](OpOperand &use) { return &use; }));
+    for (OpOperand *use : resultUses) {
+      rewriter.updateRootInPlace(use->getOwner(),
+                                 [&]() { use->set(toTensorOp); });
+    }
     return success();
   }
 


        


More information about the Mlir-commits mailing list