[Mlir-commits] [mlir] 98a6edd - [mlir][Interfaces] `LoopLikeOpInterface`: Expose tied loop results (#70535)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Oct 31 16:34:18 PDT 2023
Author: Matthias Springer
Date: 2023-11-01T08:34:14+09:00
New Revision: 98a6edd38f960679e65124d52e3c61f4abd1419f
URL: https://github.com/llvm/llvm-project/commit/98a6edd38f960679e65124d52e3c61f4abd1419f
DIFF: https://github.com/llvm/llvm-project/commit/98a6edd38f960679e65124d52e3c61f4abd1419f.diff
LOG: [mlir][Interfaces] `LoopLikeOpInterface`: Expose tied loop results (#70535)
Expose loop results, which correspond to the region iter_arg values that
are returned from the loop when there are no more iterations. Exposing
loop results is optional because some loops (e.g., `scf.while`) do not
have a 1-to-1 mapping between region iter_args and op results.
Also add additional helper functions to query tied
results/iter_args/inits.
Added:
Modified:
mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
mlir/include/mlir/Interfaces/LoopLikeInterface.td
mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
mlir/lib/Dialect/SCF/IR/SCF.cpp
mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
mlir/lib/Interfaces/LoopLikeInterface.cpp
mlir/test/Dialect/SCF/invalid.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 43beebc1bf54166..38937fe28949436 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -269,28 +269,6 @@ def ForOp : SCF_Op<"for",
/// Number of operands controlling the loop: lb, ub, step
unsigned getNumControlOperands() { return 3; }
- /// Get the OpResult that corresponds to an OpOperand.
- /// Assert that opOperand is an iterArg.
- /// This helper prevents internal op implementation detail leakage to
- /// clients by hiding the operand / block argument mapping.
- OpResult getResultForOpOperand(OpOperand &opOperand) {
- assert(opOperand.getOperandNumber() >= getNumControlOperands() &&
- "expected an iter args operand");
- assert(opOperand.getOwner() == getOperation() &&
- "opOperand does not belong to this scf::ForOp operation");
- return getOperation()->getResult(
- opOperand.getOperandNumber() - getNumControlOperands());
- }
- /// Get the OpOperand& that corresponds to an OpResultOpOperand.
- /// This helper prevents internal op implementation detail leakage to
- /// clients by hiding the operand / block argument mapping.
- OpOperand &getOpOperandForResult(OpResult opResult) {
- assert(opResult.getDefiningOp() == getOperation() &&
- "opResult does not belong to the scf::ForOp operation");
- return getOperation()->getOpOperand(
- getNumControlOperands() + opResult.getResultNumber());
- }
-
/// Returns the step as an `APInt` if it is constant.
std::optional<APInt> getConstantStep();
@@ -942,7 +920,7 @@ def WhileOp : SCF_Op<"while",
[DeclareOpInterfaceMethods<RegionBranchOpInterface,
["getEntrySuccessorOperands"]>,
DeclareOpInterfaceMethods<LoopLikeOpInterface,
- ["getLoopResults", "getRegionIterArgs", "getYieldedValuesMutable"]>,
+ ["getRegionIterArgs", "getYieldedValuesMutable"]>,
RecursiveMemoryEffects, SingleBlock]> {
let summary = "a generic 'while' loop";
let description = [{
@@ -1156,7 +1134,7 @@ def IndexSwitchOp : SCF_Op<"index_switch", [RecursiveMemoryEffects,
//===----------------------------------------------------------------------===//
def YieldOp : SCF_Op<"yield", [Pure, ReturnLike, Terminator,
- ParentOneOf<["ExecuteRegionOp, ForOp", "IfOp", "IndexSwitchOp",
+ ParentOneOf<["ExecuteRegionOp", "ForOp", "IfOp", "IndexSwitchOp",
"ParallelOp", "WhileOp"]>]> {
let summary = "loop yield and termination operation";
let description = [{
diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.td b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
index d3d07eec8ebff57..75d90b67bd82f36 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
@@ -33,6 +33,13 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
If one of the respective interface methods is implemented, so must the other
two. The interface verifier ensures that the number of types of the region
iter_args, init values and yielded values match.
+
+ Optionally, "loop results" can be exposed through this interface. These are
+ the values that are returned from the loop op when there are no more
+ iterations. The number and types of the loop results must match with the
+ region iter_args. Note: Loop results are optional because some loops
+ (e.g., `scf.while`) may produce results that do match 1-to-1 with the
+ region iter_args.
}];
let cppNamespace = "::mlir";
@@ -166,6 +173,26 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
return {};
}]
>,
+ InterfaceMethod<[{
+ Return the range of results that are return from this loop and
+ correspond to the "init" operands.
+
+ Note: This interface method is optional. If loop results are not
+ exposed via this interface, "std::nullopt" should be returned.
+ Otherwise, the number and types of results must match with the
+ region iter_args, inits and yielded values that are exposed via this
+ interface. If loop results are exposed but this loop op has no
+ loop-carried variables, an empty result range (and not "std::nullopt")
+ should be returned.
+ }],
+ /*retTy=*/"::std::optional<::mlir::ResultRange>",
+ /*methodName=*/"getLoopResults",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return ::std::nullopt;
+ }]
+ >,
InterfaceMethod<[{
Append the specified additional "init" operands: replace this loop with
a new loop that has the additional init operands. The loop body of
@@ -242,6 +269,8 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
}
/// Return the region iter_arg that corresponds to the given init operand.
+ /// Return an "empty" block argument if the given operand is not an init
+ /// operand of this loop op.
BlockArgument getTiedLoopRegionIterArg(OpOperand *opOperand) {
auto initsMutable = $_op.getInitsMutable();
auto it = llvm::find(initsMutable, *opOperand);
@@ -250,7 +279,22 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
return $_op.getRegionIterArgs()[std::distance(initsMutable.begin(), it)];
}
+ /// Return the region iter_arg that corresponds to the given loop result.
+ /// Return an "empty" block argument if the given OpResult is not a loop
+ /// result or if this op does not expose any loop results.
+ BlockArgument getTiedLoopRegionIterArg(OpResult opResult) {
+ auto loopResults = $_op.getLoopResults();
+ if (!loopResults)
+ return {};
+ auto it = llvm::find(*loopResults, opResult);
+ if (it == loopResults->end())
+ return {};
+ return $_op.getRegionIterArgs()[std::distance(loopResults->begin(), it)];
+ }
+
/// Return the init operand that corresponds to the given region iter_arg.
+ /// Return "nullptr" if the given block argument is not a region iter_arg
+ /// of this loop op.
OpOperand *getTiedLoopInit(BlockArgument bbArg) {
auto iterArgs = $_op.getRegionIterArgs();
auto it = llvm::find(iterArgs, bbArg);
@@ -259,7 +303,22 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
return &$_op.getInitsMutable()[std::distance(iterArgs.begin(), it)];
}
+ /// Return the init operand that corresponds to the given loop result.
+ /// Return "nullptr" if the given OpResult is not a loop result or if this
+ /// op does not expose any loop results.
+ OpOperand *getTiedLoopInit(OpResult opResult) {
+ auto loopResults = $_op.getLoopResults();
+ if (!loopResults)
+ return nullptr;
+ auto it = llvm::find(*loopResults, opResult);
+ if (it == loopResults->end())
+ return nullptr;
+ return &$_op.getInitsMutable()[std::distance(loopResults->begin(), it)];
+ }
+
/// Return the yielded value that corresponds to the given region iter_arg.
+ /// Return "nullptr" if the given block argument is not a region iter_arg
+ /// of this loop op.
OpOperand *getTiedLoopYieldedValue(BlockArgument bbArg) {
auto iterArgs = $_op.getRegionIterArgs();
auto it = llvm::find(iterArgs, bbArg);
@@ -268,6 +327,34 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
return
&$_op.getYieldedValuesMutable()[std::distance(iterArgs.begin(), it)];
}
+
+ /// Return the loop result that corresponds to the given init operand.
+ /// Return an "empty" OpResult if the given operand is not an init operand
+ /// of this loop op or if this op does not expose any loop results.
+ OpResult getTiedLoopResult(OpOperand *opOperand) {
+ auto loopResults = $_op.getLoopResults();
+ if (!loopResults)
+ return {};
+ auto initsMutable = $_op.getInitsMutable();
+ auto it = llvm::find(initsMutable, *opOperand);
+ if (it == initsMutable.end())
+ return {};
+ return (*loopResults)[std::distance(initsMutable.begin(), it)];
+ }
+
+ /// Return the loop result that corresponds to the given region iter_arg.
+ /// Return an "empty" OpResult if the given block argument is not a region
+ /// iter_arg of this loop op or if this op does not expose any loop results.
+ OpResult getTiedLoopResult(BlockArgument bbArg) {
+ auto loopResults = $_op.getLoopResults();
+ if (!loopResults)
+ return {};
+ auto iterArgs = $_op.getRegionIterArgs();
+ auto it = llvm::find(iterArgs, bbArg);
+ if (it == iterArgs.end())
+ return {};
+ return (*loopResults)[std::distance(iterArgs.begin(), it)];
+ }
}];
let verifyWithRegions = 1;
diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
index 19f704f5232ed81..866f51b0e92bbde 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
@@ -810,7 +810,7 @@ padThroughLoopIterArg(RewriterBase &rewriter, Value paddedValueBeforeHoisting,
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPointAfter(hoistedPackedTensor.getDefiningOp());
- unsigned iterArgNumber = forOp.getResultForOpOperand(*pUse).getResultNumber();
+ unsigned iterArgNumber = forOp.getTiedLoopResult(pUse).getResultNumber();
auto yieldingExtractSliceOp = forOp.getYieldedValues()[iterArgNumber]
.getDefiningOp<tensor::ExtractSliceOp>();
if (!yieldingExtractSliceOp)
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index b8b75f3f476a5da..bc33fe2a9a01079 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -390,6 +390,8 @@ std::optional<OpFoldResult> ForOp::getSingleUpperBound() {
return OpFoldResult(getUpperBound());
}
+std::optional<ResultRange> ForOp::getLoopResults() { return getResults(); }
+
/// Promotes the loop body of a forOp to its containing block if the forOp
/// it can be determined that the loop has a single iteration.
LogicalResult ForOp::promoteIfSingleIteration(RewriterBase &rewriter) {
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index 885e00b48ff8434..dc3c46bf896a9cf 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -614,7 +614,7 @@ struct ForOpInterface
AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
auto forOp = cast<scf::ForOp>(op);
- OpResult opResult = forOp.getResultForOpOperand(opOperand);
+ OpResult opResult = forOp.getTiedLoopResult(&opOperand);
BufferRelation relation = bufferRelation(op, opResult, state);
return {{opResult, relation,
/*isDefinite=*/relation == BufferRelation::Equivalent}};
@@ -625,10 +625,9 @@ struct ForOpInterface
// ForOp results are equivalent to their corresponding init_args if the
// corresponding iter_args and yield values are equivalent.
auto forOp = cast<scf::ForOp>(op);
- OpOperand &forOperand = forOp.getOpOperandForResult(opResult);
- auto bbArg = forOp.getTiedLoopRegionIterArg(&forOperand);
+ BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult);
bool equivalentYield = state.areEquivalentBufferizedValues(
- bbArg, forOp.getYieldedValues()[opResult.getResultNumber()]);
+ bbArg, forOp.getTiedLoopYieldedValue(bbArg)->get());
return equivalentYield ? BufferRelation::Equivalent
: BufferRelation::Unknown;
}
@@ -703,16 +702,13 @@ struct ForOpInterface
if (auto opResult = dyn_cast<OpResult>(value)) {
// The type of an OpResult must match the corresponding iter_arg type.
- BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(
- &forOp.getOpOperandForResult(opResult));
+ BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult);
return bufferization::getBufferType(bbArg, options, invocationStack);
}
// Compute result/argument number.
BlockArgument bbArg = cast<BlockArgument>(value);
- unsigned resultNum =
- forOp.getResultForOpOperand(*forOp.getTiedLoopInit(bbArg))
- .getResultNumber();
+ unsigned resultNum = forOp.getTiedLoopResult(bbArg).getResultNumber();
// Compute the bufferized type.
auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index e649125a09fea6a..df162d29a48eb89 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -609,8 +609,7 @@ mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
if (destinationInitArg &&
(*destinationInitArg)->getOwner() == outerMostLoop) {
unsigned iterArgNumber =
- outerMostLoop.getResultForOpOperand(**destinationInitArg)
- .getResultNumber();
+ outerMostLoop.getTiedLoopResult(*destinationInitArg).getResultNumber();
int64_t resultNumber = fusableProducer.getResultNumber();
if (auto dstOp =
dyn_cast<DestinationStyleOpInterface>(fusableProducer.getOwner())) {
diff --git a/mlir/lib/Interfaces/LoopLikeInterface.cpp b/mlir/lib/Interfaces/LoopLikeInterface.cpp
index 15a816f4e448839..be1316b95688bf2 100644
--- a/mlir/lib/Interfaces/LoopLikeInterface.cpp
+++ b/mlir/lib/Interfaces/LoopLikeInterface.cpp
@@ -58,7 +58,7 @@ LogicalResult detail::verifyLoopLikeOpInterface(Operation *op) {
// but the LoopLikeOpInterface provides better error messages.
auto loopLikeOp = cast<LoopLikeOpInterface>(op);
- // Verify number of inits/iter_args/yielded values.
+ // Verify number of inits/iter_args/yielded values/loop results.
if (loopLikeOp.getInits().size() != loopLikeOp.getRegionIterArgs().size())
return op->emitOpError("
diff erent number of inits and region iter_args: ")
<< loopLikeOp.getInits().size()
@@ -69,21 +69,43 @@ LogicalResult detail::verifyLoopLikeOpInterface(Operation *op) {
"
diff erent number of region iter_args and yielded values: ")
<< loopLikeOp.getRegionIterArgs().size()
<< " != " << loopLikeOp.getYieldedValues().size();
+ if (loopLikeOp.getLoopResults() && loopLikeOp.getLoopResults()->size() !=
+ loopLikeOp.getRegionIterArgs().size())
+ return op->emitOpError(
+ "
diff erent number of loop results and region iter_args: ")
+ << loopLikeOp.getLoopResults()->size()
+ << " != " << loopLikeOp.getRegionIterArgs().size();
- // Verify types of inits/iter_args/yielded values.
+ // Verify types of inits/iter_args/yielded values/loop results.
int64_t i = 0;
for (const auto it :
llvm::zip_equal(loopLikeOp.getInits(), loopLikeOp.getRegionIterArgs(),
loopLikeOp.getYieldedValues())) {
if (std::get<0>(it).getType() != std::get<1>(it).getType())
- op->emitOpError(std::to_string(i))
- << "-th init and " << i << "-th region iter_arg have
diff erent type: "
- << std::get<0>(it).getType() << " != " << std::get<1>(it).getType();
+ return op->emitOpError(std::to_string(i))
+ << "-th init and " << i
+ << "-th region iter_arg have
diff erent type: "
+ << std::get<0>(it).getType()
+ << " != " << std::get<1>(it).getType();
if (std::get<1>(it).getType() != std::get<2>(it).getType())
- op->emitOpError(std::to_string(i))
- << "-th region iter_arg and " << i
- << "-th yielded value have
diff erent type: "
- << std::get<1>(it).getType() << " != " << std::get<2>(it).getType();
+ return op->emitOpError(std::to_string(i))
+ << "-th region iter_arg and " << i
+ << "-th yielded value have
diff erent type: "
+ << std::get<1>(it).getType()
+ << " != " << std::get<2>(it).getType();
+ ++i;
+ }
+ i = 0;
+ if (loopLikeOp.getLoopResults()) {
+ for (const auto it : llvm::zip_equal(loopLikeOp.getRegionIterArgs(),
+ *loopLikeOp.getLoopResults())) {
+ if (std::get<0>(it).getType() != std::get<1>(it).getType())
+ return op->emitOpError(std::to_string(i))
+ << "-th region iter_arg and " << i
+ << "-th loop result have
diff erent type: "
+ << std::get<0>(it).getType()
+ << " != " << std::get<1>(it).getType();
+ }
++i;
}
diff --git a/mlir/test/Dialect/SCF/invalid.mlir b/mlir/test/Dialect/SCF/invalid.mlir
index 1b2c3f563195c52..ad07a8b11327deb 100644
--- a/mlir/test/Dialect/SCF/invalid.mlir
+++ b/mlir/test/Dialect/SCF/invalid.mlir
@@ -96,6 +96,19 @@ func.func @not_enough_loop_results(%arg0: index, %init: f32) {
// -----
+func.func @scf_for_incorrect_result_type(%arg0: index, %init: f32) {
+ // expected-error @below{{0-th region iter_arg and 0-th loop result have
diff erent type: 'f32' != 'f64'}}
+ "scf.for"(%arg0, %arg0, %arg0, %init) (
+ {
+ ^bb0(%i0 : index, %iter: f32):
+ scf.yield %iter : f32
+ }
+ ) : (index, index, index, f32) -> (f64)
+ return
+}
+
+// -----
+
func.func @too_many_iter_args(%arg0: index, %init: f32) {
// expected-error @below{{
diff erent number of inits and region iter_args: 1 != 2}}
%x = "scf.for"(%arg0, %arg0, %arg0, %init) (
@@ -449,7 +462,6 @@ func.func @std_for_operands_mismatch_4(%arg0 : index, %arg1 : index, %arg2 : ind
%s0 = arith.constant 0.0 : f32
%t0 = arith.constant 1.0 : f32
// expected-error @below {{1-th region iter_arg and 1-th yielded value have
diff erent type: 'f32' != 'i32'}}
- // expected-error @below {{along control flow edge from Region #0 to Region #0: source type #1 'i32' should match input type #1 'f32'}}
%result1:2 = scf.for %i0 = %arg0 to %arg1 step %arg2
iter_args(%si = %s0, %ti = %t0) -> (f32, f32) {
%sn = arith.addf %si, %si : f32
More information about the Mlir-commits
mailing list