[Mlir-commits] [mlir] [mlir][Interfaces] `LoopLikeOpInterface`: Add helpers to query tied inits/iter_args (PR #70408)
Matthias Springer
llvmlistbot at llvm.org
Thu Oct 26 20:49:14 PDT 2023
https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/70408
The `LoopLikeOpInterface` already has interface methods to query inits and iter_args. This commit adds helper functions to query tied init/iter_arg pairs and removes the corresponding functions for `scf::ForOp`.
>From ce1059130b5729c1f9269ac1facf0058c606de8a Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Fri, 27 Oct 2023 12:48:16 +0900
Subject: [PATCH] [mlir][Interfaces] `LoopLikeOpInterface`: Add helpers to
query tied inits/iter_args
The `LoopLikeOpInterface` already has interface methods to query inits and iter_args. This commit adds helper functions to query tied init/iter_arg pairs and removes the corresponding functions for `scf::ForOp`.
---
mlir/include/mlir/Dialect/SCF/IR/SCFOps.td | 27 +++----------------
.../mlir/Interfaces/LoopLikeInterface.h | 21 +++++++++++++++
.../mlir/Interfaces/LoopLikeInterface.td | 24 ++++++++++++++++-
.../Linalg/Transforms/HoistPadding.cpp | 2 +-
mlir/lib/Dialect/SCF/IR/SCF.cpp | 4 +--
.../BufferizableOpInterfaceImpl.cpp | 10 +++----
.../SCF/Transforms/LoopCanonicalization.cpp | 2 +-
.../Dialect/SCF/Transforms/LoopPipelining.cpp | 4 +--
.../SCF/Transforms/TileUsingInterface.cpp | 2 +-
9 files changed, 59 insertions(+), 37 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index fde9176c670bc6b..43beebc1bf54166 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -123,7 +123,8 @@ def ForOp : SCF_Op<"for",
[AutomaticAllocationScope, DeclareOpInterfaceMethods<LoopLikeOpInterface,
["getInitsMutable", "getSingleInductionVar", "getSingleLowerBound",
"getSingleStep", "getSingleUpperBound", "getYieldedValuesMutable",
- "promoteIfSingleIteration", "replaceWithAdditionalYields"]>,
+ "getLoopResults", "promoteIfSingleIteration",
+ "replaceWithAdditionalYields"]>,
AllTypesMatch<["lowerBound", "upperBound", "step"]>,
ConditionallySpeculatable,
DeclareOpInterfaceMethods<RegionBranchOpInterface,
@@ -268,28 +269,6 @@ def ForOp : SCF_Op<"for",
/// Number of operands controlling the loop: lb, ub, step
unsigned getNumControlOperands() { return 3; }
- /// Get the region iter arg that corresponds to an OpOperand.
- /// This helper prevents internal op implementation detail leakage to
- /// clients by hiding the operand / block argument mapping.
- BlockArgument getRegionIterArgForOpOperand(OpOperand &opOperand) {
- assert(opOperand.getOperandNumber() >= getNumControlOperands() &&
- "expected an iter args operand");
- assert(opOperand.getOwner() == getOperation() &&
- "opOperand does not belong to this scf::ForOp operation");
- return getRegionIterArgs()[
- opOperand.getOperandNumber() - getNumControlOperands()];
- }
- /// Get the OpOperand& that corresponds to a region iter arg.
- /// This helper prevents internal op implementation detail leakage to
- /// clients by hiding the operand / block argument mapping.
- OpOperand &getOpOperandForRegionIterArg(BlockArgument bbArg) {
- assert(bbArg.getArgNumber() >= getNumInductionVars() &&
- "expected a bbArg that is not an induction variable");
- assert(bbArg.getOwner()->getParentOp() == getOperation() &&
- "bbArg does not belong to the scf::ForOp body");
- return getOperation()->getOpOperand(
- getNumControlOperands() + bbArg.getArgNumber() - getNumInductionVars());
- }
/// Get the OpResult that corresponds to an OpOperand.
/// Assert that opOperand is an iterArg.
/// This helper prevents internal op implementation detail leakage to
@@ -963,7 +942,7 @@ def WhileOp : SCF_Op<"while",
[DeclareOpInterfaceMethods<RegionBranchOpInterface,
["getEntrySuccessorOperands"]>,
DeclareOpInterfaceMethods<LoopLikeOpInterface,
- ["getRegionIterArgs", "getYieldedValuesMutable"]>,
+ ["getLoopResults", "getRegionIterArgs", "getYieldedValuesMutable"]>,
RecursiveMemoryEffects, SingleBlock]> {
let summary = "a generic 'while' loop";
let description = [{
diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.h b/mlir/include/mlir/Interfaces/LoopLikeInterface.h
index 7c7d378d0590ab1..b5ac1e367b2a761 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.h
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.h
@@ -28,6 +28,27 @@ using NewYieldValuesFn = std::function<SmallVector<Value>(
namespace detail {
/// Verify invariants of the LoopLikeOpInterface.
LogicalResult verifyLoopLikeOpInterface(Operation *op);
+
+/// Return the position of the occurrence of `val` in `array` or "nullopt" if
+/// `val` is not present.
+template <typename Range, typename T>
+std::optional<int64_t> findIdx(const Range &array, const T &val) {
+ for (int64_t i = 0, e = array.size(); i < e; ++i)
+ if (array[i] == val)
+ return i;
+ return std::nullopt;
+}
+
+/// Return the position of the occurrence of `val` in `array` or "nullopt" if
+/// `val` is not present. Compare elements by pointer.
+template <typename Range, typename T>
+std::optional<int64_t> findSamePtrIdx(const Range &array, const T &val) {
+ for (int64_t i = 0, e = array.size(); i < e; ++i)
+ if (&array[i] == &val)
+ return i;
+ return std::nullopt;
+}
+
} // namespace detail
} // namespace mlir
diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.td b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
index afb7860491664da..76884c70d82a52a 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
@@ -238,7 +238,29 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
unsigned firstOperandIndex = initsMutable.begin()->getOperandNumber();
return OperandRange(
$_op->operand_begin() + firstOperandIndex,
- $_op->operand_begin() + firstOperandIndex + initsMutable.size()); }
+ $_op->operand_begin() + firstOperandIndex + initsMutable.size());
+ }
+
+ /// Return the region iter_arg that corresponds to the given init operand.
+ BlockArgument getTiedLoopRegionIterArg(OpOperand *opOperand) {
+ if (auto idx = detail::findSamePtrIdx($_op.getInitsMutable(), *opOperand))
+ return $_op.getRegionIterArgs()[*idx];
+ return {};
+ }
+
+ /// Return the init operand that corresponds to the given region iter_arg.
+ OpOperand *getTiedLoopInit(BlockArgument bbArg) {
+ if (auto idx = detail::findIdx($_op.getRegionIterArgs(), bbArg))
+ return &$_op.getInitsMutable()[*idx];
+ return nullptr;
+ }
+
+ /// Return the yielded value that corresponds to the given region iter_arg.
+ OpOperand *getTiedLoopYieldedValue(BlockArgument bbArg) {
+ if (auto idx = detail::findIdx($_op.getRegionIterArgs(), bbArg))
+ return &$_op.getYieldedValuesMutable()[*idx];
+ return nullptr;
+ }
}];
let verifyWithRegions = 1;
diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
index 8fef99bb375095a..19f704f5232ed81 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
@@ -558,7 +558,7 @@ static FailureOr<PackingResult> buildPackingLoopNestImpl(
break;
if (forOp != outerLoop && !outerLoop->isAncestor(forOp))
break;
- OpOperand &operand = forOp.getOpOperandForRegionIterArg(bbArg);
+ OpOperand &operand = *forOp.getTiedLoopInit(bbArg);
bvm.map(bbArg, operand.get());
bbArg = dyn_cast<BlockArgument>(operand.get());
}
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index cb888bc17c571fe..b8b75f3f476a5da 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -994,8 +994,8 @@ replaceTensorCastForOpIterArg(PatternRewriter &rewriter, OpOperand &operand,
// corresponding to the `replacement` value.
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(&newBlock, newBlock.begin());
- BlockArgument newRegionIterArg = newForOp.getRegionIterArgForOpOperand(
- newForOp->getOpOperand(operand.getOperandNumber()));
+ BlockArgument newRegionIterArg = newForOp.getTiedLoopRegionIterArg(
+ &newForOp->getOpOperand(operand.getOperandNumber()));
Value castIn = rewriter.create<tensor::CastOp>(newForOp.getLoc(), oldType,
newRegionIterArg);
newBlockTransferArgs[newRegionIterArg.getArgNumber()] = castIn;
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index 455b7d8bcaff0e2..885e00b48ff8434 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -602,7 +602,7 @@ struct ForOpInterface
// scf::ForOp alone doesn't bufferize to a memory read, one of the uses of
// its matching bbArg may.
- return state.isValueRead(forOp.getRegionIterArgForOpOperand(opOperand));
+ return state.isValueRead(forOp.getTiedLoopRegionIterArg(&opOperand));
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
@@ -626,7 +626,7 @@ struct ForOpInterface
// corresponding iter_args and yield values are equivalent.
auto forOp = cast<scf::ForOp>(op);
OpOperand &forOperand = forOp.getOpOperandForResult(opResult);
- auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
+ auto bbArg = forOp.getTiedLoopRegionIterArg(&forOperand);
bool equivalentYield = state.areEquivalentBufferizedValues(
bbArg, forOp.getYieldedValues()[opResult.getResultNumber()]);
return equivalentYield ? BufferRelation::Equivalent
@@ -703,15 +703,15 @@ struct ForOpInterface
if (auto opResult = dyn_cast<OpResult>(value)) {
// The type of an OpResult must match the corresponding iter_arg type.
- BlockArgument bbArg = forOp.getRegionIterArgForOpOperand(
- forOp.getOpOperandForResult(opResult));
+ BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(
+ &forOp.getOpOperandForResult(opResult));
return bufferization::getBufferType(bbArg, options, invocationStack);
}
// Compute result/argument number.
BlockArgument bbArg = cast<BlockArgument>(value);
unsigned resultNum =
- forOp.getResultForOpOperand(forOp.getOpOperandForRegionIterArg(bbArg))
+ forOp.getResultForOpOperand(*forOp.getTiedLoopInit(bbArg))
.getResultNumber();
// Compute the bufferized type.
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
index 43e79d309c66784..eee0791b397ae68 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
@@ -98,7 +98,7 @@ struct DimOfIterArgFolder : public OpRewritePattern<OpTy> {
if (!isShapePreserving(forOp, blockArg.getArgNumber() - 1))
return failure();
- Value initArg = forOp.getOpOperandForRegionIterArg(blockArg).get();
+ Value initArg = forOp.getTiedLoopInit(blockArg)->get();
rewriter.updateRootInPlace(
dimOp, [&]() { dimOp.getSourceMutable().assign(initArg); });
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
index 4ff7965c0858a3b..5537a8b212c51f7 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
@@ -194,8 +194,8 @@ cloneAndUpdateOperands(RewriterBase &rewriter, Operation *op,
void LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) {
// Initialize the iteration argument to the loop initiale values.
- for (BlockArgument &arg : forOp.getRegionIterArgs()) {
- OpOperand &operand = forOp.getOpOperandForRegionIterArg(arg);
+ for (auto [arg, operand] :
+ llvm::zip(forOp.getRegionIterArgs(), forOp.getInitsMutable())) {
setValueMapping(arg, operand.get(), 0);
}
auto yield = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 2bddb498c21ac33..e649125a09fea6a 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -523,7 +523,7 @@ getUntiledProducerFromSliceSource(OpOperand *source,
scf::ForOp loop = *loopIt;
if (iterArg.getOwner()->getParentOp() != loop)
break;
- source = &loop.getOpOperandForRegionIterArg(iterArg);
+ source = loop.getTiedLoopInit(iterArg);
loopIt++;
}
if (loopIt == loops.rend())
More information about the Mlir-commits
mailing list