[Mlir-commits] [mlir] 72de758 - [mlir][SCF] Add bufferization hook for scf.foreach_thread and terminator.

Nicolas Vasilache llvmlistbot at llvm.org
Fri Jun 3 00:19:43 PDT 2022


Author: Nicolas Vasilache
Date: 2022-06-03T07:14:05Z
New Revision: 72de7588cc8bf54b48f66e649f621ec182435e1a

URL: https://github.com/llvm/llvm-project/commit/72de7588cc8bf54b48f66e649f621ec182435e1a
DIFF: https://github.com/llvm/llvm-project/commit/72de7588cc8bf54b48f66e649f621ec182435e1a.diff

LOG: [mlir][SCF] Add bufferization hook for scf.foreach_thread and terminator.

`scf.foreach_thread` results alias with the underlying `scf.foreach_thread.parallel_insert_slice` destination operands
and they bufferize to equivalent buffers in the absence of other conflicts.
`scf.foreach_thread.parallel_insert_slice` conflict detection is similar to `tensor.insert_slice` conflict detection.

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D126769

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SCF/SCFOps.td
    mlir/lib/Dialect/SCF/SCF.cpp
    mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/test/Dialect/SCF/one-shot-bufferize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SCF/SCFOps.td b/mlir/include/mlir/Dialect/SCF/SCFOps.td
index 8c9a1e3ad1d8..3e20412b1736 100644
--- a/mlir/include/mlir/Dialect/SCF/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/SCFOps.td
@@ -448,6 +448,12 @@ def PerformConcurrentlyOp : SCF_Op<"foreach_thread.perform_concurrently", [
   let hasCustomAssemblyFormat = 1;
   let hasVerifier = 1;
 
+  // The default builder does not add a region with an empty body, add our own.
+  let skipDefaultBuilders = 1;
+  let builders = [
+    OpBuilder<(ins)>,
+  ];
+
   // TODO: Add a `PerformConcurrentlyOpInterface` interface for ops that can
   // appear inside perform_concurrently.
   let extraClassDeclaration = [{

diff  --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp
index a160ba028928..ecb66faee199 100644
--- a/mlir/lib/Dialect/SCF/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/SCF.cpp
@@ -1138,10 +1138,11 @@ void ForeachThreadOp::build(mlir::OpBuilder &builder,
   result.addOperands(numThreads);
 
   Region *bodyRegion = result.addRegion();
-  {
-    OpBuilder::InsertionGuard g(builder);
-    builder.createBlock(bodyRegion);
-  }
+  OpBuilder::InsertionGuard g(builder);
+  // createBlock sets the IP inside the block.
+  // Generally we would guard against that but the default ensureTerminator impl
+  // expects it ..
+  builder.createBlock(bodyRegion);
   Block &bodyBlock = bodyRegion->front();
   bodyBlock.addArguments(
       SmallVector<Type>(numThreads.size(), builder.getIndexType()),
@@ -1158,8 +1159,9 @@ void ForeachThreadOp::build(
     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
   result.addOperands(numThreads);
 
+  OpBuilder::InsertionGuard g(builder);
   Region *bodyRegion = result.addRegion();
-  bodyRegion->push_back(new Block);
+  builder.createBlock(bodyRegion);
   Block &bodyBlock = bodyRegion->front();
   bodyBlock.addArguments(
       SmallVector<Type>(numThreads.size(), builder.getIndexType()),
@@ -1167,9 +1169,11 @@ void ForeachThreadOp::build(
 
   OpBuilder::InsertionGuard guard(builder);
   builder.setInsertionPointToStart(&bodyBlock);
-  bodyBuilder(builder, result.location, bodyBlock.getArgument(0));
+  bodyBuilder(builder, result.location, bodyBlock.getArguments());
   auto terminator =
-      llvm::cast<PerformConcurrentlyOp>(bodyBlock.getTerminator());
+      llvm::dyn_cast<PerformConcurrentlyOp>(bodyBlock.getTerminator());
+  assert(terminator &&
+         "expected bodyBuilder to create PerformConcurrentlyOp terminator");
   result.addTypes(terminator.yieldedTypes());
 }
 
@@ -1272,6 +1276,13 @@ void ParallelInsertSliceOp::getCanonicalizationPatterns(
 // PerformConcurrentlyOp
 //===----------------------------------------------------------------------===//
 
+// Build a PerformConcurrentlyOp with mixed static and dynamic entries.
+void PerformConcurrentlyOp::build(OpBuilder &b, OperationState &result) {
+  OpBuilder::InsertionGuard g(b);
+  Region *bodyRegion = result.addRegion();
+  b.createBlock(bodyRegion);
+}
+
 LogicalResult PerformConcurrentlyOp::verify() {
   // TODO: PerformConcurrentlyOpInterface.
   for (const Operation &op : getRegion().front().getOperations())

diff  --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index 2a6a95e5e0d1..79bef06dfc5f 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -13,6 +13,7 @@
 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/Operation.h"
 #include "mlir/IR/PatternMatch.h"
@@ -812,6 +813,289 @@ struct YieldOpInterface
   }
 };
 
+using tensor::ExtractSliceOp;
+
+/// Return the destinations that an ForeachThreadOp is inserting into. One per
+/// ParallelInsertSliceOp.
+static SmallVector<OpOperand *>
+getInsertionDest(ForeachThreadOp foreachThreadOp) {
+  PerformConcurrentlyOp terminator = foreachThreadOp.getTerminator();
+  SmallVector<OpOperand *> result;
+  terminator.walk([&](ParallelInsertSliceOp insertOp) {
+    result.push_back(&insertOp->getOpOperand(1) /*dest*/);
+  });
+  return result;
+}
+
+/// Bufferization of ForeachThreadOp. This also bufferizes the terminator of the
+/// region. There are op interfaces for the terminators (PerformConcurrentlyOp
+/// and ParallelInsertSliceOp), but these are only used during analysis. Not
+/// for bufferization.
+struct ForeachThreadOpInterface
+    : public BufferizableOpInterface::ExternalModel<ForeachThreadOpInterface,
+                                                    ForeachThreadOp> {
+  SmallVector<OpOperand *>
+  getAliasingOpOperand(Operation *op, OpResult opResult,
+                       const AnalysisState &state) const {
+    // Get OpOperand (dest) from corresponding ParallelInsertSliceOp.
+    auto foreachThreadOp = cast<ForeachThreadOp>(op);
+    return {getInsertionDest(foreachThreadOp)[opResult.getResultNumber()]};
+  }
+
+  bool isMemoryWrite(Operation *op, OpResult opResult,
+                     const AnalysisState &state) const {
+    // This op is a memory write. Stop lookup here to avoid finding false
+    // conflicts involving this op and one of the ops in the region. This is
+    // similar to how scf.if ops are analyzed.
+    return true;
+  }
+
+  BufferRelation bufferRelation(Operation *op, OpResult opResult,
+                                const AnalysisState &state) const {
+    return BufferRelation::Equivalent;
+  }
+
+  LogicalResult bufferize(Operation *op, RewriterBase &b,
+                          BufferizationState &state) const {
+    OpBuilder::InsertionGuard g(b);
+    auto foreachThreadOp = cast<ForeachThreadOp>(op);
+
+    // Gather new results of the ForeachThreadOp.
+    SmallVector<Value> newResults;
+    for (OpResult opResult : foreachThreadOp->getOpResults()) {
+      SmallVector<OpOperand *> insertDestOperands =
+          state.getAnalysisState().getAliasingOpOperand(opResult);
+      assert(insertDestOperands.size() == 1 &&
+             "expected exactly one aliasing OpOperand");
+      // Insert copies right before the PerformConcurrentlyOp terminator. They
+      // should not be inside terminator (which would be the default insertion
+      // point).
+      Value buffer = *state.getBuffer(b, *insertDestOperands.front(),
+                                      /*forceInPlace=*/llvm::None,
+                                      /*customCopyInsertionPoint=*/op);
+      newResults.push_back(buffer);
+    }
+
+    // Create new ForeachThreadOp without any results and drop the automatically
+    // introduced terminator.
+    TypeRange newResultTypes;
+    auto newForeachThreadOp =
+        b.create<ForeachThreadOp>(foreachThreadOp.getLoc(), newResultTypes,
+                                  foreachThreadOp.getNumThreads());
+    newForeachThreadOp.getBody()->getTerminator()->erase();
+
+    // Move over block contents of the old op.
+    b.mergeBlocks(foreachThreadOp.getBody(), newForeachThreadOp.getBody(),
+                  {newForeachThreadOp.getBody()->getArguments()});
+
+    // Bufferize terminator.
+    auto performConcurrentlyOp = cast<PerformConcurrentlyOp>(
+        newForeachThreadOp.getBody()->getTerminator());
+    b.setInsertionPoint(performConcurrentlyOp);
+    unsigned resultCounter = 0;
+    WalkResult walkResult =
+        performConcurrentlyOp.walk([&](ParallelInsertSliceOp insertOp) {
+          Location loc = insertOp.getLoc();
+          Type srcType = getMemRefType(
+              insertOp.getSource().getType().cast<RankedTensorType>(),
+              state.getOptions());
+          // ParallelInsertSliceOp bufferizes to a copy.
+          auto srcMemref = b.create<bufferization::ToMemrefOp>(
+              loc, srcType, insertOp.getSource());
+          Value destMemref = newResults[resultCounter++];
+          Value subview = b.create<memref::SubViewOp>(
+              loc, destMemref, insertOp.getMixedOffsets(),
+              insertOp.getMixedSizes(), insertOp.getMixedStrides());
+          // This memcpy will fold away if everything bufferizes in-place.
+          if (failed(state.getOptions().createMemCpy(b, insertOp.getLoc(),
+                                                     srcMemref, subview)))
+            return WalkResult::interrupt();
+          b.eraseOp(insertOp);
+          return WalkResult::advance();
+        });
+    if (walkResult.wasInterrupted())
+      return failure();
+
+    // Replace the op.
+    replaceOpWithBufferizedValues(b, op, newResults);
+
+    return success();
+  }
+};
+
+/// Nothing to do for PerformConcurrentlyOp.
+struct PerformConcurrentlyOpInterface
+    : public BufferizableOpInterface::ExternalModel<
+          PerformConcurrentlyOpInterface, PerformConcurrentlyOp> {
+  LogicalResult bufferize(Operation *op, RewriterBase &b,
+                          BufferizationState &state) const {
+    assert(false && "op does not have any tensor OpOperands / OpResults");
+    return failure();
+  }
+};
+
+/// Return true if the (ExtractSliceOp, ParallelInsertSliceOp) pair match (i.e.
+/// equivalent operand / result and same offset/sizes/strides specification).
+static bool areEquivalentExtractSliceOps(const AnalysisState &state,
+                                         ExtractSliceOp st,
+                                         ParallelInsertSliceOp sti) {
+  if (!st || !sti)
+    return false;
+  if (st != sti &&
+      !state.areEquivalentBufferizedValues(st.source(), sti.getDest()))
+    return false;
+  if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue))
+    return false;
+  return true;
+}
+
+/// Return true if `value` is originating from an ExtractSliceOp that matches
+/// the given InsertSliceOp.
+static bool hasMatchingExtractSliceOp(const AnalysisState &state, Value value,
+                                      ParallelInsertSliceOp insertOp) {
+  auto condition = [&](Value val) {
+    if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
+      if (areEquivalentExtractSliceOps(state, extractOp, insertOp))
+        return true;
+    return false;
+  };
+
+  return llvm::all_of(state.findValueInReverseUseDefChain(value, condition),
+                      condition);
+}
+
+/// Analysis of ParallelInsertSliceOp.
+struct ParallelInsertSliceOpInterface
+    : public BufferizableOpInterface::ExternalModel<
+          ParallelInsertSliceOpInterface, ParallelInsertSliceOp> {
+  SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
+                                            const AnalysisState &state) const {
+    if (&opOperand != &op->getOpOperand(1) /*dest*/)
+      return {};
+
+    // ParallelInsertSliceOp itself has no results. Tensors are returned via
+    // the parent op.
+    auto foreachThreadOp = op->getParentOfType<ForeachThreadOp>();
+    assert(foreachThreadOp &&
+           "could not find valid owner of parallel_insert_slice");
+
+    // The i-th ParallelInsertSliceOp result is returned via the i-th OpResult
+    // of the parent ForeachThreadOp.
+    Block *block = op->getBlock();
+    unsigned int opIdx = 0;
+    for (ParallelInsertSliceOp insertOp :
+         block->getOps<ParallelInsertSliceOp>()) {
+      if (insertOp.getOperation() == op)
+        break;
+      ++opIdx;
+    }
+    assert(opIdx < foreachThreadOp->getNumResults() &&
+           "could not find op inside terminator op");
+
+    return {foreachThreadOp->getResult(opIdx)};
+  }
+
+  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+                              const AnalysisState &state) const {
+    return true;
+  }
+
+  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+                               const AnalysisState &state) const {
+    return &opOperand == &op->getOpOperand(1) /*dest*/;
+  }
+
+  BufferRelation bufferRelation(Operation *op, OpResult opResult,
+                                const AnalysisState &state) const {
+    return BufferRelation::Equivalent;
+  }
+
+  LogicalResult bufferize(Operation *op, RewriterBase &b,
+                          BufferizationState &state) const {
+    // Will be bufferized as part of ForeachThreadOp.
+    return failure();
+  }
+
+  // TODO: This is copied from TensorInterfaceImpl.cpp. Find a way to share
+  // the code.
+  bool isNotConflicting(Operation *op, OpOperand *uRead,
+                        OpOperand *uConflictingWrite,
+                        const AnalysisState &state) const {
+    Operation *readingOp = uRead->getOwner();
+    Operation *conflictingWritingOp = uConflictingWrite->getOwner();
+
+    // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
+    // uRead is an InsertSliceOp...
+    if (auto insertSliceOp = dyn_cast<ParallelInsertSliceOp>(readingOp)) {
+      // As an example, consider the following IR.
+      //
+      // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
+      // %1 = linalg.fill %cst, %0 {inplace= [true] }
+      // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
+      //     {inplace= [true] }
+
+      // TODO: Use insertSliceOp.getDestOpOperand etc. when available.
+      if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
+          hasMatchingExtractSliceOp(state, uConflictingWrite->get(),
+                                    insertSliceOp))
+        // Case 1: The main insight is that InsertSliceOp reads only part of
+        // the destination tensor. The overwritten area is not read. If
+        // uConflictingWrite writes into exactly the memory location that is
+        // being read by uRead, this is not a conflict.
+        //
+        // In the above example:
+        // uRead             = OpOperand 1 (%t) of tensor.insert_slice
+        // uConflictingWrite = OpOperand 1 (%0) of linalg.fill
+        //
+        // The read of %t does not conflict with the write of the FillOp
+        // (same aliases!) because the area that the FillOp operates on is
+        // exactly the one that is *not* read via %t.
+        return true;
+
+      if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
+          uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
+          hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp))
+        // Case 2: The read of the source tensor and the write to the dest
+        // tensor via an InsertSliceOp is not a conflict if the read is
+        // reading exactly that part of an equivalent tensor that the
+        // InsertSliceOp is writing.
+        //
+        // In the above example:
+        // uRead             = OpOperand 0 (%1) of tensor.insert_slice
+        // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
+        return true;
+    }
+
+    // If uConflictingWrite is an InsertSliceOp...
+    if (auto insertSliceOp =
+            dyn_cast<ParallelInsertSliceOp>(conflictingWritingOp))
+      // As an example, consider the following IR.
+      //
+      // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
+      // %1 = linalg.fill %cst, %0 {inplace= [true] }
+      // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
+      //     {inplace= [true] }
+      // %3 = vector.transfer_read %1, %cst
+      //
+      // In the above example:
+      // uRead             = OpOperand 0 (%1) of vector.transfer_read
+      // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
+      // lastWrite         = %1
+      //
+      // This is not a conflict because the InsertSliceOp overwrites the
+      // memory segment of %1 with the exact same data. (Effectively, there
+      // is no memory write here.)
+      if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
+          state.areEquivalentBufferizedValues(uRead->get(),
+                                              insertSliceOp.getSource()) &&
+          hasMatchingExtractSliceOp(state, insertSliceOp.getSource(),
+                                    insertSliceOp))
+        return true;
+
+    return false;
+  }
+};
+
 } // namespace
 } // namespace scf
 } // namespace mlir
@@ -822,6 +1106,11 @@ void mlir::scf::registerBufferizableOpInterfaceExternalModels(
     ExecuteRegionOp::attachInterface<ExecuteRegionOpInterface>(*ctx);
     ForOp::attachInterface<ForOpInterface>(*ctx);
     IfOp::attachInterface<IfOpInterface>(*ctx);
+    ForeachThreadOp::attachInterface<ForeachThreadOpInterface>(*ctx);
+    ParallelInsertSliceOp::attachInterface<ParallelInsertSliceOpInterface>(
+        *ctx);
+    PerformConcurrentlyOp::attachInterface<PerformConcurrentlyOpInterface>(
+        *ctx);
     WhileOp::attachInterface<WhileOpInterface>(*ctx);
     YieldOp::attachInterface<YieldOpInterface>(*ctx);
   });

diff  --git a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
index 888eea82bbf7..00c977d52b99 100644
--- a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
@@ -486,3 +486,127 @@ func.func @scf_while_iter_arg_result_mismatch(%arg0: tensor<5xi1>,
   }
   return
 }
+
+// -----
+
+// CHECK-LABEL: func.func @parallel_insert_slice_no_conflict(
+//  CHECK-SAME:     %[[idx:.*]]: index, %[[idx2:.*]]: index,
+//  CHECK-SAME:     %[[arg1:.*]]: memref<?xf32, #{{.*}}>,
+//  CHECK-SAME:     %[[arg2:.*]]: memref<?xf32, #{{.*}}>
+func.func @parallel_insert_slice_no_conflict(
+    %idx: index,
+    %idx2: index,
+    %arg1: tensor<?xf32> {bufferization.writable = true},
+    %arg2: tensor<?xf32> {bufferization.writable = true}) -> (tensor<?xf32>, f32) {
+  %cst = arith.constant 4.200000e+01 : f32
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+
+  // CHECK: scf.foreach_thread (%[[tidx:.*]]) in (%[[idx2]])  -> ()
+  %2 = scf.foreach_thread (%arg3) in (%idx2)  -> (tensor<?xf32>) {
+      // CHECK: %[[subview:.*]] = memref.subview %[[arg2]][5] [%[[idx]]] [1]
+      %6 = tensor.extract_slice %arg2[5] [%idx] [%c1] : tensor<?xf32> to tensor<?xf32>
+      // CHECK: linalg.fill ins(%{{.*}}) outs(%[[subview]] : memref<?xf32
+      %8 = linalg.fill ins(%cst : f32) outs(%6 : tensor<?xf32>) -> tensor<?xf32>
+      // Self-copy will DCE away later.
+      // CHECK: memref.copy %[[subview]], %[[subview]]
+
+      // Empty terminator is elided from pretty-printing.
+      // CHECK-NOT: scf.foreach_thread.perform_concurrently
+      // CHECK-NOT: parallel_insert_slice
+      scf.foreach_thread.perform_concurrently {
+        scf.foreach_thread.parallel_insert_slice %8 into %arg2[5] [%idx] [%c1] : 
+          tensor<?xf32> into tensor<?xf32>
+      }
+  }
+
+  // CHECK: %[[load:.*]] = memref.load %[[arg2]]
+  %f = tensor.extract %2[%c0] : tensor<?xf32>
+
+  // CHECK: return %[[load]] : f32
+  return %2, %f : tensor<?xf32>, f32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @parallel_insert_slice_with_conflict(
+//  CHECK-SAME:     %[[idx:.*]]: index, %[[idx2:.*]]: index,
+//  CHECK-SAME:     %[[arg1:.*]]: memref<?xf32, #{{.*}}>,
+//  CHECK-SAME:     %[[arg2:.*]]: memref<?xf32, #{{.*}}>
+func.func @parallel_insert_slice_with_conflict(
+    %idx: index, 
+    %idx2: index, 
+    %arg1: tensor<?xf32> {bufferization.writable = true},
+    %arg2: tensor<?xf32> {bufferization.writable = true}) -> (f32, f32)
+{
+  %cst = arith.constant 4.200000e+01 : f32
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+
+  // The parallel_insert_slice_op bufferizes out-of-place due to a RAW conflict
+  // on %arg2, so we need an allocation.
+  // CHECK: %[[alloc1:.*]] = memref.alloc
+  // CHECK: memref.copy %[[arg2]], %[[alloc1]]
+
+  // CHECK: scf.foreach_thread (%[[tidx:.*]]) in (%[[idx2]])  -> ()
+  %2 = scf.foreach_thread (%arg3) in (%idx2)  -> (tensor<?xf32>) {
+      // Another alloc for the extract_slice op.
+      // CHECK: %[[alloc2:.*]] = memref.alloc
+      %6 = tensor.extract_slice %arg2[5] [%idx] [%c1] : tensor<?xf32> to tensor<?xf32>
+
+      // CHECK: linalg.fill ins(%{{.*}}) outs(%[[alloc2]] : memref<?xf32
+      %8 = linalg.fill ins(%cst : f32) outs(%6 : tensor<?xf32>) -> tensor<?xf32>
+
+      // Now the copy of the actual insert_slice.
+      // CHECK: %[[subview1:.*]] = memref.subview %[[alloc1]][5] [%[[idx]]] [1]
+      //
+      // CHECK: memref.copy %[[alloc2]], %[[subview1]]
+      // CHECK: memref.dealloc %[[alloc2]]
+
+      // Empty terminator is elided from pretty-printing.
+      // CHECK-NOT: scf.foreach_thread.perform_concurrently
+      // CHECK-NOT: parallel_insert_slice
+      scf.foreach_thread.perform_concurrently {
+        scf.foreach_thread.parallel_insert_slice %8 into %arg2[5] [%idx] [%c1] :
+          tensor<?xf32> into tensor<?xf32>
+      }
+  }
+
+  // CHECK: %[[load:.*]] = memref.load %[[arg2]]
+  // CHECK: %[[load2:.*]] = memref.load %[[alloc1]]
+  // CHECK: memref.dealloc %[[alloc1]]
+  %f = tensor.extract %arg2[%c0] : tensor<?xf32>
+  %f2 = tensor.extract %2[%c0] : tensor<?xf32>
+
+  // CHECK: return %[[load2]], %[[load]] : f32, f32
+  return %f2, %f : f32, f32
+}
+
+// -----
+
+#map0 = affine_map<(d0) -> (d0 * 4)>
+#map1 = affine_map<(d0) -> (d0 * 2)>
+
+// CHECK: #[[$DYN_LAYOUT_MAP:.*]] = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>
+
+// CHECK-LABEL: func.func @matmul
+func.func @matmul(%arg0: tensor<8x8xf32>, %arg1: tensor<8x8xf32>, %arg2: tensor<8x8xf32> {bufferization.writable = true}) -> tensor<8x8xf32> {
+  %c2 = arith.constant 2 : index
+  %c4 = arith.constant 4 : index
+
+  // CHECK: scf.foreach_thread {{.*}}  -> ()
+  %0 = scf.foreach_thread (%arg3, %arg4) in (%c2, %c4) -> (tensor<8x8xf32>) {
+    %1 = affine.apply #map0(%arg3)
+    %3 = tensor.extract_slice %arg0[%1, 0] [4, 8] [1, 1] : tensor<8x8xf32> to tensor<4x8xf32>
+    %4 = affine.apply #map1(%arg4)
+    %6 = tensor.extract_slice %arg1[0, %4] [8, 4] [1, 1] : tensor<8x8xf32> to tensor<8x4xf32>
+    %7 = tensor.extract_slice %arg2[%1, %4] [4, 4] [1, 1] : tensor<8x8xf32> to tensor<4x4xf32>
+ 
+    //      CHECK: linalg.matmul ins({{.*}}memref<4x8xf32, #[[$DYN_LAYOUT_MAP]]>, memref<8x4xf32, #[[$DYN_LAYOUT_MAP]]>) outs({{.*}} : memref<4x4xf32, #[[$DYN_LAYOUT_MAP]]>)
+    %8 = linalg.matmul ins(%3, %6 : tensor<4x8xf32>, tensor<8x4xf32>) outs(%7 : tensor<4x4xf32>) -> tensor<4x4xf32>
+    scf.foreach_thread.perform_concurrently {
+      scf.foreach_thread.parallel_insert_slice %8 into %arg2[%1, %4] [4, 4] [1, 1] : tensor<4x4xf32> into tensor<8x8xf32>
+    }
+  }
+  return %0 : tensor<8x8xf32>
+}


        


More information about the Mlir-commits mailing list