[Mlir-commits] [mlir] 72de758 - [mlir][SCF] Add bufferization hook for scf.foreach_thread and terminator.
Nicolas Vasilache
llvmlistbot at llvm.org
Fri Jun 3 00:19:43 PDT 2022
Author: Nicolas Vasilache
Date: 2022-06-03T07:14:05Z
New Revision: 72de7588cc8bf54b48f66e649f621ec182435e1a
URL: https://github.com/llvm/llvm-project/commit/72de7588cc8bf54b48f66e649f621ec182435e1a
DIFF: https://github.com/llvm/llvm-project/commit/72de7588cc8bf54b48f66e649f621ec182435e1a.diff
LOG: [mlir][SCF] Add bufferization hook for scf.foreach_thread and terminator.
`scf.foreach_thread` results alias with the underlying `scf.foreach_thread.parallel_insert_slice` destination operands
and they bufferize to equivalent buffers in the absence of other conflicts.
`scf.foreach_thread.parallel_insert_slice` conflict detection is similar to `tensor.insert_slice` conflict detection.
Reviewed By: springerm
Differential Revision: https://reviews.llvm.org/D126769
Added:
Modified:
mlir/include/mlir/Dialect/SCF/SCFOps.td
mlir/lib/Dialect/SCF/SCF.cpp
mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/test/Dialect/SCF/one-shot-bufferize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SCF/SCFOps.td b/mlir/include/mlir/Dialect/SCF/SCFOps.td
index 8c9a1e3ad1d8..3e20412b1736 100644
--- a/mlir/include/mlir/Dialect/SCF/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/SCFOps.td
@@ -448,6 +448,12 @@ def PerformConcurrentlyOp : SCF_Op<"foreach_thread.perform_concurrently", [
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
+ // The default builder does not add a region with an empty body, add our own.
+ let skipDefaultBuilders = 1;
+ let builders = [
+ OpBuilder<(ins)>,
+ ];
+
// TODO: Add a `PerformConcurrentlyOpInterface` interface for ops that can
// appear inside perform_concurrently.
let extraClassDeclaration = [{
diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp
index a160ba028928..ecb66faee199 100644
--- a/mlir/lib/Dialect/SCF/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/SCF.cpp
@@ -1138,10 +1138,11 @@ void ForeachThreadOp::build(mlir::OpBuilder &builder,
result.addOperands(numThreads);
Region *bodyRegion = result.addRegion();
- {
- OpBuilder::InsertionGuard g(builder);
- builder.createBlock(bodyRegion);
- }
+ OpBuilder::InsertionGuard g(builder);
+ // createBlock sets the IP inside the block.
+ // Generally we would guard against that but the default ensureTerminator impl
+ // expects it ..
+ builder.createBlock(bodyRegion);
Block &bodyBlock = bodyRegion->front();
bodyBlock.addArguments(
SmallVector<Type>(numThreads.size(), builder.getIndexType()),
@@ -1158,8 +1159,9 @@ void ForeachThreadOp::build(
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
result.addOperands(numThreads);
+ OpBuilder::InsertionGuard g(builder);
Region *bodyRegion = result.addRegion();
- bodyRegion->push_back(new Block);
+ builder.createBlock(bodyRegion);
Block &bodyBlock = bodyRegion->front();
bodyBlock.addArguments(
SmallVector<Type>(numThreads.size(), builder.getIndexType()),
@@ -1167,9 +1169,11 @@ void ForeachThreadOp::build(
OpBuilder::InsertionGuard guard(builder);
builder.setInsertionPointToStart(&bodyBlock);
- bodyBuilder(builder, result.location, bodyBlock.getArgument(0));
+ bodyBuilder(builder, result.location, bodyBlock.getArguments());
auto terminator =
- llvm::cast<PerformConcurrentlyOp>(bodyBlock.getTerminator());
+ llvm::dyn_cast<PerformConcurrentlyOp>(bodyBlock.getTerminator());
+ assert(terminator &&
+ "expected bodyBuilder to create PerformConcurrentlyOp terminator");
result.addTypes(terminator.yieldedTypes());
}
@@ -1272,6 +1276,13 @@ void ParallelInsertSliceOp::getCanonicalizationPatterns(
// PerformConcurrentlyOp
//===----------------------------------------------------------------------===//
+// Build a PerformConcurrentlyOp with mixed static and dynamic entries.
+void PerformConcurrentlyOp::build(OpBuilder &b, OperationState &result) {
+ OpBuilder::InsertionGuard g(b);
+ Region *bodyRegion = result.addRegion();
+ b.createBlock(bodyRegion);
+}
+
LogicalResult PerformConcurrentlyOp::verify() {
// TODO: PerformConcurrentlyOpInterface.
for (const Operation &op : getRegion().front().getOperations())
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index 2a6a95e5e0d1..79bef06dfc5f 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -13,6 +13,7 @@
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
@@ -812,6 +813,289 @@ struct YieldOpInterface
}
};
+using tensor::ExtractSliceOp;
+
+/// Return the destinations that an ForeachThreadOp is inserting into. One per
+/// ParallelInsertSliceOp.
+static SmallVector<OpOperand *>
+getInsertionDest(ForeachThreadOp foreachThreadOp) {
+ PerformConcurrentlyOp terminator = foreachThreadOp.getTerminator();
+ SmallVector<OpOperand *> result;
+ terminator.walk([&](ParallelInsertSliceOp insertOp) {
+ result.push_back(&insertOp->getOpOperand(1) /*dest*/);
+ });
+ return result;
+}
+
+/// Bufferization of ForeachThreadOp. This also bufferizes the terminator of the
+/// region. There are op interfaces for the terminators (PerformConcurrentlyOp
+/// and ParallelInsertSliceOp), but these are only used during analysis. Not
+/// for bufferization.
+struct ForeachThreadOpInterface
+ : public BufferizableOpInterface::ExternalModel<ForeachThreadOpInterface,
+ ForeachThreadOp> {
+ SmallVector<OpOperand *>
+ getAliasingOpOperand(Operation *op, OpResult opResult,
+ const AnalysisState &state) const {
+ // Get OpOperand (dest) from corresponding ParallelInsertSliceOp.
+ auto foreachThreadOp = cast<ForeachThreadOp>(op);
+ return {getInsertionDest(foreachThreadOp)[opResult.getResultNumber()]};
+ }
+
+ bool isMemoryWrite(Operation *op, OpResult opResult,
+ const AnalysisState &state) const {
+ // This op is a memory write. Stop lookup here to avoid finding false
+ // conflicts involving this op and one of the ops in the region. This is
+ // similar to how scf.if ops are analyzed.
+ return true;
+ }
+
+ BufferRelation bufferRelation(Operation *op, OpResult opResult,
+ const AnalysisState &state) const {
+ return BufferRelation::Equivalent;
+ }
+
+ LogicalResult bufferize(Operation *op, RewriterBase &b,
+ BufferizationState &state) const {
+ OpBuilder::InsertionGuard g(b);
+ auto foreachThreadOp = cast<ForeachThreadOp>(op);
+
+ // Gather new results of the ForeachThreadOp.
+ SmallVector<Value> newResults;
+ for (OpResult opResult : foreachThreadOp->getOpResults()) {
+ SmallVector<OpOperand *> insertDestOperands =
+ state.getAnalysisState().getAliasingOpOperand(opResult);
+ assert(insertDestOperands.size() == 1 &&
+ "expected exactly one aliasing OpOperand");
+ // Insert copies right before the PerformConcurrentlyOp terminator. They
+ // should not be inside terminator (which would be the default insertion
+ // point).
+ Value buffer = *state.getBuffer(b, *insertDestOperands.front(),
+ /*forceInPlace=*/llvm::None,
+ /*customCopyInsertionPoint=*/op);
+ newResults.push_back(buffer);
+ }
+
+ // Create new ForeachThreadOp without any results and drop the automatically
+ // introduced terminator.
+ TypeRange newResultTypes;
+ auto newForeachThreadOp =
+ b.create<ForeachThreadOp>(foreachThreadOp.getLoc(), newResultTypes,
+ foreachThreadOp.getNumThreads());
+ newForeachThreadOp.getBody()->getTerminator()->erase();
+
+ // Move over block contents of the old op.
+ b.mergeBlocks(foreachThreadOp.getBody(), newForeachThreadOp.getBody(),
+ {newForeachThreadOp.getBody()->getArguments()});
+
+ // Bufferize terminator.
+ auto performConcurrentlyOp = cast<PerformConcurrentlyOp>(
+ newForeachThreadOp.getBody()->getTerminator());
+ b.setInsertionPoint(performConcurrentlyOp);
+ unsigned resultCounter = 0;
+ WalkResult walkResult =
+ performConcurrentlyOp.walk([&](ParallelInsertSliceOp insertOp) {
+ Location loc = insertOp.getLoc();
+ Type srcType = getMemRefType(
+ insertOp.getSource().getType().cast<RankedTensorType>(),
+ state.getOptions());
+ // ParallelInsertSliceOp bufferizes to a copy.
+ auto srcMemref = b.create<bufferization::ToMemrefOp>(
+ loc, srcType, insertOp.getSource());
+ Value destMemref = newResults[resultCounter++];
+ Value subview = b.create<memref::SubViewOp>(
+ loc, destMemref, insertOp.getMixedOffsets(),
+ insertOp.getMixedSizes(), insertOp.getMixedStrides());
+ // This memcpy will fold away if everything bufferizes in-place.
+ if (failed(state.getOptions().createMemCpy(b, insertOp.getLoc(),
+ srcMemref, subview)))
+ return WalkResult::interrupt();
+ b.eraseOp(insertOp);
+ return WalkResult::advance();
+ });
+ if (walkResult.wasInterrupted())
+ return failure();
+
+ // Replace the op.
+ replaceOpWithBufferizedValues(b, op, newResults);
+
+ return success();
+ }
+};
+
+/// Nothing to do for PerformConcurrentlyOp.
+struct PerformConcurrentlyOpInterface
+ : public BufferizableOpInterface::ExternalModel<
+ PerformConcurrentlyOpInterface, PerformConcurrentlyOp> {
+ LogicalResult bufferize(Operation *op, RewriterBase &b,
+ BufferizationState &state) const {
+ assert(false && "op does not have any tensor OpOperands / OpResults");
+ return failure();
+ }
+};
+
+/// Return true if the (ExtractSliceOp, ParallelInsertSliceOp) pair match (i.e.
+/// equivalent operand / result and same offset/sizes/strides specification).
+static bool areEquivalentExtractSliceOps(const AnalysisState &state,
+ ExtractSliceOp st,
+ ParallelInsertSliceOp sti) {
+ if (!st || !sti)
+ return false;
+ if (st != sti &&
+ !state.areEquivalentBufferizedValues(st.source(), sti.getDest()))
+ return false;
+ if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue))
+ return false;
+ return true;
+}
+
+/// Return true if `value` is originating from an ExtractSliceOp that matches
+/// the given InsertSliceOp.
+static bool hasMatchingExtractSliceOp(const AnalysisState &state, Value value,
+ ParallelInsertSliceOp insertOp) {
+ auto condition = [&](Value val) {
+ if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
+ if (areEquivalentExtractSliceOps(state, extractOp, insertOp))
+ return true;
+ return false;
+ };
+
+ return llvm::all_of(state.findValueInReverseUseDefChain(value, condition),
+ condition);
+}
+
+/// Analysis of ParallelInsertSliceOp.
+struct ParallelInsertSliceOpInterface
+ : public BufferizableOpInterface::ExternalModel<
+ ParallelInsertSliceOpInterface, ParallelInsertSliceOp> {
+ SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
+ const AnalysisState &state) const {
+ if (&opOperand != &op->getOpOperand(1) /*dest*/)
+ return {};
+
+ // ParallelInsertSliceOp itself has no results. Tensors are returned via
+ // the parent op.
+ auto foreachThreadOp = op->getParentOfType<ForeachThreadOp>();
+ assert(foreachThreadOp &&
+ "could not find valid owner of parallel_insert_slice");
+
+ // The i-th ParallelInsertSliceOp result is returned via the i-th OpResult
+ // of the parent ForeachThreadOp.
+ Block *block = op->getBlock();
+ unsigned int opIdx = 0;
+ for (ParallelInsertSliceOp insertOp :
+ block->getOps<ParallelInsertSliceOp>()) {
+ if (insertOp.getOperation() == op)
+ break;
+ ++opIdx;
+ }
+ assert(opIdx < foreachThreadOp->getNumResults() &&
+ "could not find op inside terminator op");
+
+ return {foreachThreadOp->getResult(opIdx)};
+ }
+
+ bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+ const AnalysisState &state) const {
+ return true;
+ }
+
+ bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+ const AnalysisState &state) const {
+ return &opOperand == &op->getOpOperand(1) /*dest*/;
+ }
+
+ BufferRelation bufferRelation(Operation *op, OpResult opResult,
+ const AnalysisState &state) const {
+ return BufferRelation::Equivalent;
+ }
+
+ LogicalResult bufferize(Operation *op, RewriterBase &b,
+ BufferizationState &state) const {
+ // Will be bufferized as part of ForeachThreadOp.
+ return failure();
+ }
+
+ // TODO: This is copied from TensorInterfaceImpl.cpp. Find a way to share
+ // the code.
+ bool isNotConflicting(Operation *op, OpOperand *uRead,
+ OpOperand *uConflictingWrite,
+ const AnalysisState &state) const {
+ Operation *readingOp = uRead->getOwner();
+ Operation *conflictingWritingOp = uConflictingWrite->getOwner();
+
+ // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
+ // uRead is an InsertSliceOp...
+ if (auto insertSliceOp = dyn_cast<ParallelInsertSliceOp>(readingOp)) {
+ // As an example, consider the following IR.
+ //
+ // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
+ // %1 = linalg.fill %cst, %0 {inplace= [true] }
+ // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
+ // {inplace= [true] }
+
+ // TODO: Use insertSliceOp.getDestOpOperand etc. when available.
+ if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
+ hasMatchingExtractSliceOp(state, uConflictingWrite->get(),
+ insertSliceOp))
+ // Case 1: The main insight is that InsertSliceOp reads only part of
+ // the destination tensor. The overwritten area is not read. If
+ // uConflictingWrite writes into exactly the memory location that is
+ // being read by uRead, this is not a conflict.
+ //
+ // In the above example:
+ // uRead = OpOperand 1 (%t) of tensor.insert_slice
+ // uConflictingWrite = OpOperand 1 (%0) of linalg.fill
+ //
+ // The read of %t does not conflict with the write of the FillOp
+ // (same aliases!) because the area that the FillOp operates on is
+ // exactly the one that is *not* read via %t.
+ return true;
+
+ if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
+ uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
+ hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp))
+ // Case 2: The read of the source tensor and the write to the dest
+ // tensor via an InsertSliceOp is not a conflict if the read is
+ // reading exactly that part of an equivalent tensor that the
+ // InsertSliceOp is writing.
+ //
+ // In the above example:
+ // uRead = OpOperand 0 (%1) of tensor.insert_slice
+ // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
+ return true;
+ }
+
+ // If uConflictingWrite is an InsertSliceOp...
+ if (auto insertSliceOp =
+ dyn_cast<ParallelInsertSliceOp>(conflictingWritingOp))
+ // As an example, consider the following IR.
+ //
+ // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
+ // %1 = linalg.fill %cst, %0 {inplace= [true] }
+ // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
+ // {inplace= [true] }
+ // %3 = vector.transfer_read %1, %cst
+ //
+ // In the above example:
+ // uRead = OpOperand 0 (%1) of vector.transfer_read
+ // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
+ // lastWrite = %1
+ //
+ // This is not a conflict because the InsertSliceOp overwrites the
+ // memory segment of %1 with the exact same data. (Effectively, there
+ // is no memory write here.)
+ if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
+ state.areEquivalentBufferizedValues(uRead->get(),
+ insertSliceOp.getSource()) &&
+ hasMatchingExtractSliceOp(state, insertSliceOp.getSource(),
+ insertSliceOp))
+ return true;
+
+ return false;
+ }
+};
+
} // namespace
} // namespace scf
} // namespace mlir
@@ -822,6 +1106,11 @@ void mlir::scf::registerBufferizableOpInterfaceExternalModels(
ExecuteRegionOp::attachInterface<ExecuteRegionOpInterface>(*ctx);
ForOp::attachInterface<ForOpInterface>(*ctx);
IfOp::attachInterface<IfOpInterface>(*ctx);
+ ForeachThreadOp::attachInterface<ForeachThreadOpInterface>(*ctx);
+ ParallelInsertSliceOp::attachInterface<ParallelInsertSliceOpInterface>(
+ *ctx);
+ PerformConcurrentlyOp::attachInterface<PerformConcurrentlyOpInterface>(
+ *ctx);
WhileOp::attachInterface<WhileOpInterface>(*ctx);
YieldOp::attachInterface<YieldOpInterface>(*ctx);
});
diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
index 888eea82bbf7..00c977d52b99 100644
--- a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
@@ -486,3 +486,127 @@ func.func @scf_while_iter_arg_result_mismatch(%arg0: tensor<5xi1>,
}
return
}
+
+// -----
+
+// CHECK-LABEL: func.func @parallel_insert_slice_no_conflict(
+// CHECK-SAME: %[[idx:.*]]: index, %[[idx2:.*]]: index,
+// CHECK-SAME: %[[arg1:.*]]: memref<?xf32, #{{.*}}>,
+// CHECK-SAME: %[[arg2:.*]]: memref<?xf32, #{{.*}}>
+func.func @parallel_insert_slice_no_conflict(
+ %idx: index,
+ %idx2: index,
+ %arg1: tensor<?xf32> {bufferization.writable = true},
+ %arg2: tensor<?xf32> {bufferization.writable = true}) -> (tensor<?xf32>, f32) {
+ %cst = arith.constant 4.200000e+01 : f32
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+
+ // CHECK: scf.foreach_thread (%[[tidx:.*]]) in (%[[idx2]]) -> ()
+ %2 = scf.foreach_thread (%arg3) in (%idx2) -> (tensor<?xf32>) {
+ // CHECK: %[[subview:.*]] = memref.subview %[[arg2]][5] [%[[idx]]] [1]
+ %6 = tensor.extract_slice %arg2[5] [%idx] [%c1] : tensor<?xf32> to tensor<?xf32>
+ // CHECK: linalg.fill ins(%{{.*}}) outs(%[[subview]] : memref<?xf32
+ %8 = linalg.fill ins(%cst : f32) outs(%6 : tensor<?xf32>) -> tensor<?xf32>
+ // Self-copy will DCE away later.
+ // CHECK: memref.copy %[[subview]], %[[subview]]
+
+ // Empty terminator is elided from pretty-printing.
+ // CHECK-NOT: scf.foreach_thread.perform_concurrently
+ // CHECK-NOT: parallel_insert_slice
+ scf.foreach_thread.perform_concurrently {
+ scf.foreach_thread.parallel_insert_slice %8 into %arg2[5] [%idx] [%c1] :
+ tensor<?xf32> into tensor<?xf32>
+ }
+ }
+
+ // CHECK: %[[load:.*]] = memref.load %[[arg2]]
+ %f = tensor.extract %2[%c0] : tensor<?xf32>
+
+ // CHECK: return %[[load]] : f32
+ return %2, %f : tensor<?xf32>, f32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @parallel_insert_slice_with_conflict(
+// CHECK-SAME: %[[idx:.*]]: index, %[[idx2:.*]]: index,
+// CHECK-SAME: %[[arg1:.*]]: memref<?xf32, #{{.*}}>,
+// CHECK-SAME: %[[arg2:.*]]: memref<?xf32, #{{.*}}>
+func.func @parallel_insert_slice_with_conflict(
+ %idx: index,
+ %idx2: index,
+ %arg1: tensor<?xf32> {bufferization.writable = true},
+ %arg2: tensor<?xf32> {bufferization.writable = true}) -> (f32, f32)
+{
+ %cst = arith.constant 4.200000e+01 : f32
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+
+ // The parallel_insert_slice_op bufferizes out-of-place due to a RAW conflict
+ // on %arg2, so we need an allocation.
+ // CHECK: %[[alloc1:.*]] = memref.alloc
+ // CHECK: memref.copy %[[arg2]], %[[alloc1]]
+
+ // CHECK: scf.foreach_thread (%[[tidx:.*]]) in (%[[idx2]]) -> ()
+ %2 = scf.foreach_thread (%arg3) in (%idx2) -> (tensor<?xf32>) {
+ // Another alloc for the extract_slice op.
+ // CHECK: %[[alloc2:.*]] = memref.alloc
+ %6 = tensor.extract_slice %arg2[5] [%idx] [%c1] : tensor<?xf32> to tensor<?xf32>
+
+ // CHECK: linalg.fill ins(%{{.*}}) outs(%[[alloc2]] : memref<?xf32
+ %8 = linalg.fill ins(%cst : f32) outs(%6 : tensor<?xf32>) -> tensor<?xf32>
+
+ // Now the copy of the actual insert_slice.
+ // CHECK: %[[subview1:.*]] = memref.subview %[[alloc1]][5] [%[[idx]]] [1]
+ //
+ // CHECK: memref.copy %[[alloc2]], %[[subview1]]
+ // CHECK: memref.dealloc %[[alloc2]]
+
+ // Empty terminator is elided from pretty-printing.
+ // CHECK-NOT: scf.foreach_thread.perform_concurrently
+ // CHECK-NOT: parallel_insert_slice
+ scf.foreach_thread.perform_concurrently {
+ scf.foreach_thread.parallel_insert_slice %8 into %arg2[5] [%idx] [%c1] :
+ tensor<?xf32> into tensor<?xf32>
+ }
+ }
+
+ // CHECK: %[[load:.*]] = memref.load %[[arg2]]
+ // CHECK: %[[load2:.*]] = memref.load %[[alloc1]]
+ // CHECK: memref.dealloc %[[alloc1]]
+ %f = tensor.extract %arg2[%c0] : tensor<?xf32>
+ %f2 = tensor.extract %2[%c0] : tensor<?xf32>
+
+ // CHECK: return %[[load2]], %[[load]] : f32, f32
+ return %f2, %f : f32, f32
+}
+
+// -----
+
+#map0 = affine_map<(d0) -> (d0 * 4)>
+#map1 = affine_map<(d0) -> (d0 * 2)>
+
+// CHECK: #[[$DYN_LAYOUT_MAP:.*]] = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>
+
+// CHECK-LABEL: func.func @matmul
+func.func @matmul(%arg0: tensor<8x8xf32>, %arg1: tensor<8x8xf32>, %arg2: tensor<8x8xf32> {bufferization.writable = true}) -> tensor<8x8xf32> {
+ %c2 = arith.constant 2 : index
+ %c4 = arith.constant 4 : index
+
+ // CHECK: scf.foreach_thread {{.*}} -> ()
+ %0 = scf.foreach_thread (%arg3, %arg4) in (%c2, %c4) -> (tensor<8x8xf32>) {
+ %1 = affine.apply #map0(%arg3)
+ %3 = tensor.extract_slice %arg0[%1, 0] [4, 8] [1, 1] : tensor<8x8xf32> to tensor<4x8xf32>
+ %4 = affine.apply #map1(%arg4)
+ %6 = tensor.extract_slice %arg1[0, %4] [8, 4] [1, 1] : tensor<8x8xf32> to tensor<8x4xf32>
+ %7 = tensor.extract_slice %arg2[%1, %4] [4, 4] [1, 1] : tensor<8x8xf32> to tensor<4x4xf32>
+
+ // CHECK: linalg.matmul ins({{.*}}memref<4x8xf32, #[[$DYN_LAYOUT_MAP]]>, memref<8x4xf32, #[[$DYN_LAYOUT_MAP]]>) outs({{.*}} : memref<4x4xf32, #[[$DYN_LAYOUT_MAP]]>)
+ %8 = linalg.matmul ins(%3, %6 : tensor<4x8xf32>, tensor<8x4xf32>) outs(%7 : tensor<4x4xf32>) -> tensor<4x4xf32>
+ scf.foreach_thread.perform_concurrently {
+ scf.foreach_thread.parallel_insert_slice %8 into %arg2[%1, %4] [4, 4] [1, 1] : tensor<4x4xf32> into tensor<8x8xf32>
+ }
+ }
+ return %0 : tensor<8x8xf32>
+}
More information about the Mlir-commits
mailing list