[Mlir-commits] [mlir] 310deca - [mlir] Add loop bounds to scf.foreach_thread.
Alexander Belyaev
llvmlistbot at llvm.org
Thu Feb 16 23:58:03 PST 2023
Author: Alexander Belyaev
Date: 2023-02-17T08:57:52+01:00
New Revision: 310deca248c84a9d8e529654437797327391fdc1
URL: https://github.com/llvm/llvm-project/commit/310deca248c84a9d8e529654437797327391fdc1
DIFF: https://github.com/llvm/llvm-project/commit/310deca248c84a9d8e529654437797327391fdc1.diff
LOG: [mlir] Add loop bounds to scf.foreach_thread.
https://discourse.llvm.org/t/rfc-parallel-loops-on-tensors-in-mlir/68332
Differential Revision: https://reviews.llvm.org/D144072
Added:
Modified:
mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
mlir/include/mlir/Interfaces/ViewLikeInterface.h
mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
mlir/lib/Dialect/SCF/IR/SCF.cpp
mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
mlir/lib/Interfaces/ViewLikeInterface.cpp
mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir
mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
mlir/test/Dialect/SCF/ops.mlir
mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir
mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 24c37272d641e..6d627635bfe8e 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -353,16 +353,16 @@ def ForOp : SCF_Op<"for",
def ForeachThreadOp : SCF_Op<"foreach_thread", [
AttrSizedOperandSegments,
- SingleBlockImplicitTerminator<"scf::PerformConcurrentlyOp">,
- RecursiveMemoryEffects,
AutomaticAllocationScope,
- ]> {
+ RecursiveMemoryEffects,
+ SingleBlockImplicitTerminator<"scf::PerformConcurrentlyOp">,
+ ]> {
let summary = "evaluate a block multiple times in parallel";
let description = [{
`scf.foreach_thread` is a target-independent multi-dimensional parallel
region application operation. It has exactly one block that represents the
- parallel body and it takes index operands that indicate how many parallel
- instances of that function are created.
+ parallel body and it takes index operands that specify lower bounds, upper
+ bounds and steps.
The op also takes a variadic number of tensor operands (`shared_outs`).
The future buffers corresponding to these tensors are shared among all
@@ -404,7 +404,12 @@ def ForeachThreadOp : SCF_Op<"foreach_thread", [
When the parallel function body has side effects, their order is unspecified
across threads.
- Example:
+ `scf.foreach_thread` can be printed in two
diff erent ways depending on
+ whether the loop is normalized or not. The loop is 'normalized' when all
+ lower bounds are equal to zero and steps are equal to one. In that case,
+ `lowerBound` and `step` operands will be omitted during printing.
+
+ Normalized loop example:
```mlir
//
@@ -442,6 +447,38 @@ def ForeachThreadOp : SCF_Op<"foreach_thread", [
//
```
+ Loop with loop bounds example:
+
+ ```mlir
+ //
+ // Sequential context.
+ //
+ %pointwise = scf.foreach_thread (%i, %j) = (0, 0) to (%dim1, %dim2)
+ step (%tileSize1, %tileSize2) shared_outs(%o1 = %out)
+ -> (tensor<?x?xT>, tensor<?xT>) {
+ //
+ // Parallel context.
+ //
+ %sA = tensor.extract_slice %A[%i, %j][%tileSize1, %tileSize2][1, 1]
+ : tensor<?x?xT> to tensor<?x?xT>
+ %sB = tensor.extract_slice %B[%i, %j][%tileSize1, %tileSize2][1, 1]
+ : tensor<?x?xT> to tensor<?x?xT>
+ %sC = tensor.extract_slice %o[%i, %j][%tileSize1, %tileSize2][1, 1]
+ : tensor<?x?xT> to tensor<?x?xT>
+
+ %add = map {"arith.addf"} ins(%sA, %sB) outs(%sC)
+
+ scf.foreach_thread.perform_concurrently {
+ scf.foreach_thread.parallel_insert_slice %add into
+ %o[%i, %j][%tileSize1, %tileSize2][1, 1]
+ : tensor<?x?xT> into tensor<?x?xT>
+ }
+ }
+ // Implicit synchronization point.
+ // Sequential context.
+ //
+ ```
+
Example with mapping attribute:
```mlir
@@ -481,9 +518,15 @@ def ForeachThreadOp : SCF_Op<"foreach_thread", [
}
```
}];
- let arguments = (ins Variadic<Index>:$num_threads,
- Variadic<AnyRankedTensor>:$outputs,
- OptionalAttr<DeviceMappingArrayAttr>:$mapping);
+ let arguments = (ins
+ Variadic<Index>:$dynamicLowerBound,
+ Variadic<Index>:$dynamicUpperBound,
+ Variadic<Index>:$dynamicStep,
+ DenseI64ArrayAttr:$staticLowerBound,
+ DenseI64ArrayAttr:$staticUpperBound,
+ DenseI64ArrayAttr:$staticStep,
+ Variadic<AnyRankedTensor>:$outputs,
+ OptionalAttr<DeviceMappingArrayAttr>:$mapping);
let results = (outs Variadic<AnyType>:$results);
let regions = (region SizedRegion<1>:$region);
@@ -495,58 +538,114 @@ def ForeachThreadOp : SCF_Op<"foreach_thread", [
// The default builder does not add the proper body BBargs, roll our own.
let skipDefaultBuilders = 1;
let builders = [
- // Bodyless builder, outputs must be specified.
- OpBuilder<(ins "ValueRange":$outputs, "ValueRange":$num_threads,
- "std::optional<ArrayAttr>":$mapping)>,
- // Builder that takes a bodyBuilder lambda.
- OpBuilder<(ins "ValueRange":$outputs, "ValueRange":$num_threads,
- "ArrayRef<Attribute>":$mapping,
- "function_ref<void(OpBuilder &, Location, ValueRange)>":$bodyBuilder)>
+ // Builder that takes loop bounds.
+ OpBuilder<(ins "ArrayRef<OpFoldResult>":$lbs,
+ "ArrayRef<OpFoldResult>":$ubs, "ArrayRef<OpFoldResult>":$steps,
+ "ValueRange":$outputs, "std::optional<ArrayAttr>":$mapping,
+ CArg<"function_ref<void(OpBuilder &, Location, ValueRange)>",
+ "nullptr"> :$bodyBuilderFn)>,
+
+ // Builder for normalized loop that takes only upper bounds.
+ OpBuilder<(ins "ArrayRef<OpFoldResult>":$ubs,
+ "ValueRange":$outputs, "std::optional<ArrayAttr>":$mapping,
+ CArg<"function_ref<void(OpBuilder &, Location, ValueRange)>",
+ "nullptr"> :$bodyBuilderFn)>,
];
+
let extraClassDeclaration = [{
- int64_t getRank() { return getNumThreads().size(); }
+ // Get lower bounds as OpFoldResult.
+ SmallVector<OpFoldResult> getMixedLowerBound() {
+ Builder b(getOperation()->getContext());
+ return getMixedValues(getStaticLowerBound(), getDynamicLowerBound(), b);
+ }
+
+ // Get upper bounds as OpFoldResult.
+ SmallVector<OpFoldResult> getMixedUpperBound() {
+ Builder b(getOperation()->getContext());
+ return getMixedValues(getStaticUpperBound(), getDynamicUpperBound(), b);
+ }
+
+ // Get steps as OpFoldResult.
+ SmallVector<OpFoldResult> getMixedStep() {
+ Builder b(getOperation()->getContext());
+ return getMixedValues(getStaticStep(), getDynamicStep(), b);
+ }
+
+ /// Get lower bounds as values.
+ SmallVector<Value> getLowerBound(OpBuilder &b) {
+ return getAsValues(b, getLoc(), getMixedLowerBound());
+ }
+
+ /// Get upper bounds as values.
+ SmallVector<Value> getUpperBound(OpBuilder &b) {
+ return getAsValues(b, getLoc(), getMixedUpperBound());
+ }
+
+ /// Get steps as values.
+ SmallVector<Value> getStep(OpBuilder &b) {
+ return getAsValues(b, getLoc(), getMixedStep());
+ }
+
+ int64_t getRank() { return getStaticLowerBound().size(); }
+
+ /// Number of operands controlling the loop: lbs, ubs, steps
+ unsigned getNumControlOperands() { return 3 * getRank(); }
+
+ /// Number of dynamic operands controlling the loop: lbs, ubs, steps
+ unsigned getNumDynamicControlOperands() {
+ return getODSOperandIndexAndLength(3).first;
+ }
OpResult getTiedOpResult(OpOperand *opOperand) {
- assert(opOperand->getOperandNumber() >= getRank() && "invalid operand");
+ assert(opOperand->getOperandNumber() >= getNumDynamicControlOperands() &&
+ "invalid operand");
return getOperation()->getOpResult(
- opOperand->getOperandNumber() - getRank());
+ opOperand->getOperandNumber() - getNumDynamicControlOperands());
}
/// Return the num_threads operand that is tied to the given thread id
/// block argument.
OpOperand *getTiedOpOperand(BlockArgument bbArg) {
assert(bbArg.getArgNumber() >= getRank() && "invalid bbArg");
- return &getOperation()->getOpOperand(bbArg.getArgNumber());
+
+ return &getOperation()->getOpOperand(getNumDynamicControlOperands() +
+ bbArg.getArgNumber() - getRank());
}
/// Return the shared_outs operand that is tied to the given OpResult.
OpOperand *getTiedOpOperand(OpResult opResult) {
assert(opResult.getDefiningOp() == getOperation() && "invalid OpResult");
- return &getOperation()->getOpOperand(
- opResult.getResultNumber() + getRank());
+ return &getOperation()->getOpOperand(getNumDynamicControlOperands() +
+ opResult.getResultNumber());
}
BlockArgument getTiedBlockArgument(OpOperand *opOperand) {
- assert(opOperand->getOperandNumber() >= getRank() && "invalid operand");
- return getBody()->getArgument(opOperand->getOperandNumber());
+ assert(opOperand->getOperandNumber() >= getNumDynamicControlOperands() &&
+ "invalid operand");
+
+ return getBody()->getArgument(opOperand->getOperandNumber() -
+ getNumDynamicControlOperands() + getRank());
}
ArrayRef<BlockArgument> getOutputBlockArguments() {
return getBody()->getArguments().drop_front(getRank());
}
- ::mlir::ValueRange getThreadIndices() {
+ ::mlir::ValueRange getInductionVars() {
return getBody()->getArguments().take_front(getRank());
}
- ::mlir::Value getThreadIndex(int64_t idx) {
- return getThreadIndices()[idx];
+ ::mlir::Value getInductionVar(int64_t idx) {
+ return getInductionVars()[idx];
}
::mlir::Block::BlockArgListType getRegionOutArgs() {
return getBody()->getArguments().drop_front(getRank());
}
+ /// Checks if the lbs are zeros and steps are ones.
+ bool isNormalized();
+
/// Helper to sort `values` according to matching `keys`.
/// Take a custom `compare` binary comparator which returns true if the first
/// element is smaller than the second (i.e. compatible with std::sort).
@@ -559,7 +658,8 @@ def ForeachThreadOp : SCF_Op<"foreach_thread", [
// The ensureTerminator method generated by SingleBlockImplicitTerminator is
// unaware of the fact that our terminator also needs a region to be
// well-formed. We override it here to ensure that we do the right thing.
- static void ensureTerminator(Region ®ion, OpBuilder &builder, Location loc);
+ static void ensureTerminator(Region & region, OpBuilder & builder,
+ Location loc);
PerformConcurrentlyOp getTerminator();
}];
diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
index f950933b23c7a..5843ecd061df6 100644
--- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h
+++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
@@ -48,8 +48,10 @@ namespace mlir {
/// in `integers` is `dynVal` or (2) the next value otherwise. This allows
/// idiomatic printing of mixed value and integer attributes in a list. E.g.
/// `[%arg0, 7, 42, %arg42]`.
-void printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
- OperandRange values, ArrayRef<int64_t> integers);
+void printDynamicIndexList(
+ OpAsmPrinter &printer, Operation *op, OperandRange values,
+ ArrayRef<int64_t> integers,
+ AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square);
/// Pasrer hook for custom directive in assemblyFormat.
///
@@ -64,10 +66,11 @@ void printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
/// E.g. after parsing "[%arg0, 7, 42, %arg42]":
/// 1. `result` is filled with the i64 ArrayAttr "[`dynVal`, 7, 42, `dynVal`]"
/// 2. `ssa` is filled with "[%arg0, %arg1]".
-ParseResult
-parseDynamicIndexList(OpAsmParser &parser,
- SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
- DenseI64ArrayAttr &integers);
+ParseResult parseDynamicIndexList(
+ OpAsmParser &parser,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
+ DenseI64ArrayAttr &integers,
+ AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square);
/// Verify that a the `values` has as many elements as the number of entries in
/// `attr` for which `isDynamic` evaluates to true.
diff --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
index 90f04f3713cf3..e971b764c6f05 100644
--- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
+++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
@@ -180,16 +180,19 @@ DiagnosedSilenceableFailure mlir::transform::gpu::mapForeachToBlocksImpl(
// transform op does not apply to individual ForeachThreadOp.
Location loc = foreachThreadOp->getLoc();
+ if (!foreachThreadOp.isNormalized())
+ return transformOp.emitSilenceableError()
+ << "unsupported non-normalized loops";
if (foreachThreadOp.getNumResults() > 0)
return transformOp.emitSilenceableError()
<< "only bufferized scf.foreach_thread lowers to "
"gpu.block_id";
- if (foreachThreadOp.getNumThreads().size() > 3)
+ if (foreachThreadOp.getRank() > 3)
return transformOp.emitSilenceableError()
<< "scf.foreach_thread with rank > 3 does not lower to "
"gpu.block_id";
- if (llvm::any_of(foreachThreadOp.getNumThreads(), [](Value v) {
- return !v.getDefiningOp<arith::ConstantIndexOp>();
+ if (llvm::any_of(foreachThreadOp.getMixedUpperBound(), [](OpFoldResult ofr) {
+ return !getConstantIntValue(ofr).has_value();
})) {
return transformOp.emitSilenceableError()
<< "unsupported dynamic griddim size";
@@ -198,8 +201,7 @@ DiagnosedSilenceableFailure mlir::transform::gpu::mapForeachToBlocksImpl(
llvm::to_vector(foreachThreadOp.getMapping()->getValue());
// Step 1. Complete the blockMapping to a full mapping (with 1s) if necessary.
- SmallVector<Value> numBlocks =
- llvm::to_vector(foreachThreadOp.getNumThreads());
+ SmallVector<Value> numBlocks = foreachThreadOp.getUpperBound(rewriter);
// Ensure we have 3 block sizes, one for each id.
Value one;
for (auto attr : mappingAttributes) {
@@ -227,7 +229,7 @@ DiagnosedSilenceableFailure mlir::transform::gpu::mapForeachToBlocksImpl(
blockIdGenerator(rewriter, foreachThreadOp, blockOps);
IRMapping bvm;
for (auto [blockIdx, blockDim] :
- llvm::zip(foreachThreadOp.getThreadIndices(), blockMapping)) {
+ llvm::zip(foreachThreadOp.getInductionVars(), blockMapping)) {
bvm.map(blockIdx,
blockOps[static_cast<int64_t>(
blockDim.cast<DeviceMappingAttrInterface>().getMappingId())]);
@@ -243,7 +245,7 @@ DiagnosedSilenceableFailure mlir::transform::gpu::mapForeachToBlocksImpl(
sourceBlock.getOperations());
// Step 5. RAUW thread indices to thread ops.
- for (Value loopIndex : foreachThreadOp.getThreadIndices()) {
+ for (Value loopIndex : foreachThreadOp.getInductionVars()) {
Value blockIdx = bvm.lookup(loopIndex);
rewriter.replaceAllUsesWith(loopIndex, blockIdx);
}
@@ -381,14 +383,16 @@ static DiagnosedSilenceableFailure rewriteOneForeachThreadToGpuThreads(
return emitDefiniteFailure(foreachThreadOp, message);
};
Location loc = foreachThreadOp->getLoc();
+ if (!foreachThreadOp.isNormalized())
+ return failureHelper("unsupported non-normalized loops");
if (foreachThreadOp.getNumResults() > 0)
return failureHelper(
"only bufferized scf.foreach_thread lowers to gpu.thread_id");
- if (foreachThreadOp.getNumThreads().size() > 3)
+ if (foreachThreadOp.getRank() > 3)
return failureHelper(
"scf.foreach_thread with rank > 3 does not lower to gpu.thread_id");
- if (llvm::any_of(foreachThreadOp.getNumThreads(), [](Value v) {
- return !v.getDefiningOp<arith::ConstantIndexOp>();
+ if (llvm::any_of(foreachThreadOp.getMixedUpperBound(), [](OpFoldResult ofr) {
+ return !getConstantIntValue(ofr).has_value();
})) {
return failureHelper("unsupported dynamic blockdim size");
}
@@ -399,8 +403,7 @@ static DiagnosedSilenceableFailure rewriteOneForeachThreadToGpuThreads(
// Step 1. Complete the threadMapping to a full mapping (with 1s) if
// necessary.
- SmallVector<Value> numThreads =
- llvm::to_vector(foreachThreadOp.getNumThreads());
+ SmallVector<Value> numThreads = foreachThreadOp.getUpperBound(rewriter);
// Ensure we have 3 block sizes, one for each id.
Value one;
for (auto attr : threadMappingAttributes) {
@@ -437,7 +440,7 @@ static DiagnosedSilenceableFailure rewriteOneForeachThreadToGpuThreads(
}
IRMapping bvm;
for (auto [blockIdx, blockDim] :
- llvm::zip(foreachThreadOp.getThreadIndices(), threadMapping)) {
+ llvm::zip(foreachThreadOp.getInductionVars(), threadMapping)) {
bvm.map(blockIdx,
threadOpsUpdated[blockDim.cast<DeviceMappingAttrInterface>()
.getMappingId()]);
@@ -484,7 +487,7 @@ static DiagnosedSilenceableFailure rewriteOneForeachThreadToGpuThreads(
sourceBlock.getOperations());
// Step 6. RAUW thread indices to thread ops.
- for (Value loopIndex : foreachThreadOp.getThreadIndices()) {
+ for (Value loopIndex : foreachThreadOp.getInductionVars()) {
Value threadIdx = bvm.lookup(loopIndex);
rewriter.replaceAllUsesWith(loopIndex, threadIdx);
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index 7303bd738906c..10a1451fb7cc3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -253,7 +253,7 @@ static void calculateTileOffsetsAndSizes(
OpBuilder::InsertionGuard g(b);
b.setInsertionPointToStart(foreachThreadOp.getBody(0));
- ValueRange threadIds = foreachThreadOp.getThreadIndices();
+ ValueRange threadIds = foreachThreadOp.getInductionVars();
SmallVector<OpFoldResult> nonZeroNumThreads =
llvm::to_vector(llvm::make_filter_range(numThreads, [](OpFoldResult ofr) {
return !isConstantIntValue(ofr, 0);
@@ -360,7 +360,7 @@ static FailureOr<ForeachThreadTilingResult> tileToForeachThreadOpImpl(
// version because we require the use of RewriterBase in the body, so we
// manually move the insertion point to the body below.
scf::ForeachThreadOp foreachThreadOp = b.create<scf::ForeachThreadOp>(
- loc, dest, ValueRange(materializedNonZeroNumThreads), mapping);
+ loc, getAsOpFoldResult((materializedNonZeroNumThreads)), dest, mapping);
// 2. Fill out the ForeachThreadOp body.
SmallVector<OpFoldResult> tiledOffsets, tiledSizes;
@@ -681,8 +681,8 @@ linalg::tileReductionUsingForeachThread(RewriterBase &b,
// 2. Create the ForeachThreadOp with an empty region.
scf::ForeachThreadOp foreachThreadOp = b.create<scf::ForeachThreadOp>(
- loc, (*identityTensor)->getResults(),
- ValueRange(materializedNonZeroNumThreads), mapping);
+ loc, getAsOpFoldResult(materializedNonZeroNumThreads),
+ (*identityTensor)->getResults(), mapping);
// 3. Calculate the tile offsets and sizes for the subsequent loop that will
// be nested under `foreachThreadOp`.
@@ -712,7 +712,7 @@ linalg::tileReductionUsingForeachThread(RewriterBase &b,
b.getIndexAttr(0));
SmallVector<OpFoldResult> sizes = tiledSizes;
sizes[reductionDim] = b.getIndexAttr(1);
- outOffsets[reductionDim] = foreachThreadOp.getThreadIndices().front();
+ outOffsets[reductionDim] = foreachThreadOp.getInductionVars().front();
// TODO: use SubsetExtractOpInterface once it is available.
tiledDpsInitOperands.push_back(b.create<tensor::ExtractSliceOp>(
loc, initOperand->get().getType().cast<RankedTensorType>(),
@@ -746,7 +746,7 @@ linalg::tileReductionUsingForeachThread(RewriterBase &b,
if (failed(maybeTiled))
return b.notifyMatchFailure(op, "failed tileLinalgOpImpl");
- SmallVector<Value> ids = foreachThreadOp.getThreadIndices();
+ SmallVector<Value> ids = foreachThreadOp.getInductionVars();
mapLoopToProcessorIds(cast<scf::ForOp>(maybeTiled->loops.back()), ids,
materializedNonZeroNumThreads);
assert(maybeTiled->loops.size() == 1 &&
@@ -774,7 +774,7 @@ linalg::tileReductionUsingForeachThread(RewriterBase &b,
int64_t sizeIdx = 0;
for (int64_t i = 0, e = numThreads.size(); i < e; ++i) {
if (i == reductionDim) {
- resultOffsetsRank.push_back(foreachThreadOp.getThreadIndices().front());
+ resultOffsetsRank.push_back(foreachThreadOp.getInductionVars().front());
resultSizesRank.push_back(b.getIndexAttr(1));
continue;
}
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 6a4da00bbad36..9032f533e5fbe 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -1110,6 +1110,7 @@ Speculation::Speculatability ForOp::getSpeculatability() {
//===----------------------------------------------------------------------===//
LogicalResult ForeachThreadOp::verify() {
+ unsigned numLoops = getRank();
// Check number of outputs.
if (getNumResults() != getOutputs().size())
return emitOpError("produces ")
@@ -1118,18 +1119,18 @@ LogicalResult ForeachThreadOp::verify() {
// Check that the body defines block arguments for thread indices and outputs.
auto *body = getBody();
- if (body->getNumArguments() != getRank() + getOutputs().size())
- return emitOpError("region expects ") << getRank() << " arguments";
- for (int64_t i = 0; i < getRank(); ++i)
+ if (body->getNumArguments() != numLoops + getOutputs().size())
+ return emitOpError("region expects ") << numLoops << " arguments";
+ for (int64_t i = 0; i < numLoops; ++i)
if (!body->getArgument(i).getType().isIndex())
return emitOpError("expects ")
<< i << "-th block argument to be an index";
for (unsigned i = 0; i < getOutputs().size(); ++i)
- if (body->getArgument(i + getRank()).getType() != getOutputs()[i].getType())
+ if (body->getArgument(i + numLoops).getType() != getOutputs()[i].getType())
return emitOpError("type mismatch between ")
<< i << "-th output and corresponding block argument";
if (getMapping().has_value() && !getMapping()->empty()) {
- if (static_cast<int64_t>(getMapping()->size()) != getRank())
+ if (static_cast<int64_t>(getMapping()->size()) != numLoops)
return emitOpError() << "mapping attribute size must match op rank";
for (auto map : getMapping()->getValue()) {
if (!isa<DeviceMappingAttrInterface>(map))
@@ -1138,15 +1139,41 @@ LogicalResult ForeachThreadOp::verify() {
}
}
+ // Verify mixed static/dynamic control variables.
+ Operation *op = getOperation();
+ if (failed(verifyListOfOperandsOrIntegers(op, "lower bound", numLoops,
+ getStaticLowerBound(),
+ getDynamicLowerBound())))
+ return failure();
+ if (failed(verifyListOfOperandsOrIntegers(op, "upper bound", numLoops,
+ getStaticUpperBound(),
+ getDynamicUpperBound())))
+ return failure();
+ if (failed(verifyListOfOperandsOrIntegers(op, "step", numLoops,
+ getStaticStep(), getDynamicStep())))
+ return failure();
+
return success();
}
void ForeachThreadOp::print(OpAsmPrinter &p) {
- p << " (";
- llvm::interleaveComma(getThreadIndices(), p);
- p << ") in (";
- llvm::interleaveComma(getNumThreads(), p);
- p << ")";
+ Operation *op = getOperation();
+ p << " (" << getInductionVars();
+ if (isNormalized()) {
+ p << ") in ";
+ printDynamicIndexList(p, op, getDynamicUpperBound(), getStaticUpperBound(),
+ OpAsmParser::Delimiter::Paren);
+ } else {
+ p << ") = ";
+ printDynamicIndexList(p, op, getDynamicLowerBound(), getStaticLowerBound(),
+ OpAsmParser::Delimiter::Paren);
+ p << " to ";
+ printDynamicIndexList(p, op, getDynamicUpperBound(), getStaticUpperBound(),
+ OpAsmParser::Delimiter::Paren);
+ p << " step ";
+ printDynamicIndexList(p, op, getDynamicStep(), getStaticStep(),
+ OpAsmParser::Delimiter::Paren);
+ }
printInitializationList(p, getRegionOutArgs(), getOutputs(), " shared_outs");
p << " ";
if (!getRegionOutArgs().empty())
@@ -1154,28 +1181,60 @@ void ForeachThreadOp::print(OpAsmPrinter &p) {
p.printRegion(getRegion(),
/*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/getNumResults() > 0);
- p.printOptionalAttrDict(getOperation()->getAttrs(),
- {"operand_segment_sizes"});
+ p.printOptionalAttrDict(op->getAttrs(), {getOperandSegmentSizesAttrName(),
+ getStaticLowerBoundAttrName(),
+ getStaticUpperBoundAttrName(),
+ getStaticStepAttrName()});
}
ParseResult ForeachThreadOp::parse(OpAsmParser &parser,
OperationState &result) {
- auto &builder = parser.getBuilder();
+ OpBuilder b(parser.getContext());
+ auto indexType = b.getIndexType();
+
// Parse an opening `(` followed by thread index variables followed by `)`
// TODO: when we can refer to such "induction variable"-like handles from the
// declarative assembly format, we can implement the parser as a custom hook.
- SmallVector<OpAsmParser::Argument, 4> threadIndices;
- if (parser.parseArgumentList(threadIndices, OpAsmParser::Delimiter::Paren))
+ SmallVector<OpAsmParser::Argument, 4> ivs;
+ if (parser.parseArgumentList(ivs, OpAsmParser::Delimiter::Paren))
return failure();
- // Parse `in` threadNums.
- SmallVector<OpAsmParser::UnresolvedOperand, 4> threadNums;
- if (parser.parseKeyword("in") ||
- parser.parseOperandList(threadNums, threadIndices.size(),
+ DenseI64ArrayAttr staticLbs, staticUbs, staticSteps;
+ SmallVector<OpAsmParser::UnresolvedOperand> dynamicLbs, dynamicUbs,
+ dynamicSteps;
+ if (succeeded(parser.parseOptionalKeyword("in"))) {
+ // Parse upper bounds.
+ if (parseDynamicIndexList(parser, dynamicUbs, staticUbs,
OpAsmParser::Delimiter::Paren) ||
- parser.resolveOperands(threadNums, builder.getIndexType(),
- result.operands))
- return failure();
+ parser.resolveOperands(dynamicUbs, indexType, result.operands))
+ return failure();
+
+ unsigned numLoops = ivs.size();
+ staticLbs = b.getDenseI64ArrayAttr(SmallVector<int64_t>(numLoops, 0));
+ staticSteps = b.getDenseI64ArrayAttr(SmallVector<int64_t>(numLoops, 1));
+ } else {
+ // Parse lower bounds.
+ if (parser.parseEqual() ||
+ parseDynamicIndexList(parser, dynamicLbs, staticLbs,
+ OpAsmParser::Delimiter::Paren) ||
+
+ parser.resolveOperands(dynamicLbs, indexType, result.operands))
+ return failure();
+
+ // Parse upper bounds.
+ if (parser.parseKeyword("to") ||
+ parseDynamicIndexList(parser, dynamicUbs, staticUbs,
+ OpAsmParser::Delimiter::Paren) ||
+ parser.resolveOperands(dynamicUbs, indexType, result.operands))
+ return failure();
+
+ // Parse step values.
+ if (parser.parseKeyword("step") ||
+ parseDynamicIndexList(parser, dynamicSteps, staticSteps,
+ OpAsmParser::Delimiter::Paren) ||
+ parser.resolveOperands(dynamicSteps, indexType, result.operands))
+ return failure();
+ }
// Parse out operands and results.
SmallVector<OpAsmParser::Argument, 4> regionOutArgs;
@@ -1195,9 +1254,9 @@ ParseResult ForeachThreadOp::parse(OpAsmParser &parser,
// Parse region.
SmallVector<OpAsmParser::Argument, 4> regionArgs;
std::unique_ptr<Region> region = std::make_unique<Region>();
- for (auto &idx : threadIndices) {
- idx.type = builder.getIndexType();
- regionArgs.push_back(idx);
+ for (auto &iv : ivs) {
+ iv.type = b.getIndexType();
+ regionArgs.push_back(iv);
}
for (const auto &it : llvm::enumerate(regionOutArgs)) {
auto &out = it.value();
@@ -1208,92 +1267,111 @@ ParseResult ForeachThreadOp::parse(OpAsmParser &parser,
return failure();
// Ensure terminator and move region.
- OpBuilder b(builder.getContext());
ForeachThreadOp::ensureTerminator(*region, b, result.location);
result.addRegion(std::move(region));
// Parse the optional attribute list.
if (parser.parseOptionalAttrDict(result.attributes))
return failure();
+
+ result.addAttribute("staticLowerBound", staticLbs);
+ result.addAttribute("staticUpperBound", staticUbs);
+ result.addAttribute("staticStep", staticSteps);
result.addAttribute("operand_segment_sizes",
parser.getBuilder().getDenseI32ArrayAttr(
- {static_cast<int32_t>(threadNums.size()),
+ {static_cast<int32_t>(dynamicLbs.size()),
+ static_cast<int32_t>(dynamicUbs.size()),
+ static_cast<int32_t>(dynamicSteps.size()),
static_cast<int32_t>(outOperands.size())}));
return success();
}
-// Bodyless builder, outputs must be specified.
-void ForeachThreadOp::build(mlir::OpBuilder &builder,
- mlir::OperationState &result, ValueRange outputs,
- ValueRange numThreads,
- std::optional<ArrayAttr> mapping) {
- result.addOperands(numThreads);
+// Builder that takes loop bounds.
+void ForeachThreadOp::build(
+ mlir::OpBuilder &b, mlir::OperationState &result,
+ ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs,
+ ArrayRef<OpFoldResult> steps, ValueRange outputs,
+ std::optional<ArrayAttr> mapping,
+ function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
+ SmallVector<int64_t> staticLbs, staticUbs, staticSteps;
+ SmallVector<Value> dynamicLbs, dynamicUbs, dynamicSteps;
+ dispatchIndexOpFoldResults(lbs, dynamicLbs, staticLbs);
+ dispatchIndexOpFoldResults(ubs, dynamicUbs, staticUbs);
+ dispatchIndexOpFoldResults(steps, dynamicSteps, staticSteps);
+
+ result.addOperands(dynamicLbs);
+ result.addOperands(dynamicUbs);
+ result.addOperands(dynamicSteps);
result.addOperands(outputs);
+ result.addTypes(TypeRange(outputs));
+
+ result.addAttribute(getStaticLowerBoundAttrName(result.name),
+ b.getDenseI64ArrayAttr(staticLbs));
+ result.addAttribute(getStaticUpperBoundAttrName(result.name),
+ b.getDenseI64ArrayAttr(staticUbs));
+ result.addAttribute(getStaticStepAttrName(result.name),
+ b.getDenseI64ArrayAttr(staticSteps));
+ result.addAttribute(
+ "operand_segment_sizes",
+ b.getDenseI32ArrayAttr({static_cast<int32_t>(dynamicLbs.size()),
+ static_cast<int32_t>(dynamicUbs.size()),
+ static_cast<int32_t>(dynamicSteps.size()),
+ static_cast<int32_t>(outputs.size())}));
if (mapping.has_value()) {
result.addAttribute(ForeachThreadOp::getMappingAttrName(result.name),
mapping.value());
}
- result.addAttribute(
- "operand_segment_sizes",
- builder.getDenseI32ArrayAttr({static_cast<int32_t>(numThreads.size()),
- static_cast<int32_t>(outputs.size())}));
- result.addTypes(TypeRange(outputs));
-
Region *bodyRegion = result.addRegion();
- OpBuilder::InsertionGuard g(builder);
- // createBlock sets the IP inside the block.
- // Generally we would guard against that but the default ensureTerminator impl
- // expects it ..
- builder.createBlock(bodyRegion);
+ OpBuilder::InsertionGuard g(b);
+ b.createBlock(bodyRegion);
Block &bodyBlock = bodyRegion->front();
- // Add block arguments for indices and outputs.
- bodyBlock.addArguments(
- SmallVector<Type>(numThreads.size(), builder.getIndexType()),
- SmallVector<Location>(numThreads.size(), result.location));
- bodyBlock.addArguments(
- TypeRange(outputs),
- SmallVector<Location>(outputs.size(), result.location));
- ForeachThreadOp::ensureTerminator(*bodyRegion, builder, result.location);
-}
-// Builder that takes a bodyBuilder lambda.
-void ForeachThreadOp::build(
- mlir::OpBuilder &builder, mlir::OperationState &result, ValueRange outputs,
- ValueRange numThreads, ArrayRef<Attribute> mapping,
- function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
- result.addOperands(numThreads);
- result.addOperands(outputs);
- result.addAttribute(ForeachThreadOp::getMappingAttrName(result.name),
- builder.getArrayAttr(mapping));
- result.addAttribute(
- "operand_segment_sizes",
- builder.getDenseI32ArrayAttr({static_cast<int32_t>(numThreads.size()),
- static_cast<int32_t>(outputs.size())}));
- result.addTypes(TypeRange(outputs));
-
- Region *bodyRegion = result.addRegion();
- OpBuilder::InsertionGuard g(builder);
- builder.createBlock(bodyRegion);
- Block &bodyBlock = bodyRegion->front();
// Add block arguments for indices and outputs.
bodyBlock.addArguments(
- SmallVector<Type>(numThreads.size(), builder.getIndexType()),
- SmallVector<Location>(numThreads.size(), result.location));
+ SmallVector<Type>(lbs.size(), b.getIndexType()),
+ SmallVector<Location>(staticLbs.size(), result.location));
bodyBlock.addArguments(
TypeRange(outputs),
SmallVector<Location>(outputs.size(), result.location));
- builder.setInsertionPointToStart(&bodyBlock);
- bodyBuilder(builder, result.location, bodyBlock.getArguments());
+ b.setInsertionPointToStart(&bodyBlock);
+ if (!bodyBuilderFn) {
+ ForeachThreadOp::ensureTerminator(*bodyRegion, b, result.location);
+ return;
+ }
+ bodyBuilderFn(b, result.location, bodyBlock.getArguments());
#ifndef NDEBUG
auto terminator =
llvm::dyn_cast<PerformConcurrentlyOp>(bodyBlock.getTerminator());
assert(terminator &&
- "expected bodyBuilder to create PerformConcurrentlyOp terminator");
+ "expected bodyBuilderFn to create PerformConcurrentlyOp terminator");
#endif // NDEBUG
}
+// Builder that takes loop bounds.
+void ForeachThreadOp::build(
+ mlir::OpBuilder &b, mlir::OperationState &result,
+ ArrayRef<OpFoldResult> ubs, ValueRange outputs,
+ std::optional<ArrayAttr> mapping,
+ function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
+ unsigned numLoops = ubs.size();
+ SmallVector<OpFoldResult> lbs(numLoops, b.getIndexAttr(0));
+ SmallVector<OpFoldResult> steps(numLoops, b.getIndexAttr(1));
+ build(b, result, lbs, ubs, steps, outputs, mapping, bodyBuilderFn);
+}
+
+// Checks if the lbs are zeros and steps are ones.
+bool ForeachThreadOp::isNormalized() {
+ auto allEqual = [](ArrayRef<OpFoldResult> results, int64_t val) {
+ return llvm::all_of(results, [&](OpFoldResult ofr) {
+ auto intValue = getConstantIntValue(ofr);
+ return intValue.has_value() && intValue == val;
+ });
+ };
+ return allEqual(getMixedLowerBound(), 0) && allEqual(getMixedStep(), 1);
+}
+
// The ensureTerminator method generated by SingleBlockImplicitTerminator is
// unaware of the fact that our terminator also needs a region to be
// well-formed. We override it here to ensure that we do the right thing.
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index e844f462389a9..d5c227967b36a 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -1004,15 +1004,14 @@ struct YieldOpInterface
/// Return `true` if the given loop may have 0 iterations.
bool mayHaveZeroIterations(scf::ForeachThreadOp foreachThreadOp) {
- int64_t p = 1;
- for (Value v : foreachThreadOp.getNumThreads()) {
- if (std::optional<int64_t> c = getConstantIntValue(v)) {
- p *= *c;
- } else {
+ for (auto [lb, ub] : llvm::zip(foreachThreadOp.getMixedLowerBound(),
+ foreachThreadOp.getMixedUpperBound())) {
+ std::optional<int64_t> lbConst = getConstantIntValue(lb);
+ std::optional<int64_t> ubConst = getConstantIntValue(ub);
+ if (!lbConst.has_value() || !ubConst.has_value() || *lbConst >= *ubConst)
return true;
- }
}
- return p == 0;
+ return false;
}
/// Bufferization of ForeachThreadOp. This also bufferizes the terminator of the
@@ -1087,8 +1086,9 @@ struct ForeachThreadOpInterface
rewriter.setInsertionPoint(foreachThreadOp);
ForeachThreadOp newForeachThreadOp;
newForeachThreadOp = rewriter.create<ForeachThreadOp>(
- foreachThreadOp.getLoc(), /*outputs=*/ValueRange(),
- foreachThreadOp.getNumThreads(), foreachThreadOp.getMapping());
+ foreachThreadOp.getLoc(), foreachThreadOp.getMixedLowerBound(),
+ foreachThreadOp.getMixedUpperBound(), foreachThreadOp.getMixedStep(),
+ /*outputs=*/ValueRange(), foreachThreadOp.getMapping());
newForeachThreadOp.getBody()->getTerminator()->erase();
@@ -1127,10 +1127,28 @@ struct ForeachThreadOpInterface
bool isRepetitiveRegion(Operation *op, unsigned index) const {
auto foreachThreadOp = cast<ForeachThreadOp>(op);
- // This op is not repetitive if it has just a single thread.
- return !llvm::all_of(foreachThreadOp.getNumThreads(), [](Value v) {
- return getConstantIntValue(v) == static_cast<int64_t>(1);
- });
+
+ // This op is repetitive if it has 1 or more steps.
+ // If the control variables are dynamic, it is also considered so.
+ for (auto [lb, ub, step] : llvm::zip(foreachThreadOp.getMixedLowerBound(),
+ foreachThreadOp.getMixedUpperBound(),
+ foreachThreadOp.getMixedStep())) {
+ std::optional<int64_t> lbConstant = getConstantIntValue(lb);
+ if (!lbConstant)
+ return true;
+
+ std::optional<int64_t> ubConstant = getConstantIntValue(ub);
+ if (!ubConstant)
+ return true;
+
+ std::optional<int64_t> stepConstant = getConstantIntValue(step);
+ if (!stepConstant)
+ return true;
+
+ if (*lbConstant + *stepConstant < *ubConstant)
+ return true;
+ }
+ return false;
}
};
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
index e0f236b633f82..79a688ad9d4f7 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
@@ -180,10 +180,10 @@ struct AffineOpSCFCanonicalizationPattern : public OpRewritePattern<OpTy> {
if (scf::ForeachThreadOp foreachThreadOp =
scf::getForeachThreadOpThreadIndexOwner(iv)) {
for (int64_t idx = 0; idx < foreachThreadOp.getRank(); ++idx) {
- if (foreachThreadOp.getThreadIndices()[idx] == iv) {
- lb = OpBuilder(iv.getContext()).getIndexAttr(0);
- ub = foreachThreadOp.getNumThreads()[idx];
- step = OpBuilder(iv.getContext()).getIndexAttr(1);
+ if (foreachThreadOp.getInductionVar(idx) == iv) {
+ lb = foreachThreadOp.getMixedLowerBound()[idx];
+ ub = foreachThreadOp.getMixedUpperBound()[idx];
+ step = foreachThreadOp.getMixedStep()[idx];
return success();
}
}
diff --git a/mlir/lib/Interfaces/ViewLikeInterface.cpp b/mlir/lib/Interfaces/ViewLikeInterface.cpp
index 9d30f2797c0e8..0b1ecc9115fd7 100644
--- a/mlir/lib/Interfaces/ViewLikeInterface.cpp
+++ b/mlir/lib/Interfaces/ViewLikeInterface.cpp
@@ -69,12 +69,45 @@ mlir::detail::verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op) {
return success();
}
+static char getLeftDelimiter(AsmParser::Delimiter delimiter) {
+ switch (delimiter) {
+ case AsmParser::Delimiter::Paren:
+ return '(';
+ case AsmParser::Delimiter::LessGreater:
+ return '<';
+ case AsmParser::Delimiter::Square:
+ return '[';
+ case AsmParser::Delimiter::Braces:
+ return '{';
+ default:
+ llvm_unreachable("unsupported delimiter");
+ }
+}
+
+static char getRightDelimiter(AsmParser::Delimiter delimiter) {
+ switch (delimiter) {
+ case AsmParser::Delimiter::Paren:
+ return ')';
+ case AsmParser::Delimiter::LessGreater:
+ return '>';
+ case AsmParser::Delimiter::Square:
+ return ']';
+ case AsmParser::Delimiter::Braces:
+ return '}';
+ default:
+ llvm_unreachable("unsupported delimiter");
+ }
+}
+
void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
OperandRange values,
- ArrayRef<int64_t> integers) {
- printer << '[';
+ ArrayRef<int64_t> integers,
+ AsmParser::Delimiter delimiter) {
+ char leftDelimiter = getLeftDelimiter(delimiter);
+ char rightDelimiter = getRightDelimiter(delimiter);
+ printer << leftDelimiter;
if (integers.empty()) {
- printer << "]";
+ printer << rightDelimiter;
return;
}
unsigned idx = 0;
@@ -84,13 +117,13 @@ void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
else
printer << integer;
});
- printer << ']';
+ printer << rightDelimiter;
}
ParseResult mlir::parseDynamicIndexList(
OpAsmParser &parser,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
- DenseI64ArrayAttr &integers) {
+ DenseI64ArrayAttr &integers, AsmParser::Delimiter delimiter) {
SmallVector<int64_t, 4> integerVals;
auto parseIntegerOrValue = [&]() {
@@ -107,8 +140,7 @@ ParseResult mlir::parseDynamicIndexList(
}
return success();
};
- if (parser.parseCommaSeparatedList(OpAsmParser::Delimiter::Square,
- parseIntegerOrValue,
+ if (parser.parseCommaSeparatedList(delimiter, parseIntegerOrValue,
" in dynamic index list"))
return parser.emitError(parser.getNameLoc())
<< "expected SSA value or integer";
diff --git a/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir b/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir
index 586d09c51ebac..7b04d6c5e152b 100644
--- a/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir
+++ b/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir
@@ -13,9 +13,7 @@ module {
// CHECK-SAME: %[[B:[0-9a-z]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[C:[0-9a-z]+]]: tensor<?x?xf32>
func.func @matmul(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
- // CHECK-DAG: %[[C10:.*]] = arith.constant 10 : index
- // CHECK-DAG: %[[C20:.*]] = arith.constant 20 : index
- // CHECK: scf.foreach_thread ({{.*}}) in (%[[C10]], %[[C20]]) shared_outs(%[[C_BLK:.*]] = %[[C]]) -> (tensor<?x?xf32>) {
+ // CHECK: scf.foreach_thread ({{.*}}) in (10, 20) shared_outs(%[[C_BLK:.*]] = %[[C]]) -> (tensor<?x?xf32>) {
// CHECK: %[[tA:.*]] = tensor.extract_slice %[[A]]{{.*}} : tensor<?x?xf32> to tensor<?x?xf32>
// CHECK: %[[tB:.*]] = tensor.extract_slice %[[B]]{{.*}} : tensor<?x?xf32> to tensor<?x?xf32>
// CHECK: %[[tC:.*]] = tensor.extract_slice %[[C_BLK]]{{.*}} : tensor<?x?xf32> to tensor<?x?xf32>
@@ -95,9 +93,7 @@ transform.sequence failures(propagate) {
// CHECK-SAME: %[[B:[0-9a-z]+]]: tensor
// CHECK-SAME: %[[C:[0-9a-z]+]]: tensor
func.func @matmul_static(%A: tensor<100x200xf32>, %B: tensor<200x300xf32>, %C: tensor<100x300xf32>) -> tensor<100x300xf32> {
- // CHECK-DAG: %[[c10:.+]] = arith.constant 10 : index
- // CHECK-DAG: %[[c21:.+]] = arith.constant 21 : index
- // CHECK: scf.foreach_thread (%[[IV0:.+]], %[[IV1:.+]]) in (%[[c10]], %[[c21]]) shared_outs(%[[C_BLK:.*]] = %[[C]])
+ // CHECK: scf.foreach_thread (%[[IV0:.+]], %[[IV1:.+]]) in (10, 21) shared_outs(%[[C_BLK:.*]] = %[[C]])
// CHECK: %[[TSMIN:.+]] = affine.min #[[$map0]](%[[IV1]])
// CHECK: %[[TS:.+]] = affine.max #[[$map1]](%[[TSMIN]])
// CHECK-NOT: affine.min
@@ -175,9 +171,7 @@ transform.sequence failures(propagate) {
// CHECK-SAME: %[[B:[0-9a-z]+]]: tensor
// CHECK-SAME: %[[C:[0-9a-z]+]]: tensor
func.func @matmul_tile_size_static(%A: tensor<100x200xf32>, %B: tensor<200x300xf32>, %C: tensor<100x300xf32>) -> tensor<100x300xf32> {
- // CHECK-DAG: %[[c10:.+]] = arith.constant 10 :
- // CHECK-DAG: %[[c15:.+]] = arith.constant 15 :
- // CHECK: scf.foreach_thread (%[[IV0:.+]], %[[IV1:.+]]) in (%[[c10]], %[[c15]]) shared_outs(%[[C_BLK:.*]] = %[[C]])
+ // CHECK: scf.foreach_thread (%[[IV0:.+]], %[[IV1:.+]]) in (10, 15) shared_outs(%[[C_BLK:.*]] = %[[C]])
// CHECK: %[[TS:.+]] = affine.min #[[$map0]](%[[IV1]])
// CHECK-NOT: affine.max
// CHECK-NOT: affine.min
@@ -225,8 +219,7 @@ module {
// CHECK-DAG: #[[$map0:.+]] = affine_map<(d0) -> (d0 * 2)>
// CHECK-LABEL: extract_source(
-// CHECK: %[[C2:.*]] = arith.constant 2 : index
-// CHECK: scf.foreach_thread (%[[ARG:.*]]) in (%[[C2]]) shared_outs(%{{.*}} = %{{.*}}) -> (tensor<4xf32>) {
+// CHECK: scf.foreach_thread (%[[ARG:.*]]) in (2) shared_outs(%{{.*}} = %{{.*}}) -> (tensor<4xf32>) {
// CHECK: %[[OFF:.*]] = affine.apply #[[$map0]](%[[ARG]])
// CHECK: scf.foreach_thread.perform_concurrently {
// CHECK: tensor.parallel_insert_slice %{{.*}} into %{{.*}}[%[[OFF]]] [2] [1] : tensor<2xf32> into tensor<4xf32>
@@ -289,8 +282,7 @@ transform.sequence failures(propagate) {
func.func @tile_output_multi_1d_static(%IN1: tensor<100xf32>, %IN2: tensor<100xf32>,
%OUT1: tensor<100xf32>, %OUT2: tensor<100xf32>)
-> (tensor<100xf32>, tensor<100xf32>) {
-// CHECK-DAG: %[[c0:.+]] = arith.constant 7 :
-// CHECK: scf.foreach_thread (%[[IV0:.+]]) in (%[[c0]]) shared_outs(%[[OUT1:[0-9a-z]+]] = %[[ORGOUT1]], %[[OUT2:[0-9a-z]+]] = %[[ORGOUT2]])
+// CHECK: scf.foreach_thread (%[[IV0:.+]]) in (7) shared_outs(%[[OUT1:[0-9a-z]+]] = %[[ORGOUT1]], %[[OUT2:[0-9a-z]+]] = %[[ORGOUT2]])
// CHECK: %[[TSMIN:.+]] = affine.min #[[$map0]](%[[IV0]])
// CHECK: %[[TS:.+]] = affine.max #[[$map1]](%[[TSMIN]])
// CHECK-NOT: affine.min
@@ -345,8 +337,7 @@ transform.sequence failures(propagate) {
func.func @tile_output_multi_1d2d_static(%IN1: tensor<100xf32>, %IN2: tensor<100x300xf32>, %IN3: tensor<300xf32>,
%OUT1: tensor<300x100xf32>, %OUT2: tensor<300xf32>)
-> (tensor<300x100xf32>, tensor<300xf32>) {
-// CHECK-DAG: %[[c0:.+]] = arith.constant 4 :
-// CHECK: scf.foreach_thread (%[[IV0:.+]]) in (%[[c0]]) shared_outs(%[[OUT1:[0-9a-z]+]] = %[[ORGOUT1]], %[[OUT2:[0-9a-z]+]] = %[[ORGOUT2]])
+// CHECK: scf.foreach_thread (%[[IV0:.+]]) in (4) shared_outs(%[[OUT1:[0-9a-z]+]] = %[[ORGOUT1]], %[[OUT2:[0-9a-z]+]] = %[[ORGOUT2]])
// CHECK: %[[LB:.+]] = affine.apply #[[$map0]](%[[IV0]])
// CHECK: %[[tIN1:.+]] = tensor.extract_slice %[[IN2]][0, %[[LB]]] [100, 75]
// CHECK: %[[tIN2:.+]] = tensor.extract_slice %[[IN3]][%[[LB]]] [75]
diff --git a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
index 4525527bd7777..0c4d9f054ae39 100644
--- a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
+++ b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
@@ -122,13 +122,12 @@ transform.sequence failures(propagate) {
// CHECK-DAG: %[[I:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index
// CHECK-DAG: %[[D0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?xf32>
// CHECK-DAG: %[[D1:.*]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?xf32>
// CHECK-DAG: %[[D2:.*]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?xf32>
// CHECK: %[[E:.*]] = tensor.empty(%[[D2]]) : tensor<?x5xf32>
// CHECK: %[[F:.*]] = linalg.fill ins(%[[I]] : f32) outs(%[[E]] : tensor<?x5xf32>) -> tensor<?x5xf32>
-// CHECK: %[[L:.*]] = scf.foreach_thread (%[[IV:.+]]) in (%[[C5]]) shared_outs(%[[ARG3:.+]] = %[[F]]) -> (tensor<?x5xf32>) {
+// CHECK: %[[L:.*]] = scf.foreach_thread (%[[IV:.+]]) in (5) shared_outs(%[[ARG3:.+]] = %[[F]]) -> (tensor<?x5xf32>) {
// CHECK-DAG: %[[TS0:.+]] = affine.min #[[MAP0]](%[[IV]])[%[[D1]]]
// CHECK-DAG: %[[TS1:.+]] = affine.max #[[MAP1]](%[[TS0]])
// CHECK-DAG: %[[ET:.+]] = tensor.extract_slice %[[ARG3:.+]][0, %[[IV]]] [%[[D0]], 1] [1, 1] : tensor<?x5xf32> to tensor<?xf32>
@@ -175,7 +174,6 @@ transform.sequence failures(propagate) {
// CHECK-DAG: %[[I:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index
// CHECK-DAG: %[[D0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?xf32>
// CHECK-DAG: %[[D1:.*]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?xf32>
// CHECK-DAG: %[[D2:.*]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
@@ -183,7 +181,7 @@ transform.sequence failures(propagate) {
// CHECK-DAG: %[[D4:.*]] = tensor.dim %[[ARG2]], %[[C1]] : tensor<?x?xf32>
// CHECK: %[[E:.*]] = tensor.empty(%[[D3]], %[[D4]]) : tensor<?x?x5xf32>
// CHECK: %[[F:.*]] = linalg.fill ins(%[[I]] : f32) outs(%[[E]] : tensor<?x?x5xf32>) -> tensor<?x?x5xf32>
-// CHECK: %[[L:.*]] = scf.foreach_thread (%[[IV:.+]]) in (%[[C5]]) shared_outs(%[[ARG3:.+]] = %[[F]]) -> (tensor<?x?x5xf32>) {
+// CHECK: %[[L:.*]] = scf.foreach_thread (%[[IV:.+]]) in (5) shared_outs(%[[ARG3:.+]] = %[[F]]) -> (tensor<?x?x5xf32>) {
// CHECK-DAG: %[[TS0:.+]] = affine.min #[[MAP0]](%[[IV]])[%[[D1]]]
// CHECK-DAG: %[[TS1:.+]] = affine.max #[[MAP1]](%[[TS0]])
// CHECK-DAG: %[[ET:.+]] = tensor.extract_slice %[[ARG3:.+]][0, 0, %[[IV]]] [%[[D0]], %[[D2]], 1] [1, 1, 1] : tensor<?x?x5xf32> to tensor<?x?xf32>
@@ -235,13 +233,12 @@ transform.sequence failures(propagate) {
// CHECK-DAG: %[[I:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index
// CHECK-DAG: %[[C15:.*]] = arith.constant 15 : index
// CHECK-DAG: %[[D0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?xf32>
// CHECK-DAG: %[[D2:.*]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?xf32>
// CHECK: %[[E:.*]] = tensor.empty(%[[D2]]) : tensor<?x5xf32>
// CHECK: %[[F:.*]] = linalg.fill ins(%[[I]] : f32) outs(%[[E]] : tensor<?x5xf32>) -> tensor<?x5xf32>
-// CHECK: %[[L:.*]] = scf.foreach_thread (%[[IV:.+]]) in (%[[C5]]) shared_outs(%[[ARG3:.+]] = %[[F]]) -> (tensor<?x5xf32>) {
+// CHECK: %[[L:.*]] = scf.foreach_thread (%[[IV:.+]]) in (5) shared_outs(%[[ARG3:.+]] = %[[F]]) -> (tensor<?x5xf32>) {
// CHECK: %[[ET:.+]] = tensor.extract_slice %[[ARG3:.+]][0, %[[IV]]] [%[[D0]], 1] [1, 1] : tensor<?x5xf32> to tensor<?xf32>
// CHECK: %[[D1:.*]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?xf32>
// CHECK: %[[LB:.+]] = affine.apply #[[MAP0]]()[%[[IV]]]
diff --git a/mlir/test/Dialect/SCF/ops.mlir b/mlir/test/Dialect/SCF/ops.mlir
index 18413e8e2a2fb..6d79f2cd01740 100644
--- a/mlir/test/Dialect/SCF/ops.mlir
+++ b/mlir/test/Dialect/SCF/ops.mlir
@@ -311,8 +311,8 @@ func.func @execute_region() -> i64 {
return %res : i64
}
-// CHECK-LABEL: func.func @simple_example
-func.func @simple_example(%in: tensor<100xf32>, %out: tensor<100xf32>) {
+// CHECK-LABEL: func.func @normalized_foreach_thread
+func.func @normalized_foreach_thread(%in: tensor<100xf32>, %out: tensor<100xf32>) {
%c1 = arith.constant 1 : index
%num_threads = arith.constant 100 : index
@@ -333,8 +333,32 @@ func.func @simple_example(%in: tensor<100xf32>, %out: tensor<100xf32>) {
return
}
-// CHECK-LABEL: func.func @elide_terminator
-func.func @elide_terminator() -> () {
+// CHECK-LABEL: func.func @explicit_loop_bounds_foreach_thread
+func.func @explicit_loop_bounds_foreach_thread(%in: tensor<100xf32>,
+ %out: tensor<100xf32>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %num_threads = arith.constant 100 : index
+
+ // CHECK: scf.foreach_thread
+ // CHECK-NEXT: tensor.extract_slice
+ // CHECK-NEXT: scf.foreach_thread.perform_concurrently
+ // CHECK-NEXT: tensor.parallel_insert_slice
+ // CHECK-NEXT: }
+ // CHECK-NEXT: }
+ // CHECK-NEXT: return
+ %result = scf.foreach_thread (%thread_idx) = (%c0) to (%num_threads) step (%c1) shared_outs(%o = %out) -> tensor<100xf32> {
+ %1 = tensor.extract_slice %in[%thread_idx][1][1] : tensor<100xf32> to tensor<1xf32>
+ scf.foreach_thread.perform_concurrently {
+ tensor.parallel_insert_slice %1 into %o[%thread_idx][1][1] :
+ tensor<1xf32> into tensor<100xf32>
+ }
+ }
+ return
+}
+
+// CHECK-LABEL: func.func @normalized_foreach_thread_elide_terminator
+func.func @normalized_foreach_thread_elide_terminator() -> () {
%num_threads = arith.constant 100 : index
// CHECK: scf.foreach_thread
@@ -345,6 +369,23 @@ func.func @elide_terminator() -> () {
}
} {mapping = [#gpu.thread<x>]}
return
+
+}
+
+// CHECK-LABEL: func.func @explicit_loop_bounds_foreach_thread_elide_terminator
+func.func @explicit_loop_bounds_foreach_thread_elide_terminator() -> () {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %num_threads = arith.constant 100 : index
+
+ // CHECK: scf.foreach_thread
+ // CHECK-NEXT: } {mapping = [#gpu.thread<x>]}
+ // CHECK-NEXT: return
+ scf.foreach_thread (%thread_idx) = (%c0) to (%num_threads) step (%c1) {
+ scf.foreach_thread.perform_concurrently {
+ }
+ } {mapping = [#gpu.thread<x>]}
+ return
}
// CHECK-LABEL: @switch
diff --git a/mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir b/mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir
index 541a54a911289..f366d331dff58 100644
--- a/mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir
+++ b/mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir
@@ -24,12 +24,11 @@ func.func @extract_slice_static(%input: tensor<3x5x7x11xf32>) -> tensor<20x11xf3
// CHECK: return %[[tile]]
// FOREACH: func.func @extract_slice_static(%[[arg0:.+]]:
-// FOREACH-DAG: %[[c20:.+]] = arith.constant 20 : index
// FOREACH-DAG: %[[c3:.+]] = arith.constant 3 : index
// FOREACH-DAG: %[[c5:.+]] = arith.constant 5 : index
// FOREACH-DAG: %[[c7:.+]] = arith.constant 7 : index
// FOREACH-DAG: %[[init:.+]] = tensor.empty() : tensor<20x11xf32>
-// FOREACH: %[[tile:.+]] = scf.foreach_thread (%[[iv:.+]]) in (%[[c20]]) shared_outs(%[[dest:.+]] = %[[init]])
+// FOREACH: %[[tile:.+]] = scf.foreach_thread (%[[iv:.+]]) in (20) shared_outs(%[[dest:.+]] = %[[init]])
// FOREACH: %[[multiIndex:.+]]:3 = affine.delinearize_index %[[iv]] into (%[[c3]], %[[c5]], %[[c7]]
// FOREACH: %[[slice:.+]] = tensor.extract_slice %[[arg0]][%[[multiIndex]]#0, %[[multiIndex]]#1, %[[multiIndex]]#2, 0] [1, 1, 1, 11] [1, 1, 1, 1] :
// FOREACH: %[[sliceFlat:.+]] = tensor.collapse_shape %[[slice]] {{\[}}[0, 1, 2], [3]{{\]}} :
diff --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
index 25039a5c41f92..6dcf6cf97339d 100644
--- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
@@ -247,9 +247,10 @@ struct RewriteExtractSliceFromCollapseShapeUsingScfForeach
tensor::ExtractSliceFromCollapseHelper &helper,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
- auto foreachOp = rewriter.create<scf::ForeachThreadOp>(
- loc, /*outputs=*/dest, /*numThreads=*/helper.getIterationSpaceSizes(),
- /*mapping=*/ArrayRef<Attribute>{},
+ auto foreachThreadOp = rewriter.create<scf::ForeachThreadOp>(
+ loc, /*numThreads=*/getAsOpFoldResult(helper.getIterationSpaceSizes()),
+ /*outputs=*/dest,
+ /*mapping=*/std::nullopt,
[&](OpBuilder &nestedBuilder, Location loc, ValueRange regionArgs) {
unsigned numThreadIdRegionArgs =
helper.getIterationSpaceSizes().size();
@@ -267,7 +268,7 @@ struct RewriteExtractSliceFromCollapseShapeUsingScfForeach
nestedBuilder.create<tensor::ParallelInsertSliceOp>(
loc, tile, outputArgs[0], insertParams);
});
- rewriter.replaceOp(op, foreachOp->getResult(0));
+ rewriter.replaceOp(op, foreachThreadOp->getResult(0));
return success();
}
};
More information about the Mlir-commits
mailing list