[Mlir-commits] [mlir] 04dac2c - [mlir][SCF][bufferize][NFC] Implement resolveConflicts for ParallelInsertSliceOp

Matthias Springer llvmlistbot at llvm.org
Tue Jun 28 03:23:22 PDT 2022


Author: Matthias Springer
Date: 2022-06-28T12:18:22+02:00
New Revision: 04dac2ca7c06d0ce173e53527e3b90a07e3b325d

URL: https://github.com/llvm/llvm-project/commit/04dac2ca7c06d0ce173e53527e3b90a07e3b325d
DIFF: https://github.com/llvm/llvm-project/commit/04dac2ca7c06d0ce173e53527e3b90a07e3b325d.diff

LOG: [mlir][SCF][bufferize][NFC] Implement resolveConflicts for ParallelInsertSliceOp

This was previous implemented as part of the BufferizableOpInterface of ForEachThreadOp. Moving the implementation to ParallelInsertSliceOp to be consistent with the remaining ops and to have a nice example op that can serve as a blueprint for other ops.

Differential Revision: https://reviews.llvm.org/D128666

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
    mlir/lib/Dialect/SCF/IR/SCF.cpp
    mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index cde966592212d..5de21cc350ac1 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -578,6 +578,10 @@ def ParallelInsertSliceOp : SCF_Op<"foreach_thread.parallel_insert_slice", [
     /// Return the number of leading operands before `offsets`, `sizes` and
     /// `strides` operands.
     static unsigned getOffsetSizeAndStrideStartOperandIndex() { return 1; }
+
+    /// Return the OpResult of the enclosing ForeachThreadOp that is
+    /// corresponding to this ParallelInsertSliceOp.
+    OpResult getTiedOpResult();
   }];
 
   let builders = [

diff  --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index bd0f16dbd0e07..557a9edc2f18e 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -1215,6 +1215,18 @@ ForeachThreadOp mlir::scf::getForeachThreadOpThreadIndexOwner(Value val) {
 // ParallelInsertSliceOp
 //===----------------------------------------------------------------------===//
 
+OpResult ParallelInsertSliceOp::getTiedOpResult() {
+  auto foreachThreadOp = getOperation()->getParentOfType<ForeachThreadOp>();
+  assert(foreachThreadOp && "unlinked ParallelInsertSliceOp");
+  PerformConcurrentlyOp performConcurrentlyOp = foreachThreadOp.getTerminator();
+  for (const auto &it : llvm::enumerate(performConcurrentlyOp.yieldingOps())) {
+    Operation &nextOp = it.value();
+    if (&nextOp == getOperation())
+      return foreachThreadOp->getResult(it.index());
+  }
+  llvm_unreachable("ParallelInsertSliceOp not found");
+}
+
 // Build a ParallelInsertSliceOp with mixed static and dynamic entries.
 void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
                                   Value source, Value dest,

diff  --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index 0227816b77845..1f6359bfbb497 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -961,42 +961,6 @@ struct ForeachThreadOpInterface
     return BufferRelation::Equivalent;
   }
 
-  LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
-                                 const AnalysisState &state) const {
-    auto bufferizableOp = cast<BufferizableOpInterface>(op);
-    if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state)))
-      return failure();
-
-    OpBuilder::InsertionGuard g(rewriter);
-    auto foreachThreadOp = cast<ForeachThreadOp>(op);
-    for (OpResult opResult : foreachThreadOp->getOpResults()) {
-      SmallVector<OpOperand *> destOperands =
-          state.getAliasingOpOperand(opResult);
-      assert(destOperands.size() == 1 &&
-             "expected exactly one aliasing OpOperand");
-      assert(isa<ParallelInsertSliceOp>(destOperands.front()->getOwner()) &&
-             "expected ParallelInsertSliceOp");
-
-      // Nothing to do if there is no conflict.
-      if (state.isInPlace(*destOperands.front()))
-        continue;
-
-      // Insert tensor allocation.
-      bool isYielded = state.isTensorYielded(opResult);
-      FailureOr<Value> alloc = allocateTensorForShapedValue(
-          rewriter, op->getLoc(), destOperands.front()->get(),
-          /*escape=*/isYielded, state.getOptions());
-      if (failed(alloc))
-        return failure();
-
-      // Update terminator operand.
-      rewriter.updateRootInPlace(destOperands.front()->getOwner(),
-                                 [&]() { destOperands.front()->set(*alloc); });
-    }
-
-    return success();
-  }
-
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                           const BufferizationOptions &options) const {
     auto foreachThreadOp = cast<ForeachThreadOp>(op);
@@ -1118,7 +1082,55 @@ struct ParallelInsertSliceOpInterface
 
   LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
                                  const AnalysisState &state) const {
-    // RaW conflicts are resolved as part of ForeachThreadOp.
+    // This interface method is overridden because we want to set a custom
+    // insertion point for tensor copies. They should be inserted right before
+    // the ForeachThreadOp. E.g.:
+    //
+    // %r0, %r1 = foreach_thead ... {
+    //   ...
+    //   perform_concurrently {
+    //     parallel_insert_slice %a into %b ... {inplace = ["true", "true"]}
+    //     parallel_insert_slice %c into %d ... {inplace = ["true", "false"]}
+    //   }
+    // }
+    //
+    // After TensorCopyInsertion:
+    //
+    // %copy = bufferization.alloc_tensor() copy(%d)
+    // %r0, %r1 = foreach_thead ... {
+    //   ...
+    //   perform_concurrently {
+    //     parallel_insert_slice %a into %b ...
+    //     parallel_insert_slice %c into %copy ...
+    //   }
+    // }
+
+    OpBuilder::InsertionGuard g(rewriter);
+    auto insertOp = cast<ParallelInsertSliceOp>(op);
+    auto foreachThreadOp = insertOp->getParentOfType<ForeachThreadOp>();
+
+    // Nothing to do if the destination tensor is inplace.
+    assert(state.isInPlace(op->getOpOperand(0) /*src*/) &&
+           "source is always in-place");
+    if (state.isInPlace(op->getOpOperand(1) /*dest*/))
+      return success();
+
+    // Find corresponding OpResult.
+    OpResult opResult = insertOp.getTiedOpResult();
+
+    // Insert tensor allocation right before the ForeachThreadOp.
+    rewriter.setInsertionPoint(foreachThreadOp);
+    bool isYielded = state.isTensorYielded(opResult);
+    FailureOr<Value> alloc =
+        allocateTensorForShapedValue(rewriter, op->getLoc(), insertOp.getDest(),
+                                     /*escape=*/isYielded, state.getOptions());
+    if (failed(alloc))
+      return failure();
+
+    // Update destination operand.
+    rewriter.updateRootInPlace(
+        insertOp, [&]() { insertOp.getDestMutable().assign(*alloc); });
+
     return success();
   }
 
@@ -1149,29 +1161,20 @@ struct ParallelInsertSliceOpInterface
     if (failed(options.createMemCpy(rewriter, insertOp.getLoc(), *srcBuffer,
                                     subview)))
       return failure();
-    rewriter.eraseOp(op);
 
     // Replace all uses of ForeachThreadOp (just the corresponding result).
     rewriter.setInsertionPointAfter(foreachThreadOp);
     Value toTensorOp =
         rewriter.create<ToTensorOp>(foreachThreadOp.getLoc(), *destBuffer);
-    // PerformConcurrentlyOp can have multiple ParallelInserSliceOps. Find the
-    // index of `op` within yielding ops.
-    unsigned resultNum = 0;
-    for (Operation &nextOp : performConcurrentlyOp.yieldingOps()) {
-      if (&nextOp == op)
-        break;
-      resultNum++;
-    }
-    assert(resultNum < foreachThreadOp->getNumResults() &&
-           "ParallelInsertSliceOp not found in PerformConcurrentlyOp");
-    SmallVector<OpOperand *> resultUses = llvm::to_vector(
-        llvm::map_range(foreachThreadOp->getResult(resultNum).getUses(),
-                        [](OpOperand &use) { return &use; }));
+    // PerformConcurrentlyOp can have multiple ParallelInsertSliceOps.
+    SmallVector<OpOperand *> resultUses =
+        llvm::to_vector(llvm::map_range(insertOp.getTiedOpResult().getUses(),
+                                        [](OpOperand &use) { return &use; }));
     for (OpOperand *use : resultUses) {
       rewriter.updateRootInPlace(use->getOwner(),
                                  [&]() { use->set(toTensorOp); });
     }
+    rewriter.eraseOp(op);
     return success();
   }
 


        


More information about the Mlir-commits mailing list