[Mlir-commits] [mlir] [mlir][SCF] `scf.parallel`: Make reductions part of the terminator (PR #75314)
Matthias Springer
llvmlistbot at llvm.org
Tue Dec 19 17:57:00 PST 2023
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/75314
>From 4566abde04e98e693650de1b2bc2955b64ad45e8 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Wed, 20 Dec 2023 10:47:46 +0900
Subject: [PATCH] [mlir][SCF] `scf.parallel`: Make reductions part of the
terminator
This commit makes reductions part of the terminator. Instead of `scf.yield`, `scf.reduce` now terminates the body of `scf.parallel` ops. `scf.reduce` may contain an arbitrary number of reductions, with one region per reduction.
`scf.reduce` operations can no longer be interleaved with other ops in the body of `scf.parallel`. This simplifies the op and makes it possible to assign the `RecursiveMemoryEffects` trait to `scf.reduce`. (This was not possible before because the op was not a terminator, causing the op to be DCE'd.)
---
mlir/include/mlir/Dialect/SCF/IR/SCFOps.td | 111 +++++++------
.../AffineToStandard/AffineToStandard.cpp | 27 +--
.../SCFToControlFlow/SCFToControlFlow.cpp | 24 ++-
.../Conversion/SCFToOpenMP/SCFToOpenMP.cpp | 115 +++++++------
.../Async/Transforms/AsyncParallelFor.cpp | 3 +-
mlir/lib/Dialect/SCF/IR/SCF.cpp | 155 ++++++++----------
.../SCF/Transforms/ParallelLoopTiling.cpp | 5 +
.../Transforms/SparseGPUCodegen.cpp | 3 +
.../Transforms/Utils/LoopEmitter.cpp | 2 +-
.../AffineToStandard/lower-affine.mlir | 24 +--
.../SCFToControlFlow/convert-to-cfg.mlir | 13 +-
.../Conversion/SCFToGPU/parallel_loop.mlir | 8 +-
.../Conversion/SCFToOpenMP/reductions.mlir | 19 +--
.../Conversion/SCFToSPIRV/unsupported.mlir | 8 +-
mlir/test/Dialect/Linalg/parallel-loops.mlir | 2 +-
.../Dialect/Linalg/transform-op-match.mlir | 2 +-
.../test/Dialect/SCF/buffer-deallocation.mlir | 2 +-
mlir/test/Dialect/SCF/canonicalize.mlir | 23 ++-
mlir/test/Dialect/SCF/invalid.mlir | 35 ++--
mlir/test/Dialect/SCF/ops.mlir | 22 ++-
.../Dialect/SCF/parallel-loop-fusion.mlir | 66 ++++----
.../SparseTensor/sparse_parallel_reduce.mlir | 5 +-
.../invalid-parallel-loop-collapsing.mlir | 4 +-
.../loop-invariant-code-motion.mlir | 2 +-
.../Transforms/parallel-loop-collapsing.mlir | 2 +-
.../single-parallel-loop-collapsing.mlir | 2 +-
26 files changed, 344 insertions(+), 340 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 573e804b405e84..8d65d3dd820baf 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -770,7 +770,7 @@ def ParallelOp : SCF_Op<"parallel",
"getSingleLowerBound", "getSingleUpperBound", "getSingleStep"]>,
RecursiveMemoryEffects,
DeclareOpInterfaceMethods<RegionBranchOpInterface>,
- SingleBlockImplicitTerminator<"scf::YieldOp">]> {
+ SingleBlockImplicitTerminator<"scf::ReduceOp">]> {
let summary = "parallel for operation";
let description = [{
The "scf.parallel" operation represents a loop nest taking 4 groups of SSA
@@ -791,27 +791,36 @@ def ParallelOp : SCF_Op<"parallel",
The parallel loop operation supports reduction of values produced by
individual iterations into a single result. This is modeled using the
- scf.reduce operation (see scf.reduce for details). Each result of a
- scf.parallel operation is associated with an initial value operand and
- reduce operation that is an immediate child. Reductions are matched to
- result and initial values in order of their appearance in the body.
- Consequently, we require that the body region has the same number of
- results and initial values as it has reduce operations.
-
- The body region must contain exactly one block that terminates with
- "scf.yield" without operands. Parsing ParallelOp will create such a region
- and insert the terminator when it is absent from the custom format.
+ "scf.reduce" terminator operation (see "scf.reduce" for details). The i-th
+ result of an "scf.parallel" operation is associated with the i-th initial
+ value operand, the i-th operand of the "scf.reduce" operation (the value to
+ be reduced) and the i-th region of the "scf.reduce" operation (the reduction
+ function). Consequently, we require that the number of results of an
+ "scf.parallel" op matches the number of initial values and the the number of
+ reductions in the "scf.reduce" terminator.
+
+ The body region must contain exactly one block that terminates with a
+ "scf.reduce" operation. If an "scf.parallel" op has no reductions, the
+ terminator has no operands and no regions. The "scf.parallel" parser will
+ automatically insert the terminator for ops that have no reductions if it is
+ absent.
Example:
```mlir
%init = arith.constant 0.0 : f32
- scf.parallel (%iv) = (%lb) to (%ub) step (%step) init (%init) -> f32 {
- %elem_to_reduce = load %buffer[%iv] : memref<100xf32>
- scf.reduce(%elem_to_reduce) : f32 {
+ %r:2 = scf.parallel (%iv) = (%lb) to (%ub) step (%step) init (%init, %init)
+ -> f32, f32 {
+ %elem_to_reduce1 = load %buffer1[%iv] : memref<100xf32>
+ %elem_to_reduce2 = load %buffer2[%iv] : memref<100xf32>
+ scf.reduce(%elem_to_reduce1, %elem_to_reduce2 : f32, f32) {
^bb0(%lhs : f32, %rhs: f32):
%res = arith.addf %lhs, %rhs : f32
scf.reduce.return %res : f32
+ }, {
+ ^bb0(%lhs : f32, %rhs: f32):
+ %res = arith.mulf %lhs, %rhs : f32
+ scf.reduce.return %res : f32
}
}
```
@@ -853,36 +862,36 @@ def ParallelOp : SCF_Op<"parallel",
// ReduceOp
//===----------------------------------------------------------------------===//
-def ReduceOp : SCF_Op<"reduce", [HasParent<"ParallelOp">]> {
- let summary = "reduce operation for parallel for";
+def ReduceOp : SCF_Op<"reduce", [
+ Terminator, HasParent<"ParallelOp">, RecursiveMemoryEffects,
+ DeclareOpInterfaceMethods<RegionBranchTerminatorOpInterface>]> {
+ let summary = "reduce operation for scf.parallel";
let description = [{
- "scf.reduce" is an operation occurring inside "scf.parallel" operations.
- It consists of one block with two arguments which have the same type as the
- operand of "scf.reduce".
-
- "scf.reduce" is used to model the value for reduction computations of a
- "scf.parallel" operation. It has to appear as an immediate child of a
- "scf.parallel" and is associated with a result value of its parent
- operation.
-
- Association is in the order of appearance in the body where the first
- result of a parallel loop operation corresponds to the first "scf.reduce"
- in the operation's body region. The reduce operation takes a single
- operand, which is the value to be used in the reduction.
-
- The reduce operation contains a region whose entry block expects two
- arguments of the same type as the operand. As the iteration order of the
- parallel loop and hence reduction order is unspecified, the result of
- reduction may be non-deterministic unless the operation is associative and
- commutative.
-
- The result of the reduce operation's body must have the same type as the
- operands and associated result value of the parallel loop operation.
+ "scf.reduce" is the terminator for "scf.parallel" operations. It can model
+ an arbitrary number of reductions. It has one region per reduction. Each
+ region has one block with two arguments which have the same type as the
+ corresponding operand of "scf.reduce". The operands of the op are the values
+ that should be reduce; one value per reduction.
+
+ The i-th reduction (i.e., the i-th region and the i-th operand) corresponds
+ the i-th initial value and the i-th result of the enclosing "scf.parallel"
+ op.
+
+ The "scf.reduce" operation contains regions whose entry blocks expect two
+ arguments of the same type as the corresponding operand. As the iteration
+ order of the enclosing parallel loop and hence reduction order is
+ unspecified, the results of the reductions may be non-deterministic unless
+ the reductions are associative and commutative.
+
+ The result of a reduction region ("scf.reduce.return" operand) must have the
+ same type as the corresponding "scf.reduce" operand and the corresponding
+ "scf.parallel" initial value.
+
Example:
```mlir
%operand = arith.constant 1.0 : f32
- scf.reduce(%operand) : f32 {
+ scf.reduce(%operand : f32) {
^bb0(%lhs : f32, %rhs: f32):
%res = arith.addf %lhs, %rhs : f32
scf.reduce.return %res : f32
@@ -892,14 +901,15 @@ def ReduceOp : SCF_Op<"reduce", [HasParent<"ParallelOp">]> {
let skipDefaultBuilders = 1;
let builders = [
- OpBuilder<(ins "Value":$operand,
- CArg<"function_ref<void (OpBuilder &, Location, Value, Value)>",
- "nullptr">:$bodyBuilderFn)>
+ OpBuilder<(ins "ValueRange":$operands)>,
+ OpBuilder<(ins)>
];
- let arguments = (ins AnyType:$operand);
- let hasCustomAssemblyFormat = 1;
- let regions = (region SizedRegion<1>:$reductionOperator);
+ let arguments = (ins Variadic<AnyType>:$operands);
+ let assemblyFormat = [{
+ (`(` $operands^ `:` type($operands) `)`)? $reductions attr-dict
+ }];
+ let regions = (region VariadicRegion<SizedRegion<1>>:$reductions);
let hasRegionVerifier = 1;
}
@@ -908,13 +918,14 @@ def ReduceOp : SCF_Op<"reduce", [HasParent<"ParallelOp">]> {
//===----------------------------------------------------------------------===//
def ReduceReturnOp :
- SCF_Op<"reduce.return", [HasParent<"ReduceOp">, Pure,
- Terminator]> {
+ SCF_Op<"reduce.return", [HasParent<"ReduceOp">, Pure, Terminator]> {
let summary = "terminator for reduce operation";
let description = [{
"scf.reduce.return" is a special terminator operation for the block inside
- "scf.reduce". It terminates the region. It should have the same type as
- the operand of "scf.reduce". Example for the custom format:
+ "scf.reduce" regions. It terminates the region. It should have the same
+ operand type as the corresponding operand of the enclosing "scf.reduce" op.
+
+ Example:
```mlir
scf.reduce.return %res : f32
@@ -1150,7 +1161,7 @@ def IndexSwitchOp : SCF_Op<"index_switch", [RecursiveMemoryEffects,
def YieldOp : SCF_Op<"yield", [Pure, ReturnLike, Terminator,
ParentOneOf<["ExecuteRegionOp", "ForOp", "IfOp", "IndexSwitchOp",
- "ParallelOp", "WhileOp"]>]> {
+ "WhileOp"]>]> {
let summary = "loop yield and termination operation";
let description = [{
"scf.yield" yields an SSA value from the SCF dialect op region and
diff --git a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
index 7dbbf015182f39..15ad6d8cdf629d 100644
--- a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
+++ b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
@@ -137,10 +137,9 @@ class AffineYieldOpLowering : public OpRewritePattern<AffineYieldOp> {
LogicalResult matchAndRewrite(AffineYieldOp op,
PatternRewriter &rewriter) const override {
if (isa<scf::ParallelOp>(op->getParentOp())) {
- // scf.parallel does not yield any values via its terminator scf.yield but
- // models reductions differently using additional ops in its region.
- rewriter.replaceOpWithNewOp<scf::YieldOp>(op);
- return success();
+ // Terminator is rewritten as part of the "affine.parallel" lowering
+ // pattern.
+ return failure();
}
rewriter.replaceOpWithNewOp<scf::YieldOp>(op, op.getOperands());
return success();
@@ -203,7 +202,8 @@ class AffineParallelLowering : public OpRewritePattern<AffineParallelOp> {
steps.push_back(rewriter.create<arith::ConstantIndexOp>(loc, step));
// Get the terminator op.
- Operation *affineParOpTerminator = op.getBody()->getTerminator();
+ auto affineParOpTerminator =
+ cast<AffineYieldOp>(op.getBody()->getTerminator());
scf::ParallelOp parOp;
if (op.getResults().empty()) {
// Case with no reduction operations/return values.
@@ -214,6 +214,8 @@ class AffineParallelLowering : public OpRewritePattern<AffineParallelOp> {
rewriter.inlineRegionBefore(op.getRegion(), parOp.getRegion(),
parOp.getRegion().end());
rewriter.replaceOp(op, parOp.getResults());
+ rewriter.setInsertionPoint(affineParOpTerminator);
+ rewriter.replaceOpWithNewOp<scf::ReduceOp>(affineParOpTerminator);
return success();
}
// Case with affine.parallel with reduction operations/return values.
@@ -243,6 +245,11 @@ class AffineParallelLowering : public OpRewritePattern<AffineParallelOp> {
parOp.getRegion().end());
assert(reductions.size() == affineParOpTerminator->getNumOperands() &&
"Unequal number of reductions and operands.");
+
+ // Emit new "scf.reduce" terminator.
+ rewriter.setInsertionPoint(affineParOpTerminator);
+ auto reduceOp = rewriter.replaceOpWithNewOp<scf::ReduceOp>(
+ affineParOpTerminator, affineParOpTerminator->getOperands());
for (unsigned i = 0, end = reductions.size(); i < end; i++) {
// For each of the reduction operations get the respective mlir::Value.
std::optional<arith::AtomicRMWKind> reductionOp =
@@ -251,13 +258,11 @@ class AffineParallelLowering : public OpRewritePattern<AffineParallelOp> {
assert(reductionOp && "Reduction Operation cannot be of None Type");
arith::AtomicRMWKind reductionOpValue = *reductionOp;
rewriter.setInsertionPoint(&parOp.getBody()->back());
- auto reduceOp = rewriter.create<scf::ReduceOp>(
- loc, affineParOpTerminator->getOperand(i));
- rewriter.setInsertionPointToEnd(&reduceOp.getReductionOperator().front());
+ Block &reductionBody = reduceOp.getReductions()[i].front();
+ rewriter.setInsertionPointToEnd(&reductionBody);
Value reductionResult = arith::getReductionOp(
- reductionOpValue, rewriter, loc,
- reduceOp.getReductionOperator().front().getArgument(0),
- reduceOp.getReductionOperator().front().getArgument(1));
+ reductionOpValue, rewriter, loc, reductionBody.getArgument(0),
+ reductionBody.getArgument(1));
rewriter.create<scf::ReduceReturnOp>(loc, reductionResult);
}
rewriter.replaceOp(op, parOp.getResults());
diff --git a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
index c9b45fd4a7957b..9eb8a289d7d658 100644
--- a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
+++ b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
@@ -471,6 +471,7 @@ LogicalResult
ParallelLowering::matchAndRewrite(ParallelOp parallelOp,
PatternRewriter &rewriter) const {
Location loc = parallelOp.getLoc();
+ auto reductionOp = cast<ReduceOp>(parallelOp.getBody()->getTerminator());
// For a parallel loop, we essentially need to create an n-dimensional loop
// nest. We do this by translating to scf.for ops and have those lowered in
@@ -506,23 +507,20 @@ ParallelLowering::matchAndRewrite(ParallelOp parallelOp,
}
// First, merge reduction blocks into the main region.
- SmallVector<Value, 4> yieldOperands;
+ SmallVector<Value> yieldOperands;
yieldOperands.reserve(parallelOp.getNumResults());
- for (auto &op : *parallelOp.getBody()) {
- auto reduce = dyn_cast<ReduceOp>(op);
- if (!reduce)
- continue;
-
- Block &reduceBlock = reduce.getReductionOperator().front();
+ for (int64_t i = 0, e = parallelOp.getNumResults(); i < e; ++i) {
+ Block &reductionBody = reductionOp.getReductions()[i].front();
Value arg = iterArgs[yieldOperands.size()];
- yieldOperands.push_back(reduceBlock.getTerminator()->getOperand(0));
- rewriter.eraseOp(reduceBlock.getTerminator());
- rewriter.inlineBlockBefore(&reduceBlock, &op, {arg, reduce.getOperand()});
- rewriter.eraseOp(reduce);
+ yieldOperands.push_back(
+ cast<ReduceReturnOp>(reductionBody.getTerminator()).getResult());
+ rewriter.eraseOp(reductionBody.getTerminator());
+ rewriter.inlineBlockBefore(&reductionBody, reductionOp,
+ {arg, reductionOp.getOperands()[i]});
}
+ rewriter.eraseOp(reductionOp);
// Then merge the loop body without the terminator.
- rewriter.eraseOp(parallelOp.getBody()->getTerminator());
Block *newBody = rewriter.getInsertionBlock();
if (newBody->empty())
rewriter.mergeBlocks(parallelOp.getBody(), newBody, ivs);
@@ -711,7 +709,7 @@ LogicalResult ForallLowering::matchAndRewrite(ForallOp forallOp,
parallelOp.getRegion().begin());
// Replace the terminator.
rewriter.setInsertionPointToEnd(¶llelOp.getRegion().front());
- rewriter.replaceOpWithNewOp<scf::YieldOp>(
+ rewriter.replaceOpWithNewOp<scf::ReduceOp>(
parallelOp.getRegion().front().getTerminator());
// Erase the scf.forall op.
diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index 67033ba812946f..2f8b3f7e11de15 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -181,32 +181,34 @@ static Attribute minMaxValueForUnsignedInt(Type type, bool min) {
/// Creates an OpenMP reduction declaration and inserts it into the provided
/// symbol table. The declaration has a constant initializer with the neutral
-/// value `initValue`, and the reduction combiner carried over from `reduce`.
-static omp::ReductionDeclareOp createDecl(PatternRewriter &builder,
- SymbolTable &symbolTable,
- scf::ReduceOp reduce,
- Attribute initValue) {
+/// value `initValue`, and the `reductionIndex`-th reduction combiner carried
+/// over from `reduce`.
+static omp::ReductionDeclareOp
+createDecl(PatternRewriter &builder, SymbolTable &symbolTable,
+ scf::ReduceOp reduce, int64_t reductionIndex, Attribute initValue) {
OpBuilder::InsertionGuard guard(builder);
- auto decl = builder.create<omp::ReductionDeclareOp>(
- reduce.getLoc(), "__scf_reduction", reduce.getOperand().getType());
+ Type type = reduce.getOperands()[reductionIndex].getType();
+ auto decl = builder.create<omp::ReductionDeclareOp>(reduce.getLoc(),
+ "__scf_reduction", type);
symbolTable.insert(decl);
- Type type = reduce.getOperand().getType();
builder.createBlock(&decl.getInitializerRegion(),
decl.getInitializerRegion().end(), {type},
- {reduce.getOperand().getLoc()});
+ {reduce.getOperands()[reductionIndex].getLoc()});
builder.setInsertionPointToEnd(&decl.getInitializerRegion().back());
Value init =
builder.create<LLVM::ConstantOp>(reduce.getLoc(), type, initValue);
builder.create<omp::YieldOp>(reduce.getLoc(), init);
- Operation *terminator = &reduce.getRegion().front().back();
+ Operation *terminator =
+ &reduce.getReductions()[reductionIndex].front().back();
assert(isa<scf::ReduceReturnOp>(terminator) &&
"expected reduce op to be terminated by redure return");
builder.setInsertionPoint(terminator);
builder.replaceOpWithNewOp<omp::YieldOp>(terminator,
terminator->getOperands());
- builder.inlineRegionBefore(reduce.getRegion(), decl.getReductionRegion(),
+ builder.inlineRegionBefore(reduce.getReductions()[reductionIndex],
+ decl.getReductionRegion(),
decl.getReductionRegion().end());
return decl;
}
@@ -216,10 +218,11 @@ static omp::ReductionDeclareOp createDecl(PatternRewriter &builder,
static omp::ReductionDeclareOp addAtomicRMW(OpBuilder &builder,
LLVM::AtomicBinOp atomicKind,
omp::ReductionDeclareOp decl,
- scf::ReduceOp reduce) {
+ scf::ReduceOp reduce,
+ int64_t reductionIndex) {
OpBuilder::InsertionGuard guard(builder);
auto ptrType = LLVM::LLVMPointerType::get(builder.getContext());
- Location reduceOperandLoc = reduce.getOperand().getLoc();
+ Location reduceOperandLoc = reduce.getOperands()[reductionIndex].getLoc();
builder.createBlock(&decl.getAtomicReductionRegion(),
decl.getAtomicReductionRegion().end(), {ptrType, ptrType},
{reduceOperandLoc, reduceOperandLoc});
@@ -239,7 +242,8 @@ static omp::ReductionDeclareOp addAtomicRMW(OpBuilder &builder,
/// the neutral value, necessary for the OpenMP declaration. If the reduction
/// cannot be recognized, returns null.
static omp::ReductionDeclareOp declareReduction(PatternRewriter &builder,
- scf::ReduceOp reduce) {
+ scf::ReduceOp reduce,
+ int64_t reductionIndex) {
Operation *container = SymbolTable::getNearestSymbolTable(reduce);
SymbolTable symbolTable(container);
@@ -251,49 +255,58 @@ static omp::ReductionDeclareOp declareReduction(PatternRewriter &builder,
OpBuilder::InsertionGuard guard(builder);
builder.setInsertionPoint(insertionPoint);
- assert(llvm::hasSingleElement(reduce.getRegion()) &&
+ assert(llvm::hasSingleElement(reduce.getReductions()[reductionIndex]) &&
"expected reduction region to have a single element");
// Match simple binary reductions that can be expressed with atomicrmw.
- Type type = reduce.getOperand().getType();
- Block &reduction = reduce.getRegion().front();
+ Type type = reduce.getOperands()[reductionIndex].getType();
+ Block &reduction = reduce.getReductions()[reductionIndex].front();
if (matchSimpleReduction<arith::AddFOp, LLVM::FAddOp>(reduction)) {
- omp::ReductionDeclareOp decl = createDecl(builder, symbolTable, reduce,
- builder.getFloatAttr(type, 0.0));
- return addAtomicRMW(builder, LLVM::AtomicBinOp::fadd, decl, reduce);
+ omp::ReductionDeclareOp decl =
+ createDecl(builder, symbolTable, reduce, reductionIndex,
+ builder.getFloatAttr(type, 0.0));
+ return addAtomicRMW(builder, LLVM::AtomicBinOp::fadd, decl, reduce,
+ reductionIndex);
}
if (matchSimpleReduction<arith::AddIOp, LLVM::AddOp>(reduction)) {
- omp::ReductionDeclareOp decl = createDecl(builder, symbolTable, reduce,
- builder.getIntegerAttr(type, 0));
- return addAtomicRMW(builder, LLVM::AtomicBinOp::add, decl, reduce);
+ omp::ReductionDeclareOp decl =
+ createDecl(builder, symbolTable, reduce, reductionIndex,
+ builder.getIntegerAttr(type, 0));
+ return addAtomicRMW(builder, LLVM::AtomicBinOp::add, decl, reduce,
+ reductionIndex);
}
if (matchSimpleReduction<arith::OrIOp, LLVM::OrOp>(reduction)) {
- omp::ReductionDeclareOp decl = createDecl(builder, symbolTable, reduce,
- builder.getIntegerAttr(type, 0));
- return addAtomicRMW(builder, LLVM::AtomicBinOp::_or, decl, reduce);
+ omp::ReductionDeclareOp decl =
+ createDecl(builder, symbolTable, reduce, reductionIndex,
+ builder.getIntegerAttr(type, 0));
+ return addAtomicRMW(builder, LLVM::AtomicBinOp::_or, decl, reduce,
+ reductionIndex);
}
if (matchSimpleReduction<arith::XOrIOp, LLVM::XOrOp>(reduction)) {
- omp::ReductionDeclareOp decl = createDecl(builder, symbolTable, reduce,
- builder.getIntegerAttr(type, 0));
- return addAtomicRMW(builder, LLVM::AtomicBinOp::_xor, decl, reduce);
+ omp::ReductionDeclareOp decl =
+ createDecl(builder, symbolTable, reduce, reductionIndex,
+ builder.getIntegerAttr(type, 0));
+ return addAtomicRMW(builder, LLVM::AtomicBinOp::_xor, decl, reduce,
+ reductionIndex);
}
if (matchSimpleReduction<arith::AndIOp, LLVM::AndOp>(reduction)) {
omp::ReductionDeclareOp decl = createDecl(
- builder, symbolTable, reduce,
+ builder, symbolTable, reduce, reductionIndex,
builder.getIntegerAttr(
type, llvm::APInt::getAllOnes(type.getIntOrFloatBitWidth())));
- return addAtomicRMW(builder, LLVM::AtomicBinOp::_and, decl, reduce);
+ return addAtomicRMW(builder, LLVM::AtomicBinOp::_and, decl, reduce,
+ reductionIndex);
}
// Match simple binary reductions that cannot be expressed with atomicrmw.
// TODO: add atomic region using cmpxchg (which needs atomic load to be
// available as an op).
if (matchSimpleReduction<arith::MulFOp, LLVM::FMulOp>(reduction)) {
- return createDecl(builder, symbolTable, reduce,
+ return createDecl(builder, symbolTable, reduce, reductionIndex,
builder.getFloatAttr(type, 1.0));
}
if (matchSimpleReduction<arith::MulIOp, LLVM::MulOp>(reduction)) {
- return createDecl(builder, symbolTable, reduce,
+ return createDecl(builder, symbolTable, reduce, reductionIndex,
builder.getIntegerAttr(type, 1));
}
@@ -305,7 +318,7 @@ static omp::ReductionDeclareOp declareReduction(PatternRewriter &builder,
matchSelectReduction<LLVM::FCmpOp, LLVM::SelectOp>(
reduction, {LLVM::FCmpPredicate::olt, LLVM::FCmpPredicate::ole},
{LLVM::FCmpPredicate::ogt, LLVM::FCmpPredicate::oge}, isMin)) {
- return createDecl(builder, symbolTable, reduce,
+ return createDecl(builder, symbolTable, reduce, reductionIndex,
minMaxValueForFloat(type, !isMin));
}
if (matchSelectReduction<arith::CmpIOp, arith::SelectOp>(
@@ -314,11 +327,12 @@ static omp::ReductionDeclareOp declareReduction(PatternRewriter &builder,
matchSelectReduction<LLVM::ICmpOp, LLVM::SelectOp>(
reduction, {LLVM::ICmpPredicate::slt, LLVM::ICmpPredicate::sle},
{LLVM::ICmpPredicate::sgt, LLVM::ICmpPredicate::sge}, isMin)) {
- omp::ReductionDeclareOp decl = createDecl(
- builder, symbolTable, reduce, minMaxValueForSignedInt(type, !isMin));
+ omp::ReductionDeclareOp decl =
+ createDecl(builder, symbolTable, reduce, reductionIndex,
+ minMaxValueForSignedInt(type, !isMin));
return addAtomicRMW(builder,
isMin ? LLVM::AtomicBinOp::min : LLVM::AtomicBinOp::max,
- decl, reduce);
+ decl, reduce, reductionIndex);
}
if (matchSelectReduction<arith::CmpIOp, arith::SelectOp>(
reduction, {arith::CmpIPredicate::ult, arith::CmpIPredicate::ule},
@@ -326,11 +340,12 @@ static omp::ReductionDeclareOp declareReduction(PatternRewriter &builder,
matchSelectReduction<LLVM::ICmpOp, LLVM::SelectOp>(
reduction, {LLVM::ICmpPredicate::ugt, LLVM::ICmpPredicate::ule},
{LLVM::ICmpPredicate::ugt, LLVM::ICmpPredicate::uge}, isMin)) {
- omp::ReductionDeclareOp decl = createDecl(
- builder, symbolTable, reduce, minMaxValueForUnsignedInt(type, !isMin));
+ omp::ReductionDeclareOp decl =
+ createDecl(builder, symbolTable, reduce, reductionIndex,
+ minMaxValueForUnsignedInt(type, !isMin));
return addAtomicRMW(
builder, isMin ? LLVM::AtomicBinOp::umin : LLVM::AtomicBinOp::umax,
- decl, reduce);
+ decl, reduce, reductionIndex);
}
return nullptr;
@@ -352,8 +367,9 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
// TODO: consider checking it here is already a compatible reduction
// declaration and use it instead of redeclaring.
SmallVector<Attribute> reductionDeclSymbols;
- for (auto reduce : parallelOp.getOps<scf::ReduceOp>()) {
- omp::ReductionDeclareOp decl = declareReduction(rewriter, reduce);
+ auto reduce = cast<scf::ReduceOp>(parallelOp.getBody()->getTerminator());
+ for (int64_t i = 0, e = parallelOp.getNumReductions(); i < e; ++i) {
+ omp::ReductionDeclareOp decl = declareReduction(rewriter, reduce, i);
if (!decl)
return failure();
reductionDeclSymbols.push_back(
@@ -382,14 +398,13 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
// Replace the reduction operations contained in this loop. Must be done
// here rather than in a separate pattern to have access to the list of
// reduction variables.
- for (auto pair :
- llvm::zip(parallelOp.getOps<scf::ReduceOp>(), reductionVariables)) {
+ for (auto [x, y] :
+ llvm::zip_equal(reductionVariables, reduce.getOperands())) {
OpBuilder::InsertionGuard guard(rewriter);
- scf::ReduceOp reduceOp = std::get<0>(pair);
- rewriter.setInsertionPoint(reduceOp);
- rewriter.replaceOpWithNewOp<omp::ReductionOp>(
- reduceOp, reduceOp.getOperand(), std::get<1>(pair));
+ rewriter.setInsertionPoint(reduce);
+ rewriter.create<omp::ReductionOp>(reduce.getLoc(), y, x);
}
+ rewriter.eraseOp(reduce);
Value numThreadsVar;
if (numThreads > 0) {
@@ -432,10 +447,8 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
rewriter.create<omp::YieldOp>(loc, ValueRange());
Block *scopeBlock = rewriter.createBlock(&scope.getBodyRegion());
rewriter.mergeBlocks(ops, scopeBlock);
- auto oldYield = cast<scf::YieldOp>(scopeBlock->getTerminator());
rewriter.setInsertionPointToEnd(&*scope.getBodyRegion().begin());
- rewriter.replaceOpWithNewOp<memref::AllocaScopeReturnOp>(
- oldYield, oldYield->getOperands());
+ rewriter.create<memref::AllocaScopeReturnOp>(loc, ValueRange());
if (!reductionVariables.empty()) {
loop.setReductionsAttr(
ArrayAttr::get(rewriter.getContext(), reductionDeclSymbols));
diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
index 12a28c2e23b221..428a3c945581b4 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
@@ -429,8 +429,9 @@ static ParallelComputeFunction createParallelComputeFunction(
mapping.map(op.getInductionVars(), computeBlockInductionVars);
mapping.map(computeFuncType.captures, captures);
- for (auto &bodyOp : op.getRegion().getOps())
+ for (auto &bodyOp : op.getRegion().front().without_terminator())
b.clone(bodyOp, mapping);
+ b.create<scf::YieldOp>(loc);
};
};
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 55bb5788108bdb..5570c2ec688c8a 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -2643,7 +2643,9 @@ void ParallelOp::build(
bodyBlock->getArguments().take_front(numIVs),
bodyBlock->getArguments().drop_front(numIVs));
}
- ParallelOp::ensureTerminator(*bodyRegion, builder, result.location);
+ // Add terminator only if there are no reductions.
+ if (initVals.empty())
+ ParallelOp::ensureTerminator(*bodyRegion, builder, result.location);
}
void ParallelOp::build(
@@ -2693,19 +2695,15 @@ LogicalResult ParallelOp::verify() {
return emitOpError(
"expects arguments for the induction variable to be of index type");
- // Check that the yield has no results
- auto yield = verifyAndGetTerminator<scf::YieldOp>(
- *this, getRegion(), "expects body to terminate with 'scf.yield'");
- if (!yield)
+ // Check that the terminator is an scf.reduce op.
+ auto reduceOp = verifyAndGetTerminator<scf::ReduceOp>(
+ *this, getRegion(), "expects body to terminate with 'scf.reduce'");
+ if (!reduceOp)
return failure();
- if (yield->getNumOperands() != 0)
- return yield.emitOpError() << "not allowed to have operands inside '"
- << ParallelOp::getOperationName() << "'";
- // Check that the number of results is the same as the number of ReduceOps.
- SmallVector<ReduceOp, 4> reductions(body->getOps<ReduceOp>());
+ // Check that the number of results is the same as the number of reductions.
auto resultsSize = getResults().size();
- auto reductionsSize = reductions.size();
+ auto reductionsSize = reduceOp.getReductions().size();
auto initValsSize = getInitVals().size();
if (resultsSize != reductionsSize)
return emitOpError() << "expects number of results: " << resultsSize
@@ -2717,14 +2715,15 @@ LogicalResult ParallelOp::verify() {
<< initValsSize;
// Check that the types of the results and reductions are the same.
- for (auto resultAndReduce : llvm::zip(getResults(), reductions)) {
- auto resultType = std::get<0>(resultAndReduce).getType();
- auto reduceOp = std::get<1>(resultAndReduce);
- auto reduceType = reduceOp.getOperand().getType();
- if (resultType != reduceType)
+ for (int64_t i = 0; i < static_cast<int64_t>(reductionsSize); ++i) {
+ auto resultType = getOperation()->getResult(i).getType();
+ auto reductionOperandType = reduceOp.getOperands()[i].getType();
+ if (resultType != reductionOperandType)
return reduceOp.emitOpError()
- << "expects type of reduce: " << reduceType
- << " to be the same as result type: " << resultType;
+ << "expects type of " << i
+ << "-th reduction operand: " << reductionOperandType
+ << " to be the same as the " << i
+ << "-th result type: " << resultType;
}
return success();
}
@@ -2792,7 +2791,7 @@ ParseResult ParallelOp::parse(OpAsmParser &parser, OperationState &result) {
return failure();
// Add a terminator if none was parsed.
- ForOp::ensureTerminator(*body, builder, result.location);
+ ParallelOp::ensureTerminator(*body, builder, result.location);
return success();
}
@@ -2887,17 +2886,15 @@ struct ParallelOpSingleOrZeroIterationDimsFolder
// loop body and nested ReduceOp's
SmallVector<Value> results;
results.reserve(op.getInitVals().size());
- for (auto &bodyOp : op.getBody()->without_terminator()) {
- auto reduce = dyn_cast<ReduceOp>(bodyOp);
- if (!reduce) {
- rewriter.clone(bodyOp, mapping);
- continue;
- }
- Block &reduceBlock = reduce.getReductionOperator().front();
+ for (auto &bodyOp : op.getBody()->without_terminator())
+ rewriter.clone(bodyOp, mapping);
+ auto reduceOp = cast<ReduceOp>(op.getBody()->getTerminator());
+ for (int64_t i = 0, e = reduceOp.getReductions().size(); i < e; ++i) {
+ Block &reduceBlock = reduceOp.getReductions()[i].front();
auto initValIndex = results.size();
mapping.map(reduceBlock.getArgument(0), op.getInitVals()[initValIndex]);
mapping.map(reduceBlock.getArgument(1),
- mapping.lookupOrDefault(reduce.getOperand()));
+ mapping.lookupOrDefault(reduceOp.getOperands()[i]));
for (auto &reduceBodyOp : reduceBlock.without_terminator())
rewriter.clone(reduceBodyOp, mapping);
@@ -2905,6 +2902,7 @@ struct ParallelOpSingleOrZeroIterationDimsFolder
cast<ReduceReturnOp>(reduceBlock.getTerminator()).getResult());
results.push_back(result);
}
+
rewriter.replaceOp(op, results);
return success();
}
@@ -3008,67 +3006,48 @@ void ParallelOp::getSuccessorRegions(
// ReduceOp
//===----------------------------------------------------------------------===//
-void ReduceOp::build(
- OpBuilder &builder, OperationState &result, Value operand,
- function_ref<void(OpBuilder &, Location, Value, Value)> bodyBuilderFn) {
- auto type = operand.getType();
- result.addOperands(operand);
+void ReduceOp::build(OpBuilder &builder, OperationState &result) {}
- OpBuilder::InsertionGuard guard(builder);
- Region *bodyRegion = result.addRegion();
- Block *body = builder.createBlock(bodyRegion, {}, ArrayRef<Type>{type, type},
- {result.location, result.location});
- if (bodyBuilderFn)
- bodyBuilderFn(builder, result.location, body->getArgument(0),
- body->getArgument(1));
+void ReduceOp::build(OpBuilder &builder, OperationState &result,
+ ValueRange operands) {
+ result.addOperands(operands);
+ for (Value v : operands) {
+ OpBuilder::InsertionGuard guard(builder);
+ Region *bodyRegion = result.addRegion();
+ builder.createBlock(bodyRegion, {},
+ ArrayRef<Type>{v.getType(), v.getType()},
+ {result.location, result.location});
+ }
}
LogicalResult ReduceOp::verifyRegions() {
- // The region of a ReduceOp has two arguments of the same type as its operand.
- auto type = getOperand().getType();
- Block &block = getReductionOperator().front();
- if (block.empty())
- return emitOpError("the block inside reduce should not be empty");
- if (block.getNumArguments() != 2 ||
- llvm::any_of(block.getArguments(), [&](const BlockArgument &arg) {
- return arg.getType() != type;
- }))
- return emitOpError() << "expects two arguments to reduce block of type "
- << type;
-
- // Check that the block is terminated by a ReduceReturnOp.
- if (!isa<ReduceReturnOp>(block.getTerminator()))
- return emitOpError("the block inside reduce should be terminated with a "
- "'scf.reduce.return' op");
-
- return success();
-}
-
-ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) {
- // Parse an opening `(` followed by the reduced value followed by `)`
- OpAsmParser::UnresolvedOperand operand;
- if (parser.parseLParen() || parser.parseOperand(operand) ||
- parser.parseRParen())
- return failure();
-
- Type resultType;
- // Parse the type of the operand (and also what reduce computes on).
- if (parser.parseColonType(resultType) ||
- parser.resolveOperand(operand, resultType, result.operands))
- return failure();
-
- // Now parse the body.
- Region *body = result.addRegion();
- if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))
- return failure();
+ // The region of a ReduceOp has two arguments of the same type as its
+ // corresponding operand.
+ for (int64_t i = 0, e = getReductions().size(); i < e; ++i) {
+ auto type = getOperands()[i].getType();
+ Block &block = getReductions()[i].front();
+ if (block.empty())
+ return emitOpError() << i << "-th reduction has an empty body";
+ if (block.getNumArguments() != 2 ||
+ llvm::any_of(block.getArguments(), [&](const BlockArgument &arg) {
+ return arg.getType() != type;
+ }))
+ return emitOpError() << "expected two block arguments with type " << type
+ << " in the " << i << "-th reduction region";
+
+ // Check that the block is terminated by a ReduceReturnOp.
+ if (!isa<ReduceReturnOp>(block.getTerminator()))
+ return emitOpError("reduction bodies must be terminated with an "
+ "'scf.reduce.return' op");
+ }
return success();
}
-void ReduceOp::print(OpAsmPrinter &p) {
- p << "(" << getOperand() << ") ";
- p << " : " << getOperand().getType() << ' ';
- p.printRegion(getReductionOperator());
+MutableOperandRange
+ReduceOp::getMutableSuccessorOperands(RegionBranchPoint point) {
+ // No operands are forwarded to the next iteration.
+ return MutableOperandRange(getOperation(), /*start=*/0, /*length=*/0);
}
//===----------------------------------------------------------------------===//
@@ -3076,13 +3055,15 @@ void ReduceOp::print(OpAsmPrinter &p) {
//===----------------------------------------------------------------------===//
LogicalResult ReduceReturnOp::verify() {
- // The type of the return value should be the same type as the type of the
- // operand of the enclosing ReduceOp.
- auto reduceOp = cast<ReduceOp>((*this)->getParentOp());
- Type reduceType = reduceOp.getOperand().getType();
- if (reduceType != getResult().getType())
- return emitOpError() << "needs to have type " << reduceType
- << " (the type of the enclosing ReduceOp)";
+ // The type of the return value should be the same type as the types of the
+ // block arguments of the reduction body.
+ Block *reductionBody = getOperation()->getBlock();
+ // Should already be verified by an op trait.
+ assert(isa<ReduceOp>(reductionBody->getParentOp()) && "expected scf.reduce");
+ Type expectedResultType = reductionBody->getArgument(0).getType();
+ if (expectedResultType != getResult().getType())
+ return emitOpError() << "must have type " << expectedResultType
+ << " (the type of the reduction inputs)";
return success();
}
diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp
index fdc28060917fb2..ed73d81198f298 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp
@@ -159,6 +159,11 @@ mlir::scf::tileParallelLoop(ParallelOp op, ArrayRef<int64_t> tileSizes,
/*hasElseRegion*/ false);
ifInbound.getThenRegion().takeBody(op.getRegion());
Block &thenBlock = ifInbound.getThenRegion().front();
+ // Replace the scf.reduce terminator with an scf.yield terminator.
+ Operation *reduceOp = thenBlock.getTerminator();
+ b.setInsertionPointToEnd(&thenBlock);
+ b.create<scf::YieldOp>(reduceOp->getLoc());
+ reduceOp->erase();
b.setInsertionPointToStart(innerLoop.getBody());
for (const auto &ivs : llvm::enumerate(llvm::zip(
innerLoop.getInductionVars(), outerLoop.getInductionVars()))) {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
index 69fd1eb746ffe7..8af3b694c4d975 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
@@ -315,6 +315,9 @@ static void genGPUCode(PatternRewriter &rewriter, gpu::GPUFuncOp gpuFunc,
rewriter.eraseBlock(forOp.getBody());
rewriter.cloneRegionBefore(forallOp.getRegion(), forOp.getRegion(),
forOp.getRegion().begin(), irMap);
+ // Replace the scf.reduce terminator.
+ rewriter.setInsertionPoint(forOp.getBody()->getTerminator());
+ rewriter.replaceOpWithNewOp<scf::YieldOp>(forOp.getBody()->getTerminator());
// Done.
rewriter.setInsertionPointAfter(forOp);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
index 35faf1769746d8..d60b6ccd732167 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
@@ -1371,7 +1371,7 @@ void LoopEmitter::exitForLoop(RewriterBase &rewriter, Location loc,
rewriter.setInsertionPointAfter(redExp);
auto redOp = rewriter.create<scf::ReduceOp>(loc, curVal);
// Attach to the reduction op.
- Block *redBlock = &redOp.getRegion().getBlocks().front();
+ Block *redBlock = &redOp.getReductions().front().front();
rewriter.setInsertionPointToEnd(redBlock);
Operation *newRed = rewriter.clone(*redExp);
// Replaces arguments of the reduction expression by using the block
diff --git a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
index 6158de33e4aef2..92608135d24b08 100644
--- a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
+++ b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
@@ -763,7 +763,7 @@ func.func @affine_parallel_tiled(%o: memref<100x100xf32>, %a: memref<100x100xf32
// CHECK: %[[A3:.*]] = memref.load %[[ARG1]][%[[arg6]], %[[arg8]]] : memref<100x100xf32>
// CHECK: %[[A4:.*]] = memref.load %[[ARG2]][%[[arg8]], %[[arg7]]] : memref<100x100xf32>
// CHECK: arith.mulf %[[A3]], %[[A4]] : f32
-// CHECK: scf.yield
+// CHECK: scf.reduce
/////////////////////////////////////////////////////////////////////
@@ -789,7 +789,7 @@ func.func @affine_parallel_simple(%arg0: memref<3x3xf32>, %arg1: memref<3x3xf32>
// CHECK-NEXT: %[[VAL_2:.*]] = memref.load
// CHECK-NEXT: %[[PRODUCT:.*]] = arith.mulf
// CHECK-NEXT: store
-// CHECK-NEXT: scf.yield
+// CHECK-NEXT: scf.reduce
// CHECK-NEXT: }
// CHECK-NEXT: return
// CHECK-NEXT: }
@@ -820,7 +820,7 @@ func.func @affine_parallel_simple_dynamic_bounds(%arg0: memref<?x?xf32>, %arg1:
// CHECK-NEXT: %[[VAL_2:.*]] = memref.load
// CHECK-NEXT: %[[PRODUCT:.*]] = arith.mulf
// CHECK-NEXT: store
-// CHECK-NEXT: scf.yield
+// CHECK-NEXT: scf.reduce
// CHECK-NEXT: }
// CHECK-NEXT: return
// CHECK-NEXT: }
@@ -851,17 +851,15 @@ func.func @affine_parallel_with_reductions(%arg0: memref<3x3xf32>, %arg1: memref
// CHECK-NEXT: %[[VAL_2:.*]] = memref.load
// CHECK-NEXT: %[[PRODUCT:.*]] = arith.mulf
// CHECK-NEXT: %[[SUM:.*]] = arith.addf
-// CHECK-NEXT: scf.reduce(%[[PRODUCT]]) : f32 {
+// CHECK-NEXT: scf.reduce(%[[PRODUCT]], %[[SUM]] : f32, f32) {
// CHECK-NEXT: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
// CHECK-NEXT: %[[RES:.*]] = arith.addf
// CHECK-NEXT: scf.reduce.return %[[RES]] : f32
-// CHECK-NEXT: }
-// CHECK-NEXT: scf.reduce(%[[SUM]]) : f32 {
+// CHECK-NEXT: }, {
// CHECK-NEXT: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
// CHECK-NEXT: %[[RES:.*]] = arith.mulf
// CHECK-NEXT: scf.reduce.return %[[RES]] : f32
// CHECK-NEXT: }
-// CHECK-NEXT: scf.yield
// CHECK-NEXT: }
// CHECK-NEXT: return
// CHECK-NEXT: }
@@ -892,17 +890,15 @@ func.func @affine_parallel_with_reductions_f64(%arg0: memref<3x3xf64>, %arg1: me
// CHECK: %[[VAL_2:.*]] = memref.load
// CHECK: %[[PRODUCT:.*]] = arith.mulf
// CHECK: %[[SUM:.*]] = arith.addf
-// CHECK: scf.reduce(%[[PRODUCT]]) : f64 {
+// CHECK: scf.reduce(%[[PRODUCT]], %[[SUM]] : f64, f64) {
// CHECK: ^bb0(%[[LHS:.*]]: f64, %[[RHS:.*]]: f64):
// CHECK: %[[RES:.*]] = arith.addf
// CHECK: scf.reduce.return %[[RES]] : f64
-// CHECK: }
-// CHECK: scf.reduce(%[[SUM]]) : f64 {
+// CHECK: }, {
// CHECK: ^bb0(%[[LHS:.*]]: f64, %[[RHS:.*]]: f64):
// CHECK: %[[RES:.*]] = arith.mulf
// CHECK: scf.reduce.return %[[RES]] : f64
// CHECK: }
-// CHECK: scf.yield
// CHECK: }
/////////////////////////////////////////////////////////////////////
@@ -931,15 +927,13 @@ func.func @affine_parallel_with_reductions_i64(%arg0: memref<3x3xi64>, %arg1: me
// CHECK: %[[VAL_2:.*]] = memref.load
// CHECK: %[[PRODUCT:.*]] = arith.muli
// CHECK: %[[SUM:.*]] = arith.addi
-// CHECK: scf.reduce(%[[PRODUCT]]) : i64 {
+// CHECK: scf.reduce(%[[PRODUCT]], %[[SUM]] : i64, i64) {
// CHECK: ^bb0(%[[LHS:.*]]: i64, %[[RHS:.*]]: i64):
// CHECK: %[[RES:.*]] = arith.addi
// CHECK: scf.reduce.return %[[RES]] : i64
-// CHECK: }
-// CHECK: scf.reduce(%[[SUM]]) : i64 {
+// CHECK: }, {
// CHECK: ^bb0(%[[LHS:.*]]: i64, %[[RHS:.*]]: i64):
// CHECK: %[[RES:.*]] = arith.muli
// CHECK: scf.reduce.return %[[RES]] : i64
// CHECK: }
-// CHECK: scf.yield
// CHECK: }
diff --git a/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir b/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir
index 99b47ea94cc0b1..caf17bc91ced23 100644
--- a/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir
+++ b/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir
@@ -254,6 +254,7 @@ func.func @parallel_loop(%arg0 : index, %arg1 : index, %arg2 : index,
scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
step (%arg4, %step) {
%c1 = arith.constant 1 : index
+ scf.reduce
}
return
}
@@ -347,7 +348,7 @@ func.func @simple_parallel_reduce_loop(%arg0: index, %arg1: index,
// CHECK: return %[[ITER_ARG]]
%0 = scf.parallel (%i) = (%arg0) to (%arg1) step (%arg2) init(%arg3) -> f32 {
%cst = arith.constant 42.0 : f32
- scf.reduce(%cst) : f32 {
+ scf.reduce(%cst : f32) {
^bb0(%lhs: f32, %rhs: f32):
%1 = arith.mulf %lhs, %rhs : f32
scf.reduce.return %1 : f32
@@ -383,14 +384,12 @@ func.func @parallel_reduce_loop(%arg0 : index, %arg1 : index, %arg2 : index,
%0:2 = scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
step (%arg4, %step) init(%arg5, %init) -> (f32, i64) {
%cf = arith.constant 42.0 : f32
- scf.reduce(%cf) : f32 {
+ %2 = func.call @generate() : () -> i64
+ scf.reduce(%cf, %2 : f32, i64) {
^bb0(%lhs: f32, %rhs: f32):
%1 = arith.addf %lhs, %rhs : f32
scf.reduce.return %1 : f32
- }
-
- %2 = func.call @generate() : () -> i64
- scf.reduce(%2) : i64 {
+ }, {
^bb0(%lhs: i64, %rhs: i64):
%3 = arith.ori %lhs, %rhs : i64
scf.reduce.return %3 : i64
@@ -580,7 +579,7 @@ func.func @ifs_in_parallel(%arg1: index, %arg2: index, %arg3: index, %arg4: i1,
scf.yield %2 : index
}
}
- scf.yield
+ scf.reduce
}
// CHECK: ^[[LOOP_CONT]]:
diff --git a/mlir/test/Conversion/SCFToGPU/parallel_loop.mlir b/mlir/test/Conversion/SCFToGPU/parallel_loop.mlir
index deeaec2f81a94e..59441e5ed66290 100644
--- a/mlir/test/Conversion/SCFToGPU/parallel_loop.mlir
+++ b/mlir/test/Conversion/SCFToGPU/parallel_loop.mlir
@@ -232,9 +232,9 @@ module {
%19 = memref.load %16[%arg5, %arg6] : memref<?x?xf32, strided<[?, ?], offset: ?>>
%20 = arith.addf %17, %18 : f32
memref.store %20, %16[%arg5, %arg6] : memref<?x?xf32, strided<[?, ?], offset: ?>>
- scf.yield
+ scf.reduce
} {mapping = [#gpu.loop_dim_map<bound = (d0) -> (d0), map = (d0) -> (d0), processor = thread_x>, #gpu.loop_dim_map<bound = (d0) -> (d0), map = (d0) -> (d0), processor = thread_y>]}
- scf.yield
+ scf.reduce
} {mapping = [#gpu.loop_dim_map<bound = (d0) -> (d0), map = (d0) -> (d0), processor = block_x>, #gpu.loop_dim_map<bound = (d0) -> (d0), map = (d0) -> (d0), processor = block_y>]}
return
}
@@ -404,9 +404,9 @@ func.func @step_invariant() {
%1 = memref.load %alloc_0[%arg0, %arg1] : memref<1x1xf64>
%2 = arith.addf %0, %1 : f64
memref.store %2, %alloc[%arg0, %arg1] : memref<1x1xf64>
- scf.yield
+ scf.reduce
} {mapping = [#gpu.loop_dim_map<processor = thread_x, map = (d0) -> (d0), bound = (d0) -> (d0)>]}
- scf.yield
+ scf.reduce
} {mapping = [#gpu.loop_dim_map<processor = block_x, map = (d0) -> (d0), bound = (d0) -> (d0)>]}
memref.dealloc %alloc_1 : memref<1x1xf64>
memref.dealloc %alloc_0 : memref<1x1xf64>
diff --git a/mlir/test/Conversion/SCFToOpenMP/reductions.mlir b/mlir/test/Conversion/SCFToOpenMP/reductions.mlir
index 25b18b58a6adbd..faf5ec4aba7d4d 100644
--- a/mlir/test/Conversion/SCFToOpenMP/reductions.mlir
+++ b/mlir/test/Conversion/SCFToOpenMP/reductions.mlir
@@ -34,7 +34,7 @@ func.func @reduction1(%arg0 : index, %arg1 : index, %arg2 : index,
// CHECK: %[[CST_INNER:.*]] = arith.constant 1.0
%one = arith.constant 1.0 : f32
// CHECK: omp.reduction %[[CST_INNER]], %[[BUF]]
- scf.reduce(%one) : f32 {
+ scf.reduce(%one : f32) {
^bb0(%lhs : f32, %rhs: f32):
%res = arith.addf %lhs, %rhs : f32
scf.reduce.return %res : f32
@@ -70,7 +70,7 @@ func.func @reduction2(%arg0 : index, %arg1 : index, %arg2 : index,
scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
step (%arg4, %step) init (%zero) -> (f32) {
%one = arith.constant 1.0 : f32
- scf.reduce(%one) : f32 {
+ scf.reduce(%one : f32) {
^bb0(%lhs : f32, %rhs: f32):
%res = arith.mulf %lhs, %rhs : f32
scf.reduce.return %res : f32
@@ -107,7 +107,7 @@ func.func @reduction_muli(%arg0 : index, %arg1 : index, %arg2 : index,
step (%arg4, %step) init (%one) -> (i32) {
// CHECK: omp.reduction
%pow2 = arith.constant 2 : i32
- scf.reduce(%pow2) : i32 {
+ scf.reduce(%pow2 : i32) {
^bb0(%lhs : i32, %rhs: i32):
%res = arith.muli %lhs, %rhs : i32
scf.reduce.return %res : i32
@@ -141,7 +141,7 @@ func.func @reduction3(%arg0 : index, %arg1 : index, %arg2 : index,
scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
step (%arg4, %step) init (%zero) -> (f32) {
%one = arith.constant 1.0 : f32
- scf.reduce(%one) : f32 {
+ scf.reduce(%one : f32) {
^bb0(%lhs : f32, %rhs: f32):
%cmp = arith.cmpf oge, %lhs, %rhs : f32
%res = arith.select %cmp, %lhs, %rhs : f32
@@ -205,17 +205,16 @@ func.func @reduction4(%arg0 : index, %arg1 : index, %arg2 : index,
%res:2 = scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
step (%arg4, %step) init (%zero, %ione) -> (f32, i64) {
%one = arith.constant 1.0 : f32
+ // CHECK: arith.fptosi
+ %1 = arith.fptosi %one : f32 to i64
// CHECK: omp.reduction %{{.*}}, %[[BUF1]]
- scf.reduce(%one) : f32 {
+ // CHECK: omp.reduction %{{.*}}, %[[BUF2]]
+ scf.reduce(%one, %1 : f32, i64) {
^bb0(%lhs : f32, %rhs: f32):
%cmp = arith.cmpf oge, %lhs, %rhs : f32
%res = arith.select %cmp, %lhs, %rhs : f32
scf.reduce.return %res : f32
- }
- // CHECK: arith.fptosi
- %1 = arith.fptosi %one : f32 to i64
- // CHECK: omp.reduction %{{.*}}, %[[BUF2]]
- scf.reduce(%1) : i64 {
+ }, {
^bb1(%lhs: i64, %rhs: i64):
%cmp = arith.cmpi slt, %lhs, %rhs : i64
%res = arith.select %cmp, %rhs, %lhs : i64
diff --git a/mlir/test/Conversion/SCFToSPIRV/unsupported.mlir b/mlir/test/Conversion/SCFToSPIRV/unsupported.mlir
index 6f388f366f7447..71bf2f3d918e83 100644
--- a/mlir/test/Conversion/SCFToSPIRV/unsupported.mlir
+++ b/mlir/test/Conversion/SCFToSPIRV/unsupported.mlir
@@ -1,13 +1,13 @@
// RUN: mlir-opt -convert-scf-to-spirv %s -o - | FileCheck %s
// `scf.parallel` conversion is not supported yet.
-// Make sure that we do not accidentally invalidate this functio by removing
-// `scf.yield`.
+// Make sure that we do not accidentally invalidate this function by removing
+// `scf.reduce`.
// CHECK-LABEL: func.func @func
// CHECK: scf.parallel
// CHECK-NEXT: spirv.Constant
// CHECK-NEXT: memref.store
-// CHECK-NEXT: scf.yield
+// CHECK-NEXT: scf.reduce
// CHECK: spirv.Return
func.func @func(%arg0: i64) {
%0 = arith.index_cast %arg0 : i64 to index
@@ -15,7 +15,7 @@ func.func @func(%arg0: i64) {
scf.parallel (%arg1) = (%0) to (%0) step (%0) {
%cst = arith.constant 1.000000e+00 : f32
memref.store %cst, %alloc[%arg1] : memref<16xf32>
- scf.yield
+ scf.reduce
}
return
}
diff --git a/mlir/test/Dialect/Linalg/parallel-loops.mlir b/mlir/test/Dialect/Linalg/parallel-loops.mlir
index 15bce63caabcfd..c04f27608d4452 100644
--- a/mlir/test/Dialect/Linalg/parallel-loops.mlir
+++ b/mlir/test/Dialect/Linalg/parallel-loops.mlir
@@ -25,7 +25,7 @@ func.func @linalg_generic_sum(%lhs: memref<2x2xf32>,
// CHECK: %[[RHS_ELEM:.*]] = memref.load %[[RHS]][%[[I]], %[[J]]]
// CHECK: %[[SUM:.*]] = arith.addf %[[LHS_ELEM]], %[[RHS_ELEM]] : f32
// CHECK: store %[[SUM]], %{{.*}}[%[[I]], %[[J]]]
-// CHECK: scf.yield
+// CHECK: scf.reduce
// -----
diff --git a/mlir/test/Dialect/Linalg/transform-op-match.mlir b/mlir/test/Dialect/Linalg/transform-op-match.mlir
index fed3c007d9b6d8..15942db9b5db20 100644
--- a/mlir/test/Dialect/Linalg/transform-op-match.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-match.mlir
@@ -153,7 +153,7 @@ func.func @foo(%lb: index, %ub: index, %step: index) {
// expected-remark @below {{loop-like}}
scf.parallel (%i) = (%lb) to (%ub) step (%step) {
func.call @callee() : () -> ()
- scf.yield
+ scf.reduce
}
// expected-remark @below {{loop-like}}
scf.forall (%i) in (%ub) {
diff --git a/mlir/test/Dialect/SCF/buffer-deallocation.mlir b/mlir/test/Dialect/SCF/buffer-deallocation.mlir
index 99cfed99c02d1a..8451b1524fd2a0 100644
--- a/mlir/test/Dialect/SCF/buffer-deallocation.mlir
+++ b/mlir/test/Dialect/SCF/buffer-deallocation.mlir
@@ -31,7 +31,7 @@ func.func @reduce(%buffer: memref<100xf32>) {
%c1 = arith.constant 1 : index
scf.parallel (%iv) = (%c0) to (%c1) step (%c1) init (%init) -> f32 {
%elem_to_reduce = memref.load %buffer[%iv] : memref<100xf32>
- scf.reduce(%elem_to_reduce) : f32 {
+ scf.reduce(%elem_to_reduce : f32) {
^bb0(%lhs : f32, %rhs: f32):
%alloc = memref.alloc() : memref<2xf32>
memref.store %lhs, %alloc [%c0] : memref<2xf32>
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 41e028028616a7..52e0fdfa36d6cd 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -11,7 +11,7 @@ func.func @single_iteration_some(%A: memref<?x?x?xi32>) {
scf.parallel (%i0, %i1, %i2) = (%c0, %c3, %c7) to (%c1, %c6, %c10) step (%c1, %c2, %c3) {
%c42 = arith.constant 42 : i32
memref.store %c42, %A[%i0, %i1, %i2] : memref<?x?x?xi32>
- scf.yield
+ scf.reduce
}
return
}
@@ -26,7 +26,7 @@ func.func @single_iteration_some(%A: memref<?x?x?xi32>) {
// CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index
// CHECK: scf.parallel ([[V0:%.*]]) = ([[C3]]) to ([[C6]]) step ([[C2]]) {
// CHECK: memref.store [[C42]], [[ARG0]]{{\[}}[[C0]], [[V0]], [[C7]]] : memref<?x?x?xi32>
-// CHECK: scf.yield
+// CHECK: scf.reduce
// CHECK: }
// CHECK: return
@@ -42,7 +42,7 @@ func.func @single_iteration_all(%A: memref<?x?x?xi32>) {
scf.parallel (%i0, %i1, %i2) = (%c0, %c3, %c7) to (%c1, %c6, %c10) step (%c1, %c3, %c3) {
%c42 = arith.constant 42 : i32
memref.store %c42, %A[%i0, %i1, %i2] : memref<?x?x?xi32>
- scf.yield
+ scf.reduce
}
return
}
@@ -55,7 +55,7 @@ func.func @single_iteration_all(%A: memref<?x?x?xi32>) {
// CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index
// CHECK-NOT: scf.parallel
// CHECK: memref.store [[C42]], [[ARG0]]{{\[}}[[C0]], [[C3]], [[C7]]] : memref<?x?x?xi32>
-// CHECK-NOT: scf.yield
+// CHECK-NOT: scf.reduce
// CHECK: return
// -----
@@ -67,17 +67,15 @@ func.func @single_iteration_reduce(%A: index, %B: index) -> (index, index) {
%c3 = arith.constant 3 : index
%c6 = arith.constant 6 : index
%0:2 = scf.parallel (%i0, %i1) = (%c1, %c3) to (%c2, %c6) step (%c1, %c3) init(%A, %B) -> (index, index) {
- scf.reduce(%i0) : index {
+ scf.reduce(%i0, %i1 : index, index) {
^bb0(%lhs: index, %rhs: index):
%1 = arith.addi %lhs, %rhs : index
scf.reduce.return %1 : index
- }
- scf.reduce(%i1) : index {
+ }, {
^bb0(%lhs: index, %rhs: index):
%2 = arith.muli %lhs, %rhs : index
scf.reduce.return %2 : index
}
- scf.yield
}
return %0#0, %0#1 : index, index
}
@@ -109,11 +107,11 @@ func.func @nested_parallel(%0: memref<?x?x?xf64>) -> memref<?x?x?xf64> {
scf.parallel (%arg3) = (%c0) to (%3) step (%c1) {
%5 = memref.load %0[%arg1, %arg2, %arg3] : memref<?x?x?xf64>
memref.store %5, %4[%arg1, %arg2, %arg3] : memref<?x?x?xf64>
- scf.yield
+ scf.reduce
}
- scf.yield
+ scf.reduce
}
- scf.yield
+ scf.reduce
}
return %4 : memref<?x?x?xf64>
}
@@ -759,12 +757,11 @@ func.func @remove_empty_parallel_loop(%lb: index, %ub: index, %s: index) {
// CHECK-NOT: test.transform
%0 = scf.parallel (%i, %j, %k) = (%lb, %ub, %lb) to (%ub, %ub, %ub) step (%s, %s, %s) init(%init) -> f32 {
%1 = "test.produce"() : () -> f32
- scf.reduce(%1) : f32 {
+ scf.reduce(%1 : f32) {
^bb0(%lhs: f32, %rhs: f32):
%2 = "test.transform"(%lhs, %rhs) : (f32, f32) -> f32
scf.reduce.return %2 : f32
}
- scf.yield
}
// CHECK: "test.consume"(%[[INIT]])
"test.consume"(%0) : (f32) -> ()
diff --git a/mlir/test/Dialect/SCF/invalid.mlir b/mlir/test/Dialect/SCF/invalid.mlir
index ad07a8b11327de..fac9d825568f72 100644
--- a/mlir/test/Dialect/SCF/invalid.mlir
+++ b/mlir/test/Dialect/SCF/invalid.mlir
@@ -235,7 +235,7 @@ func.func @parallel_fewer_results_than_reduces(
// expected-error at +1 {{expects number of results: 0 to be the same as number of reductions: 1}}
scf.parallel (%i0) = (%arg0) to (%arg1) step (%arg2) {
%c0 = arith.constant 1.0 : f32
- scf.reduce(%c0) : f32 {
+ scf.reduce(%c0 : f32) {
^bb0(%lhs: f32, %rhs: f32):
scf.reduce.return %lhs : f32
}
@@ -261,7 +261,7 @@ func.func @parallel_more_results_than_initial_values(
%arg0 : index, %arg1: index, %arg2: index) {
// expected-error at +1 {{'scf.parallel' 0 operands present, but expected 1}}
%res = scf.parallel (%i0) = (%arg0) to (%arg1) step (%arg2) -> f32 {
- scf.reduce(%arg0) : index {
+ scf.reduce(%arg0 : index) {
^bb0(%lhs: index, %rhs: index):
scf.reduce.return %lhs : index
}
@@ -275,8 +275,8 @@ func.func @parallel_different_types_of_results_and_reduces(
%zero = arith.constant 0.0 : f32
%res = scf.parallel (%i0) = (%arg0) to (%arg1)
step (%arg2) init (%zero) -> f32 {
- // expected-error at +1 {{expects type of reduce: 'index' to be the same as result type: 'f32'}}
- scf.reduce(%arg0) : index {
+ // expected-error at +1 {{expects type of 0-th reduction operand: 'index' to be the same as the 0-th result type: 'f32'}}
+ scf.reduce(%arg0 : index) {
^bb0(%lhs: index, %rhs: index):
scf.reduce.return %lhs : index
}
@@ -288,7 +288,7 @@ func.func @parallel_different_types_of_results_and_reduces(
func.func @top_level_reduce(%arg0 : f32) {
// expected-error at +1 {{expects parent op 'scf.parallel'}}
- scf.reduce(%arg0) : f32 {
+ scf.reduce(%arg0 : f32) {
^bb0(%lhs : f32, %rhs : f32):
scf.reduce.return %lhs : f32
}
@@ -302,7 +302,7 @@ func.func @reduce_empty_block(%arg0 : index, %arg1 : f32) {
%res = scf.parallel (%i0) = (%arg0) to (%arg0)
step (%arg0) init (%zero) -> f32 {
// expected-error at +1 {{empty block: expect at least a terminator}}
- scf.reduce(%arg1) : f32 {
+ scf.reduce(%arg1 : f32) {
^bb0(%lhs : f32, %rhs : f32):
}
}
@@ -315,8 +315,8 @@ func.func @reduce_too_many_args(%arg0 : index, %arg1 : f32) {
%zero = arith.constant 0.0 : f32
%res = scf.parallel (%i0) = (%arg0) to (%arg0)
step (%arg0) init (%zero) -> f32 {
- // expected-error at +1 {{expects two arguments to reduce block of type 'f32'}}
- scf.reduce(%arg1) : f32 {
+ // expected-error at +1 {{expected two block arguments with type 'f32' in the 0-th reduction region}}
+ scf.reduce(%arg1 : f32) {
^bb0(%lhs : f32, %rhs : f32, %other : f32):
scf.reduce.return %lhs : f32
}
@@ -330,8 +330,8 @@ func.func @reduce_wrong_args(%arg0 : index, %arg1 : f32) {
%zero = arith.constant 0.0 : f32
%res = scf.parallel (%i0) = (%arg0) to (%arg0)
step (%arg0) init (%zero) -> f32 {
- // expected-error at +1 {{expects two arguments to reduce block of type 'f32'}}
- scf.reduce(%arg1) : f32 {
+ // expected-error at +1 {{expected two block arguments with type 'f32' in the 0-th reduction region}}
+ scf.reduce(%arg1 : f32) {
^bb0(%lhs : f32, %rhs : i32):
scf.reduce.return %lhs : f32
}
@@ -346,8 +346,8 @@ func.func @reduce_wrong_terminator(%arg0 : index, %arg1 : f32) {
%zero = arith.constant 0.0 : f32
%res = scf.parallel (%i0) = (%arg0) to (%arg0)
step (%arg0) init (%zero) -> f32 {
- // expected-error at +1 {{the block inside reduce should be terminated with a 'scf.reduce.return' op}}
- scf.reduce(%arg1) : f32 {
+ // expected-error at +1 {{reduction bodies must be terminated with an 'scf.reduce.return' op}}
+ scf.reduce(%arg1 : f32) {
^bb0(%lhs : f32, %rhs : f32):
"test.finish" () : () -> ()
}
@@ -361,10 +361,10 @@ func.func @reduceReturn_wrong_type(%arg0 : index, %arg1: f32) {
%zero = arith.constant 0.0 : f32
%res = scf.parallel (%i0) = (%arg0) to (%arg0)
step (%arg0) init (%zero) -> f32 {
- scf.reduce(%arg1) : f32 {
+ scf.reduce(%arg1 : f32) {
^bb0(%lhs : f32, %rhs : f32):
%c0 = arith.constant 1 : index
- // expected-error at +1 {{needs to have type 'f32' (the type of the enclosing ReduceOp)}}
+ // expected-error at +1 {{must have type 'f32' (the type of the reduction inputs)}}
scf.reduce.return %c0 : index
}
}
@@ -475,9 +475,10 @@ func.func @std_for_operands_mismatch_4(%arg0 : index, %arg1 : index, %arg2 : ind
func.func @parallel_invalid_yield(
%arg0: index, %arg1: index, %arg2: index) {
+ // expected-error at below {{expects body to terminate with 'scf.reduce'}}
scf.parallel (%i0) = (%arg0) to (%arg1) step (%arg2) {
%c0 = arith.constant 1.0 : f32
- // expected-error at +1 {{'scf.yield' op not allowed to have operands inside 'scf.parallel'}}
+ // expected-note at below {{terminator here}}
scf.yield %c0 : f32
}
return
@@ -487,7 +488,7 @@ func.func @parallel_invalid_yield(
func.func @yield_invalid_parent_op() {
"my.op"() ({
- // expected-error at +1 {{'scf.yield' op expects parent op to be one of 'scf.execute_region, scf.for, scf.if, scf.index_switch, scf.parallel, scf.while'}}
+ // expected-error at +1 {{'scf.yield' op expects parent op to be one of 'scf.execute_region, scf.for, scf.if, scf.index_switch, scf.while'}}
scf.yield
}) : () -> ()
return
@@ -749,7 +750,7 @@ func.func @switch_missing_terminator(%arg0: index, %arg1: i32) {
// -----
func.func @parallel_missing_terminator(%0 : index) {
- // expected-error @below {{'scf.parallel' op expects body to terminate with 'scf.yield'}}
+ // expected-error @below {{expects body to terminate with 'scf.reduce'}}
"scf.parallel"(%0, %0, %0) ({
^bb0(%arg1: index):
// expected-note @below {{terminator here}}
diff --git a/mlir/test/Dialect/SCF/ops.mlir b/mlir/test/Dialect/SCF/ops.mlir
index 46d175d6870ce0..7f457ef3b6ba0c 100644
--- a/mlir/test/Dialect/SCF/ops.mlir
+++ b/mlir/test/Dialect/SCF/ops.mlir
@@ -87,18 +87,18 @@ func.func @std_parallel_loop(%arg0 : index, %arg1 : index, %arg2 : index,
%red:2 = scf.parallel (%i2) = (%min) to (%max) step (%i1)
init (%zero, %int_zero) -> (f32, i32) {
%one = arith.constant 1.0 : f32
- scf.reduce(%one) : f32 {
+ %int_one = arith.constant 1 : i32
+ scf.reduce(%one, %int_one : f32, i32) {
^bb0(%lhs : f32, %rhs: f32):
%res = arith.addf %lhs, %rhs : f32
scf.reduce.return %res : f32
- }
- %int_one = arith.constant 1 : i32
- scf.reduce(%int_one) : i32 {
+ }, {
^bb0(%lhs : i32, %rhs: i32):
%res = arith.muli %lhs, %rhs : i32
scf.reduce.return %res : i32
}
}
+ scf.reduce
}
return
}
@@ -121,25 +121,23 @@ func.func @std_parallel_loop(%arg0 : index, %arg1 : index, %arg2 : index,
// CHECK-SAME: step (%[[I1]])
// CHECK-SAME: init (%[[ZERO]], %[[INT_ZERO]]) -> (f32, i32) {
// CHECK-NEXT: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
-// CHECK-NEXT: scf.reduce(%[[ONE]]) : f32 {
+// CHECK-NEXT: %[[INT_ONE:.*]] = arith.constant 1 : i32
+// CHECK-NEXT: scf.reduce(%[[ONE]], %[[INT_ONE]] : f32, i32) {
// CHECK-NEXT: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
// CHECK-NEXT: %[[RES:.*]] = arith.addf %[[LHS]], %[[RHS]] : f32
// CHECK-NEXT: scf.reduce.return %[[RES]] : f32
-// CHECK-NEXT: }
-// CHECK-NEXT: %[[INT_ONE:.*]] = arith.constant 1 : i32
-// CHECK-NEXT: scf.reduce(%[[INT_ONE]]) : i32 {
+// CHECK-NEXT: }, {
// CHECK-NEXT: ^bb0(%[[LHS:.*]]: i32, %[[RHS:.*]]: i32):
// CHECK-NEXT: %[[RES:.*]] = arith.muli %[[LHS]], %[[RHS]] : i32
// CHECK-NEXT: scf.reduce.return %[[RES]] : i32
// CHECK-NEXT: }
-// CHECK-NEXT: scf.yield
// CHECK-NEXT: }
-// CHECK-NEXT: scf.yield
+// CHECK-NEXT: scf.reduce
func.func @parallel_explicit_yield(
%arg0: index, %arg1: index, %arg2: index) {
scf.parallel (%i0) = (%arg0) to (%arg1) step (%arg2) {
- scf.yield
+ scf.reduce
}
return
}
@@ -149,7 +147,7 @@ func.func @parallel_explicit_yield(
// CHECK-SAME: %[[ARG1:[A-Za-z0-9]+]]:
// CHECK-SAME: %[[ARG2:[A-Za-z0-9]+]]:
// CHECK-NEXT: scf.parallel (%{{.*}}) = (%[[ARG0]]) to (%[[ARG1]]) step (%[[ARG2]])
-// CHECK-NEXT: scf.yield
+// CHECK-NEXT: scf.reduce
// CHECK-NEXT: }
// CHECK-NEXT: return
// CHECK-NEXT: }
diff --git a/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir b/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
index 8a42b3a1000ed6..9fd33b4e524717 100644
--- a/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
+++ b/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
@@ -5,10 +5,10 @@ func.func @fuse_empty_loops() {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
- scf.yield
+ scf.reduce
}
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
- scf.yield
+ scf.reduce
}
return
}
@@ -18,7 +18,7 @@ func.func @fuse_empty_loops() {
// CHECK: [[C1:%.*]] = arith.constant 1 : index
// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
// CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
-// CHECK: scf.yield
+// CHECK: scf.reduce
// CHECK: }
// CHECK-NOT: scf.parallel
@@ -35,14 +35,14 @@ func.func @fuse_two(%A: memref<2x2xf32>, %B: memref<2x2xf32>,
%C_elem = memref.load %C[%i, %j] : memref<2x2xf32>
%sum_elem = arith.addf %B_elem, %C_elem : f32
memref.store %sum_elem, %sum[%i, %j] : memref<2x2xf32>
- scf.yield
+ scf.reduce
}
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
%sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32>
%A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
%product_elem = arith.mulf %sum_elem, %A_elem : f32
memref.store %product_elem, %result[%i, %j] : memref<2x2xf32>
- scf.yield
+ scf.reduce
}
memref.dealloc %sum : memref<2x2xf32>
return
@@ -64,7 +64,7 @@ func.func @fuse_two(%A: memref<2x2xf32>, %B: memref<2x2xf32>,
// CHECK: [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]]
// CHECK: [[PRODUCT_ELEM:%.*]] = arith.mulf [[SUM_ELEM_]], [[A_ELEM]]
// CHECK: memref.store [[PRODUCT_ELEM]], [[RESULT]]{{\[}}[[I]], [[J]]]
-// CHECK: scf.yield
+// CHECK: scf.reduce
// CHECK: }
// CHECK: memref.dealloc [[SUM]]
@@ -81,20 +81,20 @@ func.func @fuse_three(%lhs: memref<100x10xf32>, %rhs: memref<100xf32>,
scf.parallel (%i, %j) = (%c0, %c0) to (%c100, %c10) step (%c1, %c1) {
%rhs_elem = memref.load %rhs[%i] : memref<100xf32>
memref.store %rhs_elem, %broadcast_rhs[%i, %j] : memref<100x10xf32>
- scf.yield
+ scf.reduce
}
scf.parallel (%i, %j) = (%c0, %c0) to (%c100, %c10) step (%c1, %c1) {
%lhs_elem = memref.load %lhs[%i, %j] : memref<100x10xf32>
%broadcast_rhs_elem = memref.load %broadcast_rhs[%i, %j] : memref<100x10xf32>
%diff_elem = arith.subf %lhs_elem, %broadcast_rhs_elem : f32
memref.store %diff_elem, %diff[%i, %j] : memref<100x10xf32>
- scf.yield
+ scf.reduce
}
scf.parallel (%i, %j) = (%c0, %c0) to (%c100, %c10) step (%c1, %c1) {
%diff_elem = memref.load %diff[%i, %j] : memref<100x10xf32>
%exp_elem = math.exp %diff_elem : f32
memref.store %exp_elem, %result[%i, %j] : memref<100x10xf32>
- scf.yield
+ scf.reduce
}
memref.dealloc %broadcast_rhs : memref<100x10xf32>
memref.dealloc %diff : memref<100x10xf32>
@@ -120,7 +120,7 @@ func.func @fuse_three(%lhs: memref<100x10xf32>, %rhs: memref<100xf32>,
// CHECK: [[DIFF_ELEM_:%.*]] = memref.load [[DIFF]]{{\[}}[[I]], [[J]]]
// CHECK: [[EXP_ELEM:%.*]] = math.exp [[DIFF_ELEM_]]
// CHECK: memref.store [[EXP_ELEM]], [[RESULT]]{{\[}}[[I]], [[J]]]
-// CHECK: scf.yield
+// CHECK: scf.reduce
// CHECK: }
// CHECK: memref.dealloc [[BROADCAST_RHS]]
// CHECK: memref.dealloc [[DIFF]]
@@ -133,12 +133,12 @@ func.func @do_not_fuse_nested_ploop1() {
%c1 = arith.constant 1 : index
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
scf.parallel (%k, %l) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
- scf.yield
+ scf.reduce
}
- scf.yield
+ scf.reduce
}
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
- scf.yield
+ scf.reduce
}
return
}
@@ -154,13 +154,13 @@ func.func @do_not_fuse_nested_ploop2() {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
- scf.yield
+ scf.reduce
}
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
scf.parallel (%k, %l) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
- scf.yield
+ scf.reduce
}
- scf.yield
+ scf.reduce
}
return
}
@@ -176,10 +176,10 @@ func.func @do_not_fuse_loops_unmatching_num_loops() {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
- scf.yield
+ scf.reduce
}
scf.parallel (%i) = (%c0) to (%c2) step (%c1) {
- scf.yield
+ scf.reduce
}
return
}
@@ -194,11 +194,11 @@ func.func @do_not_fuse_loops_with_side_effecting_ops_in_between() {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
- scf.yield
+ scf.reduce
}
%buffer = memref.alloc() : memref<2x2xf32>
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
- scf.yield
+ scf.reduce
}
return
}
@@ -214,10 +214,10 @@ func.func @do_not_fuse_loops_unmatching_iteration_space() {
%c2 = arith.constant 2 : index
%c4 = arith.constant 4 : index
scf.parallel (%i, %j) = (%c0, %c0) to (%c4, %c4) step (%c2, %c2) {
- scf.yield
+ scf.reduce
}
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
- scf.yield
+ scf.reduce
}
return
}
@@ -239,7 +239,7 @@ func.func @do_not_fuse_unmatching_write_read_patterns(
%C_elem = memref.load %C[%i, %j] : memref<2x2xf32>
%sum_elem = arith.addf %B_elem, %C_elem : f32
memref.store %sum_elem, %common_buf[%i, %j] : memref<2x2xf32>
- scf.yield
+ scf.reduce
}
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
%k = arith.addi %i, %c1 : index
@@ -247,7 +247,7 @@ func.func @do_not_fuse_unmatching_write_read_patterns(
%A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
%product_elem = arith.mulf %sum_elem, %A_elem : f32
memref.store %product_elem, %result[%i, %j] : memref<2x2xf32>
- scf.yield
+ scf.reduce
}
memref.dealloc %common_buf : memref<2x2xf32>
return
@@ -269,7 +269,7 @@ func.func @do_not_fuse_unmatching_read_write_patterns(
%C_elem = memref.load %common_buf[%i, %j] : memref<2x2xf32>
%sum_elem = arith.addf %B_elem, %C_elem : f32
memref.store %sum_elem, %sum[%i, %j] : memref<2x2xf32>
- scf.yield
+ scf.reduce
}
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
%k = arith.addi %i, %c1 : index
@@ -277,7 +277,7 @@ func.func @do_not_fuse_unmatching_read_write_patterns(
%A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
%product_elem = arith.mulf %sum_elem, %A_elem : f32
memref.store %product_elem, %common_buf[%j, %i] : memref<2x2xf32>
- scf.yield
+ scf.reduce
}
memref.dealloc %sum : memref<2x2xf32>
return
@@ -294,13 +294,13 @@ func.func @do_not_fuse_loops_with_memref_defined_in_loop_bodies() {
%c1 = arith.constant 1 : index
%buffer = memref.alloc() : memref<2x2xf32>
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
- scf.yield
+ scf.reduce
}
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
%A = memref.subview %buffer[%c0, %c0][%c2, %c2][%c1, %c1]
: memref<2x2xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
%A_elem = memref.load %A[%i, %j] : memref<?x?xf32, strided<[?, ?], offset: ?>>
- scf.yield
+ scf.reduce
}
return
}
@@ -322,14 +322,14 @@ func.func @nested_fuse(%A: memref<2x2xf32>, %B: memref<2x2xf32>,
%C_elem = memref.load %C[%i, %j] : memref<2x2xf32>
%sum_elem = arith.addf %B_elem, %C_elem : f32
memref.store %sum_elem, %sum[%i, %j] : memref<2x2xf32>
- scf.yield
+ scf.reduce
}
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
%sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32>
%A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
%product_elem = arith.mulf %sum_elem, %A_elem : f32
memref.store %product_elem, %result[%i, %j] : memref<2x2xf32>
- scf.yield
+ scf.reduce
}
}
memref.dealloc %sum : memref<2x2xf32>
@@ -353,7 +353,7 @@ func.func @nested_fuse(%A: memref<2x2xf32>, %B: memref<2x2xf32>,
// CHECK: [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]]
// CHECK: [[PRODUCT_ELEM:%.*]] = arith.mulf [[SUM_ELEM_]], [[A_ELEM]]
// CHECK: memref.store [[PRODUCT_ELEM]], [[RESULT]]{{\[}}[[I]], [[J]]]
-// CHECK: scf.yield
+// CHECK: scf.reduce
// CHECK: }
// CHECK: }
// CHECK: memref.dealloc [[SUM]]
@@ -371,14 +371,14 @@ func.func @do_not_fuse_alias(%A: memref<2x2xf32>, %B: memref<2x2xf32>,
%C_elem = memref.load %C[%i, %j] : memref<2x2xf32>
%sum_elem = arith.addf %B_elem, %C_elem : f32
memref.store %sum_elem, %sum[%i, %j] : memref<2x2xf32>
- scf.yield
+ scf.reduce
}
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
%sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32>
%A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
%product_elem = arith.mulf %sum_elem, %A_elem : f32
memref.store %product_elem, %result[%i, %j] : memref<2x2xf32>
- scf.yield
+ scf.reduce
}
return
}
diff --git a/mlir/test/Dialect/SparseTensor/sparse_parallel_reduce.mlir b/mlir/test/Dialect/SparseTensor/sparse_parallel_reduce.mlir
index 7a35e0ff0c3a97..61b50bcd7d0c63 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_parallel_reduce.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_parallel_reduce.mlir
@@ -36,15 +36,14 @@
// CHECK: %[[TMP_12:.*]] = memref.load %[[TMP_2]][%[[TMP_arg4]]] : memref<?xf32>
// CHECK: %[[TMP_13:.*]] = memref.load %[[TMP_3]][%[[TMP_11]]] : memref<32xf32>
// CHECK: %[[TMP_14:.*]] = arith.mulf %[[TMP_12]], %[[TMP_13]] : f32
-// CHECK: scf.reduce(%[[TMP_14]]) : f32 {
+// CHECK: scf.reduce(%[[TMP_14]] : f32) {
// CHECK: ^bb0(%[[TMP_arg5:.*]]: f32, %[[TMP_arg6:.*]]: f32):
// CHECK: %[[TMP_15:.*]] = arith.addf %[[TMP_arg5]], %[[TMP_arg6]] : f32
// CHECK: scf.reduce.return %[[TMP_15]] : f32
// CHECK: }
-// CHECK: scf.yield
// CHECK: }
// CHECK: memref.store %[[TMP_10]], %[[TMP_4]][%[[TMP_arg3]]] : memref<16xf32>
-// CHECK: scf.yield
+// CHECK: scf.reduce
// CHECK: }
// CHECK: %[[TMP_5:.*]] = bufferization.to_tensor %[[TMP_4]] : memref<16xf32>
// CHECK: return %[[TMP_5]] : tensor<16xf32>
diff --git a/mlir/test/Transforms/invalid-parallel-loop-collapsing.mlir b/mlir/test/Transforms/invalid-parallel-loop-collapsing.mlir
index 6f98d2c062a25d..4a3e4dc35d4f11 100644
--- a/mlir/test/Transforms/invalid-parallel-loop-collapsing.mlir
+++ b/mlir/test/Transforms/invalid-parallel-loop-collapsing.mlir
@@ -20,7 +20,7 @@
func.func @too_few_iters(%arg0: index, %arg1: index, %arg2: index) {
// expected-error @+1 {{op has 1 iter args while this limited functionality testing pass was configured only for loops with exactly 2 iter args.}}
scf.parallel (%arg3) = (%arg0) to (%arg1) step (%arg2) {
- scf.yield
+ scf.reduce
}
return
}
@@ -28,7 +28,7 @@ func.func @too_few_iters(%arg0: index, %arg1: index, %arg2: index) {
func.func @too_many_iters(%arg0: index, %arg1: index, %arg2: index) {
// expected-error @+1 {{op has 3 iter args while this limited functionality testing pass was configured only for loops with exactly 2 iter args.}}
scf.parallel (%arg3, %arg4, %arg5) = (%arg0, %arg0, %arg0) to (%arg1, %arg1, %arg1) step (%arg2, %arg2, %arg2) {
- scf.yield
+ scf.reduce
}
return
}
diff --git a/mlir/test/Transforms/loop-invariant-code-motion.mlir b/mlir/test/Transforms/loop-invariant-code-motion.mlir
index 1415583dde9da7..dcc314f36ae0a8 100644
--- a/mlir/test/Transforms/loop-invariant-code-motion.mlir
+++ b/mlir/test/Transforms/loop-invariant-code-motion.mlir
@@ -374,7 +374,7 @@ func.func @parallel_loop_with_invariant() {
// CHECK-NEXT: arith.addi
// CHECK-NEXT: scf.parallel (%[[A:.*]],{{.*}}) =
// CHECK-NEXT: arith.addi %[[A]]
- // CHECK-NEXT: yield
+ // CHECK-NEXT: reduce
// CHECK-NEXT: }
// CHECK-NEXT: return
diff --git a/mlir/test/Transforms/parallel-loop-collapsing.mlir b/mlir/test/Transforms/parallel-loop-collapsing.mlir
index c606fe7588526a..660d7edb2fbb37 100644
--- a/mlir/test/Transforms/parallel-loop-collapsing.mlir
+++ b/mlir/test/Transforms/parallel-loop-collapsing.mlir
@@ -43,4 +43,4 @@ func.func @parallel_many_dims() {
// CHECK: [[V2:%.*]] = arith.muli [[V0]], [[C10]] : index
// CHECK: [[I3:%.*]] = arith.addi [[V2]], [[C9]] : index
// CHECK: "magic.op"([[I0]], [[C3]], [[C6]], [[I3]], [[C12]]) : (index, index, index, index, index) -> index
-// CHECK: scf.yield
+// CHECK: scf.reduce
diff --git a/mlir/test/Transforms/single-parallel-loop-collapsing.mlir b/mlir/test/Transforms/single-parallel-loop-collapsing.mlir
index 7b6883896dc108..542786b5fa5e57 100644
--- a/mlir/test/Transforms/single-parallel-loop-collapsing.mlir
+++ b/mlir/test/Transforms/single-parallel-loop-collapsing.mlir
@@ -29,6 +29,6 @@ func.func @collapse_to_single() {
// CHECK: [[V1:%.*]] = arith.muli [[I1_COUNT]], [[C3]] : index
// CHECK: [[I0:%.*]] = arith.addi [[V1]], [[C3]] : index
// CHECK: "magic.op"([[I0]], [[I1]]) : (index, index) -> index
-// CHECK: scf.yield
+// CHECK: scf.reduce
// CHECK-NEXT: }
// CHECK-NEXT: return
More information about the Mlir-commits
mailing list