[Mlir-commits] [mlir] de9caf2 - [mlir][Interfaces] Add `promoteIfSingleIteration` to `LoopLikeOpInterface`
Matthias Springer
llvmlistbot at llvm.org
Mon Jul 3 09:03:58 PDT 2023
Author: Matthias Springer
Date: 2023-07-03T18:03:35+02:00
New Revision: de9caf2f3244849cf0fe29bbf7cc67bae9035329
URL: https://github.com/llvm/llvm-project/commit/de9caf2f3244849cf0fe29bbf7cc67bae9035329
DIFF: https://github.com/llvm/llvm-project/commit/de9caf2f3244849cf0fe29bbf7cc67bae9035329.diff
LOG: [mlir][Interfaces] Add `promoteIfSingleIteration` to `LoopLikeOpInterface`
There are existing implementations for `scf.for`, `scf.forall` and `affine.for`. This revision adds an interface method to the `LoopLikeOpInterface`.
* `scf.forall` now implements the `LoopLikeOpInterface`.
* The implementations of `scf.for` and `scf.forall` become interface method implementations. `affine.for` remains as is for the moment. (The implementation of `promoteIfSingleIteration` depepends on helper functions from `MLIRAffineAnalysis`, which cannot be used from `MLIRAffineDialect`, where the interface is currently implemented.)
* More efficient implementations of `promoteIfSingleIteration`. In particular, the `scf.forall` operation now inlines operations instead of cloning them. This also preserves handles when used from the transform dialect.
Differential Revision: https://reviews.llvm.org/D154343
Added:
Modified:
mlir/include/mlir/Dialect/SCF/IR/SCF.h
mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
mlir/include/mlir/Dialect/SCF/Utils/Utils.h
mlir/include/mlir/Interfaces/LoopLikeInterface.h
mlir/include/mlir/Interfaces/LoopLikeInterface.td
mlir/lib/Dialect/SCF/IR/SCF.cpp
mlir/lib/Dialect/SCF/Utils/Utils.cpp
mlir/test/Dialect/SCF/one-shot-bufferize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCF.h b/mlir/include/mlir/Dialect/SCF/IR/SCF.h
index cb399b78c406d9..915ab3016b688e 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCF.h
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCF.h
@@ -62,13 +62,8 @@ ForallOp getForallOpThreadIndexOwner(Value val);
// TODO: Consider moving this functionality to RegionBranchOpInterface.
bool insideMutuallyExclusiveBranches(Operation *a, Operation *b);
-/// Promotes the loop body of a scf::ForallOp to its containing block if the
-/// loop was known to have a single iteration.
-LogicalResult promoteIfSingleIteration(PatternRewriter &rewriter,
- scf::ForallOp forallOp);
-
/// Promotes the loop body of a scf::ForallOp to its containing block.
-void promote(PatternRewriter &rewriter, scf::ForallOp forallOp);
+void promote(RewriterBase &rewriter, scf::ForallOp forallOp);
/// An owning vector of values, handy to return from functions.
using ValueVector = SmallVector<Value>;
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 1eed4cec1e4207..58b720c6e39637 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -121,7 +121,7 @@ def ExecuteRegionOp : SCF_Op<"execute_region", [
def ForOp : SCF_Op<"for",
[AutomaticAllocationScope, DeclareOpInterfaceMethods<LoopLikeOpInterface,
["getSingleInductionVar", "getSingleLowerBound", "getSingleStep",
- "getSingleUpperBound"]>,
+ "getSingleUpperBound", "promoteIfSingleIteration"]>,
AllTypesMatch<["lowerBound", "upperBound", "step"]>,
ConditionallySpeculatable,
DeclareOpInterfaceMethods<RegionBranchOpInterface>,
@@ -361,6 +361,8 @@ def ForOp : SCF_Op<"for",
def ForallOp : SCF_Op<"forall", [
AttrSizedOperandSegments,
AutomaticAllocationScope,
+ DeclareOpInterfaceMethods<LoopLikeOpInterface,
+ ["promoteIfSingleIteration"]>,
RecursiveMemoryEffects,
SingleBlockImplicitTerminator<"scf::InParallelOp">,
DeclareOpInterfaceMethods<RegionBranchOpInterface>,
diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
index 0b251a6fec9c18..2e299fd357f282 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
@@ -141,10 +141,6 @@ LogicalResult coalesceLoops(MutableArrayRef<scf::ForOp> loops);
void collapseParallelLoops(scf::ParallelOp loops,
ArrayRef<std::vector<unsigned>> combinedDimensions);
-/// Promotes the loop body of a scf::ForOp to its containing block if the loop
-/// was known to have a single iteration.
-LogicalResult promoteIfSingleIteration(scf::ForOp forOp);
-
/// Unrolls this for operation by the specified unroll factor. Returns failure
/// if the loop cannot be unrolled either due to restrictions or due to invalid
/// unroll factors. Requires positive loop bounds and step. If specified,
diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.h b/mlir/include/mlir/Interfaces/LoopLikeInterface.h
index 48399ad0d53a84..9d81a61fac8856 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.h
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.h
@@ -15,6 +15,10 @@
#include "mlir/IR/OpDefinition.h"
+namespace mlir {
+class RewriterBase;
+} // namespace mlir
+
/// Include the generated interface declarations.
#include "mlir/Interfaces/LoopLikeInterface.h.inc"
diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.td b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
index df7315c4ecf045..a88d7e45e5d67f 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
@@ -48,6 +48,19 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
op->moveBefore($_op);
}]
>,
+ InterfaceMethod<[{
+ Promotes the loop body to its containing block if the loop is known to
+ have a single iteration. Returns "success" if the promotion was
+ successful.
+ }],
+ /*retTy=*/"::mlir::LogicalResult",
+ /*methodName=*/"promoteIfSingleIteration",
+ /*args=*/(ins "::mlir::RewriterBase &":$rewriter),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return ::mlir::failure();
+ }]
+ >,
InterfaceMethod<[{
If there is a single induction variable return it, otherwise return
std::nullopt.
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index db69195a7c7046..4f805d692637ea 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -385,6 +385,35 @@ std::optional<OpFoldResult> ForOp::getSingleUpperBound() {
return OpFoldResult(getUpperBound());
}
+/// Promotes the loop body of a forOp to its containing block if the forOp
+/// it can be determined that the loop has a single iteration.
+LogicalResult ForOp::promoteIfSingleIteration(RewriterBase &rewriter) {
+ std::optional<int64_t> tripCount =
+ constantTripCount(getLowerBound(), getUpperBound(), getStep());
+ if (!tripCount.has_value() || tripCount != 1)
+ return failure();
+
+ // Replace all results with the yielded values.
+ auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
+ rewriter.replaceAllUsesWith(getResults(), yieldOp.getOperands());
+
+ // Replace block arguments with lower bound (replacement for IV) and
+ // iter_args.
+ SmallVector<Value> bbArgReplacements;
+ bbArgReplacements.push_back(getLowerBound());
+ bbArgReplacements.append(getIterOperands().begin(), getIterOperands().end());
+
+ // Move the loop body operations to the loop's containing block.
+ rewriter.inlineBlockBefore(getBody(), getOperation()->getBlock(),
+ getOperation()->getIterator(), bbArgReplacements);
+
+ // Erase the old terminator and the loop.
+ rewriter.eraseOp(yieldOp);
+ rewriter.eraseOp(*this);
+
+ return success();
+}
+
/// Prints the initialization list in the form of
/// <prefix>(%inner = %outer, %inner2 = %outer2, <...>)
/// where 'inner' values are assumed to be region arguments and 'outer' values
@@ -536,59 +565,64 @@ void ForOp::getSuccessorRegions(std::optional<unsigned> index,
regions.push_back(RegionSuccessor(getResults()));
}
+Region &ForallOp::getLoopBody() { return getRegion(); }
+
/// Promotes the loop body of a forallOp to its containing block if it can be
/// determined that the loop has a single iteration.
-LogicalResult mlir::scf::promoteIfSingleIteration(PatternRewriter &rewriter,
- scf::ForallOp forallOp) {
+LogicalResult scf::ForallOp::promoteIfSingleIteration(RewriterBase &rewriter) {
for (auto [lb, ub, step] :
- llvm::zip(forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
- forallOp.getMixedStep())) {
+ llvm::zip(getMixedLowerBound(), getMixedUpperBound(), getMixedStep())) {
auto tripCount = constantTripCount(lb, ub, step);
if (!tripCount.has_value() || *tripCount != 1)
return failure();
}
- promote(rewriter, forallOp);
+ promote(rewriter, *this);
return success();
}
/// Promotes the loop body of a scf::ForallOp to its containing block.
-void mlir::scf::promote(PatternRewriter &rewriter, scf::ForallOp forallOp) {
- IRMapping mapping;
- mapping.map(forallOp.getInductionVars(), forallOp.getLowerBound(rewriter));
- mapping.map(forallOp.getOutputBlockArguments(), forallOp.getOutputs());
- for (auto &bodyOp : forallOp.getBody()->without_terminator())
- rewriter.clone(bodyOp, mapping);
+void mlir::scf::promote(RewriterBase &rewriter, scf::ForallOp forallOp) {
+ OpBuilder::InsertionGuard g(rewriter);
+ scf::InParallelOp terminator = forallOp.getTerminator();
+
+ // Replace block arguments with lower bounds (replacements for IVs) and
+ // outputs.
+ SmallVector<Value> bbArgReplacements = forallOp.getLowerBound(rewriter);
+ bbArgReplacements.append(forallOp.getOutputs().begin(),
+ forallOp.getOutputs().end());
+ // Move the loop body operations to the loop's containing block.
+ rewriter.inlineBlockBefore(forallOp.getBody(), forallOp->getBlock(),
+ forallOp->getIterator(), bbArgReplacements);
+
+ // Replace the terminator with tensor.insert_slice ops.
+ rewriter.setInsertionPointAfter(forallOp);
SmallVector<Value> results;
results.reserve(forallOp.getResults().size());
- scf::InParallelOp terminator = forallOp.getTerminator();
for (auto &yieldingOp : terminator.getYieldingOps()) {
auto parallelInsertSliceOp =
cast<tensor::ParallelInsertSliceOp>(yieldingOp);
Value dst = parallelInsertSliceOp.getDest();
Value src = parallelInsertSliceOp.getSource();
-
- auto getMappedValues = [&](ValueRange values) {
- return llvm::to_vector(llvm::map_range(
- values, [&](Value value) { return mapping.lookupOrDefault(value); }));
- };
-
- Value srcVal = mapping.lookupOrDefault(src);
- if (llvm::isa<TensorType>(srcVal.getType())) {
+ if (llvm::isa<TensorType>(src.getType())) {
results.push_back(rewriter.create<tensor::InsertSliceOp>(
- forallOp.getLoc(), dst.getType(), srcVal,
- mapping.lookupOrDefault(dst),
- getMappedValues(parallelInsertSliceOp.getOffsets()),
- getMappedValues(parallelInsertSliceOp.getSizes()),
- getMappedValues(parallelInsertSliceOp.getStrides()),
+ forallOp.getLoc(), dst.getType(), src, dst,
+ parallelInsertSliceOp.getOffsets(), parallelInsertSliceOp.getSizes(),
+ parallelInsertSliceOp.getStrides(),
parallelInsertSliceOp.getStaticOffsets(),
parallelInsertSliceOp.getStaticSizes(),
parallelInsertSliceOp.getStaticStrides()));
+ } else {
+ llvm_unreachable("unsupported terminator");
}
}
- rewriter.replaceOp(forallOp, results);
+ rewriter.replaceAllUsesWith(forallOp.getResults(), results);
+
+ // Erase the old terminator and the loop.
+ rewriter.eraseOp(terminator);
+ rewriter.eraseOp(forallOp);
}
LoopNest mlir::scf::buildLoopNest(
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 34225fd133b2e3..96e4f890a9c12d 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -362,44 +362,6 @@ static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
return builder.create<arith::DivUIOp>(loc, sum, divisor);
}
-/// Helper to replace uses of loop carried values (iter_args) and loop
-/// yield values while promoting single iteration scf.for ops.
-static void replaceIterArgsAndYieldResults(scf::ForOp forOp) {
- // Replace uses of iter arguments with iter operands (initial values).
- auto iterOperands = forOp.getIterOperands();
- auto iterArgs = forOp.getRegionIterArgs();
- for (auto e : llvm::zip(iterOperands, iterArgs))
- std::get<1>(e).replaceAllUsesWith(std::get<0>(e));
-
- // Replace uses of loop results with the values yielded by the loop.
- auto outerResults = forOp.getResults();
- auto innerResults = forOp.getBody()->getTerminator()->getOperands();
- for (auto e : llvm::zip(outerResults, innerResults))
- std::get<0>(e).replaceAllUsesWith(std::get<1>(e));
-}
-
-/// Promotes the loop body of a forOp to its containing block if the forOp
-/// it can be determined that the loop has a single iteration.
-LogicalResult mlir::promoteIfSingleIteration(scf::ForOp forOp) {
- std::optional<int64_t> tripCount = constantTripCount(
- forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep());
- if (!tripCount.has_value() || tripCount != 1)
- return failure();
- auto iv = forOp.getInductionVar();
- iv.replaceAllUsesWith(forOp.getLowerBound());
-
- replaceIterArgsAndYieldResults(forOp);
-
- // Move the loop body operations, except for its terminator, to the loop's
- // containing block.
- auto *parentBlock = forOp->getBlock();
- forOp.getBody()->getTerminator()->erase();
- parentBlock->getOperations().splice(Block::iterator(forOp),
- forOp.getBody()->getOperations());
- forOp.erase();
- return success();
-}
-
/// Generates unrolled copies of scf::ForOp 'loopBodyBlock', with
/// associated 'forOpIV' by 'unrollFactor', calling 'ivRemapFn' to remap
/// 'forOpIV' for each unrolled body. If specified, annotates the Ops in each
@@ -469,6 +431,7 @@ LogicalResult mlir::loopUnrollByFactor(
// Compute tripCount = ceilDiv((upperBound - lowerBound), step) and populate
// 'upperBoundUnrolled' and 'stepUnrolled' for static and dynamic cases.
OpBuilder boundsBuilder(forOp);
+ IRRewriter rewriter(forOp.getContext());
auto loc = forOp.getLoc();
Value step = forOp.getStep();
Value upperBoundUnrolled;
@@ -488,7 +451,7 @@ LogicalResult mlir::loopUnrollByFactor(
int64_t tripCount = mlir::ceilDiv(ubCst - lbCst, stepCst);
if (unrollFactor == 1) {
- if (tripCount == 1 && failed(promoteIfSingleIteration(forOp)))
+ if (tripCount == 1 && failed(forOp.promoteIfSingleIteration(rewriter)))
return failure();
return success();
}
@@ -553,7 +516,7 @@ LogicalResult mlir::loopUnrollByFactor(
}
epilogueForOp->setOperands(epilogueForOp.getNumControlOperands(),
epilogueForOp.getNumIterOperands(), results);
- (void)promoteIfSingleIteration(epilogueForOp);
+ (void)epilogueForOp.promoteIfSingleIteration(rewriter);
}
// Create unrolled loop.
@@ -573,7 +536,7 @@ LogicalResult mlir::loopUnrollByFactor(
},
annotateFn, iterArgs, yieldedValues);
// Promote the loop body up if this has turned into a single iteration loop.
- (void)promoteIfSingleIteration(forOp);
+ (void)forOp.promoteIfSingleIteration(rewriter);
return success();
}
diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
index e23e86184c3118..e2b45e85e4f8c7 100644
--- a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
@@ -545,9 +545,9 @@ func.func @parallel_insert_slice_no_conflict(
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
+ // CHECK: %[[subview:.*]] = memref.subview %[[arg2]][5] [%[[idx]]] [1]
// CHECK: scf.forall (%[[tidx:.*]]) in (%[[idx2]])
%2 = scf.forall (%arg3) in (%idx2) shared_outs(%o = %arg2) -> (tensor<?xf32>) {
- // CHECK: %[[subview:.*]] = memref.subview %[[arg2]][5] [%[[idx]]] [1]
%6 = tensor.extract_slice %o[5] [%idx] [%c1] : tensor<?xf32> to tensor<?xf32>
// CHECK: linalg.fill ins(%{{.*}}) outs(%[[subview]] : memref<?xf32
%8 = linalg.fill ins(%cst : f32) outs(%6 : tensor<?xf32>) -> tensor<?xf32>
@@ -591,9 +591,9 @@ func.func @parallel_insert_slice_with_conflict(
// CHECK: %[[alloc1:.*]] = memref.alloc
// CHECK: memref.copy %[[arg2]], %[[alloc1]]
+ // CHECK: %[[subview1:.*]] = memref.subview %[[alloc1]][5] [%[[idx]]] [1]
// CHECK: scf.forall (%[[tidx:.*]]) in (%[[idx2]])
%2 = scf.forall (%arg3) in (%idx2) shared_outs(%o = %arg2) -> (tensor<?xf32>) {
- // CHECK: %[[subview1:.*]] = memref.subview %[[alloc1]][5] [%[[idx]]] [1]
%6 = tensor.extract_slice %o[5] [%idx] [%c1] : tensor<?xf32> to tensor<?xf32>
// CHECK: linalg.fill ins(%{{.*}}) outs(%[[subview1]] : memref<?xf32
More information about the Mlir-commits
mailing list