[Mlir-commits] [mlir] Refactor LoopFuseSiblingOp and support parallel fusion (PR #94391)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Jun 8 21:33:41 PDT 2024
https://github.com/srcarroll updated https://github.com/llvm/llvm-project/pull/94391
>From 5020e498440b0016adef7e99806aa55c4837b441 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Wed, 5 Jun 2024 13:08:22 -0500
Subject: [PATCH 01/13] Add getters for multi dim loop variables in
LoopLikeOpInterface
---
.../mlir/Dialect/Affine/IR/AffineOps.td | 4 +-
mlir/include/mlir/Dialect/SCF/IR/SCFOps.td | 37 ++--------
.../mlir/Interfaces/LoopLikeInterface.td | 65 +++++++++++------
mlir/lib/Dialect/Affine/IR/AffineOps.cpp | 20 +++---
mlir/lib/Dialect/SCF/IR/SCF.cpp | 70 +++++++------------
.../Dialect/SCF/LoopLikeSCFOpsTest.cpp | 8 +++
6 files changed, 97 insertions(+), 107 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index 3640055ea8da8..bb2c29b5733b8 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -118,8 +118,8 @@ def AffineForOp : Affine_Op<"for",
[AttrSizedOperandSegments, AutomaticAllocationScope,
ImplicitAffineTerminator, ConditionallySpeculatable,
RecursiveMemoryEffects, DeclareOpInterfaceMethods<LoopLikeOpInterface,
- ["getSingleInductionVar", "getSingleLowerBound", "getSingleStep",
- "getSingleUpperBound", "getYieldedValuesMutable",
+ ["getInductionVars", "getMixedLowerBound", "getMixedStep",
+ "getMixedUpperBound", "getYieldedValuesMutable",
"replaceWithAdditionalYields"]>,
DeclareOpInterfaceMethods<RegionBranchOpInterface,
["getEntrySuccessorOperands"]>]> {
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 0b063aa772bab..3b28ca8b21d0f 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -136,8 +136,8 @@ def ExecuteRegionOp : SCF_Op<"execute_region", [
def ForOp : SCF_Op<"for",
[AutomaticAllocationScope, DeclareOpInterfaceMethods<LoopLikeOpInterface,
["getInitsMutable", "getLoopResults", "getRegionIterArgs",
- "getSingleInductionVar", "getSingleLowerBound", "getSingleStep",
- "getSingleUpperBound", "getYieldedValuesMutable",
+ "getInductionVars", "getMixedLowerBound", "getMixedStep",
+ "getMixedUpperBound", "getYieldedValuesMutable",
"promoteIfSingleIteration", "replaceWithAdditionalYields",
"yieldTiledValuesAndReplace"]>,
AllTypesMatch<["lowerBound", "upperBound", "step"]>,
@@ -301,8 +301,8 @@ def ForallOp : SCF_Op<"forall", [
AttrSizedOperandSegments,
AutomaticAllocationScope,
DeclareOpInterfaceMethods<LoopLikeOpInterface,
- ["getInitsMutable", "getRegionIterArgs", "getSingleInductionVar",
- "getSingleLowerBound", "getSingleUpperBound", "getSingleStep",
+ ["getInitsMutable", "getRegionIterArgs", "getInductionVars",
+ "getMixedLowerBound", "getMixedUpperBound", "getMixedStep",
"promoteIfSingleIteration", "yieldTiledValuesAndReplace"]>,
RecursiveMemoryEffects,
SingleBlockImplicitTerminator<"scf::InParallelOp">,
@@ -510,24 +510,6 @@ def ForallOp : SCF_Op<"forall", [
];
let extraClassDeclaration = [{
- // Get lower bounds as OpFoldResult.
- SmallVector<OpFoldResult> getMixedLowerBound() {
- Builder b(getOperation()->getContext());
- return getMixedValues(getStaticLowerBound(), getDynamicLowerBound(), b);
- }
-
- // Get upper bounds as OpFoldResult.
- SmallVector<OpFoldResult> getMixedUpperBound() {
- Builder b(getOperation()->getContext());
- return getMixedValues(getStaticUpperBound(), getDynamicUpperBound(), b);
- }
-
- // Get steps as OpFoldResult.
- SmallVector<OpFoldResult> getMixedStep() {
- Builder b(getOperation()->getContext());
- return getMixedValues(getStaticStep(), getDynamicStep(), b);
- }
-
/// Get lower bounds as values.
SmallVector<Value> getLowerBound(OpBuilder &b) {
return getValueOrCreateConstantIndexOp(b, getLoc(), getMixedLowerBound());
@@ -584,10 +566,6 @@ def ForallOp : SCF_Op<"forall", [
getNumDynamicControlOperands() + getRank());
}
- ::mlir::ValueRange getInductionVars() {
- return getBody()->getArguments().take_front(getRank());
- }
-
::mlir::Value getInductionVar(int64_t idx) {
return getInductionVars()[idx];
}
@@ -765,8 +743,8 @@ def IfOp : SCF_Op<"if", [DeclareOpInterfaceMethods<RegionBranchOpInterface, [
def ParallelOp : SCF_Op<"parallel",
[AutomaticAllocationScope,
AttrSizedOperandSegments,
- DeclareOpInterfaceMethods<LoopLikeOpInterface, ["getSingleInductionVar",
- "getSingleLowerBound", "getSingleUpperBound", "getSingleStep"]>,
+ DeclareOpInterfaceMethods<LoopLikeOpInterface, ["getInductionVars",
+ "getMixedLowerBound", "getMixedUpperBound", "getMixedStep"]>,
RecursiveMemoryEffects,
DeclareOpInterfaceMethods<RegionBranchOpInterface>,
SingleBlockImplicitTerminator<"scf::ReduceOp">,
@@ -846,9 +824,6 @@ def ParallelOp : SCF_Op<"parallel",
];
let extraClassDeclaration = [{
- ValueRange getInductionVars() {
- return getBody()->getArguments();
- }
unsigned getNumLoops() { return getStep().size(); }
unsigned getNumReductions() { return getInitVals().size(); }
}];
diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.td b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
index f0dc6e60eba58..813779c852027 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
@@ -93,51 +93,47 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
}]
>,
InterfaceMethod<[{
- If there is a single induction variable return it, otherwise return
- std::nullopt.
+ Return all induction variables.
}],
- /*retTy=*/"::std::optional<::mlir::Value>",
- /*methodName=*/"getSingleInductionVar",
+ /*retTy=*/"::mlir::ValueRange",
+ /*methodName=*/"getInductionVars",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- return std::nullopt;
+ return {};
}]
>,
InterfaceMethod<[{
- Return the single lower bound value or attribute if it exists, otherwise
- return std::nullopt.
+ Return all lower bounds.
}],
- /*retTy=*/"::std::optional<::mlir::OpFoldResult>",
- /*methodName=*/"getSingleLowerBound",
+ /*retTy=*/"::llvm::SmallVector<::mlir::OpFoldResult>",
+ /*methodName=*/"getMixedLowerBound",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- return std::nullopt;
+ return {};
}]
>,
InterfaceMethod<[{
- Return the single step value or attribute if it exists, otherwise
- return std::nullopt.
+ Return all steps.
}],
- /*retTy=*/"::std::optional<::mlir::OpFoldResult>",
- /*methodName=*/"getSingleStep",
+ /*retTy=*/"::llvm::SmallVector<::mlir::OpFoldResult>",
+ /*methodName=*/"getMixedStep",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- return std::nullopt;
+ return {};
}]
>,
InterfaceMethod<[{
- Return the single upper bound value or attribute if it exists, otherwise
- return std::nullopt.
+ Return all upper bounds.
}],
- /*retTy=*/"::std::optional<::mlir::OpFoldResult>",
- /*methodName=*/"getSingleUpperBound",
+ /*retTy=*/"::llvm::SmallVector<::mlir::OpFoldResult>",
+ /*methodName=*/"getMixedUpperBound",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- return std::nullopt;
+ return {};
}]
>,
InterfaceMethod<[{
@@ -235,6 +231,35 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
}];
let extraSharedClassDeclaration = [{
+ /// If there is a single induction variable return it, otherwise return
+ /// std::nullopt.
+ ::std::optional<::mlir::Value> getSingleInductionVar() {
+ if (this->getInductionVars().size() == 1)
+ return this->getInductionVars()[0];
+ return std::nullopt;
+ }
+ /// Return the single lower bound value or attribute if it exists, otherwise
+ /// return std::nullopt.
+ ::std::optional<::mlir::OpFoldResult> getSingleLowerBound() {
+ if (this->getMixedLowerBound().size() == 1)
+ return this->getMixedLowerBound()[0];
+ return std::nullopt;
+ }
+ /// Return the single step value or attribute if it exists, otherwise
+ /// return std::nullopt.
+ ::std::optional<::mlir::OpFoldResult> getSingleStep() {
+ if (this->getMixedStep().size() == 1)
+ return this->getMixedStep()[0];
+ return std::nullopt;
+ }
+ /// Return the single upper bound value or attribute if it exists, otherwise
+ /// return std::nullopt.
+ ::std::optional<::mlir::OpFoldResult> getSingleUpperBound() {
+ if (this->getMixedUpperBound().size() == 1)
+ return this->getMixedUpperBound()[0];
+ return std::nullopt;
+ }
+
/// Append the specified additional "init" operands: replace this loop with
/// a new loop that has the additional init operands. The loop body of this
/// loop is moved over to the new loop.
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 2e31487bd55a0..746a9c919560c 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -2454,27 +2454,25 @@ bool AffineForOp::matchingBoundOperandList() {
SmallVector<Region *> AffineForOp::getLoopRegions() { return {&getRegion()}; }
-std::optional<Value> AffineForOp::getSingleInductionVar() {
- return getInductionVar();
-}
+ValueRange AffineForOp::getInductionVars() { return {getInductionVar()}; }
-std::optional<OpFoldResult> AffineForOp::getSingleLowerBound() {
+SmallVector<OpFoldResult> AffineForOp::getMixedLowerBound() {
if (!hasConstantLowerBound())
- return std::nullopt;
+ return {};
OpBuilder b(getContext());
- return OpFoldResult(b.getI64IntegerAttr(getConstantLowerBound()));
+ return {OpFoldResult(b.getI64IntegerAttr(getConstantLowerBound()))};
}
-std::optional<OpFoldResult> AffineForOp::getSingleStep() {
+SmallVector<OpFoldResult> AffineForOp::getMixedStep() {
OpBuilder b(getContext());
- return OpFoldResult(b.getI64IntegerAttr(getStepAsInt()));
+ return {OpFoldResult(b.getI64IntegerAttr(getStepAsInt()))};
}
-std::optional<OpFoldResult> AffineForOp::getSingleUpperBound() {
+SmallVector<OpFoldResult> AffineForOp::getMixedUpperBound() {
if (!hasConstantUpperBound())
- return std::nullopt;
+ return {};
OpBuilder b(getContext());
- return OpFoldResult(b.getI64IntegerAttr(getConstantUpperBound()));
+ return {OpFoldResult(b.getI64IntegerAttr(getConstantUpperBound()))};
}
FailureOr<LoopLikeOpInterface> AffineForOp::replaceWithAdditionalYields(
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 107fd0690f193..e275ff1849c10 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -378,20 +378,18 @@ LogicalResult ForOp::verifyRegions() {
return success();
}
-std::optional<Value> ForOp::getSingleInductionVar() {
- return getInductionVar();
-}
+ValueRange ForOp::getInductionVars() { return {getInductionVar()}; }
-std::optional<OpFoldResult> ForOp::getSingleLowerBound() {
- return OpFoldResult(getLowerBound());
+SmallVector<OpFoldResult> ForOp::getMixedLowerBound() {
+ return {OpFoldResult(getLowerBound())};
}
-std::optional<OpFoldResult> ForOp::getSingleStep() {
- return OpFoldResult(getStep());
+SmallVector<OpFoldResult> ForOp::getMixedStep() {
+ return {OpFoldResult(getStep())};
}
-std::optional<OpFoldResult> ForOp::getSingleUpperBound() {
- return OpFoldResult(getUpperBound());
+SmallVector<OpFoldResult> ForOp::getMixedUpperBound() {
+ return {OpFoldResult(getUpperBound())};
}
std::optional<ResultRange> ForOp::getLoopResults() { return getResults(); }
@@ -1428,28 +1426,26 @@ SmallVector<Operation *> ForallOp::getCombiningOps(BlockArgument bbArg) {
return storeOps;
}
-std::optional<Value> ForallOp::getSingleInductionVar() {
- if (getRank() != 1)
- return std::nullopt;
- return getInductionVar(0);
+ValueRange ForallOp::getInductionVars() {
+ return getBody()->getArguments().take_front(getRank());
}
-std::optional<OpFoldResult> ForallOp::getSingleLowerBound() {
- if (getRank() != 1)
- return std::nullopt;
- return getMixedLowerBound()[0];
+// Get lower bounds as OpFoldResult.
+SmallVector<OpFoldResult> ForallOp::getMixedLowerBound() {
+ Builder b(getOperation()->getContext());
+ return getMixedValues(getStaticLowerBound(), getDynamicLowerBound(), b);
}
-std::optional<OpFoldResult> ForallOp::getSingleUpperBound() {
- if (getRank() != 1)
- return std::nullopt;
- return getMixedUpperBound()[0];
+// Get upper bounds as OpFoldResult.
+SmallVector<OpFoldResult> ForallOp::getMixedUpperBound() {
+ Builder b(getOperation()->getContext());
+ return getMixedValues(getStaticUpperBound(), getDynamicUpperBound(), b);
}
-std::optional<OpFoldResult> ForallOp::getSingleStep() {
- if (getRank() != 1)
- return std::nullopt;
- return getMixedStep()[0];
+// Get steps as OpFoldResult.
+SmallVector<OpFoldResult> ForallOp::getMixedStep() {
+ Builder b(getOperation()->getContext());
+ return getMixedValues(getStaticStep(), getDynamicStep(), b);
}
ForallOp mlir::scf::getForallOpThreadIndexOwner(Value val) {
@@ -3008,29 +3004,17 @@ void ParallelOp::print(OpAsmPrinter &p) {
SmallVector<Region *> ParallelOp::getLoopRegions() { return {&getRegion()}; }
-std::optional<Value> ParallelOp::getSingleInductionVar() {
- if (getNumLoops() != 1)
- return std::nullopt;
- return getBody()->getArgument(0);
-}
+ValueRange ParallelOp::getInductionVars() { return getBody()->getArguments(); }
-std::optional<OpFoldResult> ParallelOp::getSingleLowerBound() {
- if (getNumLoops() != 1)
- return std::nullopt;
- return getLowerBound()[0];
+SmallVector<OpFoldResult> ParallelOp::getMixedLowerBound() {
+ return getLowerBound();
}
-std::optional<OpFoldResult> ParallelOp::getSingleUpperBound() {
- if (getNumLoops() != 1)
- return std::nullopt;
- return getUpperBound()[0];
+SmallVector<OpFoldResult> ParallelOp::getMixedUpperBound() {
+ return getUpperBound();
}
-std::optional<OpFoldResult> ParallelOp::getSingleStep() {
- if (getNumLoops() != 1)
- return std::nullopt;
- return getStep()[0];
-}
+SmallVector<OpFoldResult> ParallelOp::getMixedStep() { return getStep(); }
ParallelOp mlir::scf::getParallelForInductionVarOwner(Value val) {
auto ivArg = llvm::dyn_cast<BlockArgument>(val);
diff --git a/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp b/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp
index 6bc0fd6113b9b..d8cdb213070da 100644
--- a/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp
+++ b/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp
@@ -36,6 +36,10 @@ class SCFLoopLikeTest : public ::testing::Test {
std::optional<OpFoldResult> maybeIndVar =
loopLikeOp.getSingleInductionVar();
EXPECT_TRUE(maybeIndVar.has_value());
+ EXPECT_EQ(loopLikeOp.getInductionVars().size(), 1u);
+ EXPECT_EQ(loopLikeOp.getMixedLowerBound().size(), 1u);
+ EXPECT_EQ(loopLikeOp.getMixedStep().size(), 1u);
+ EXPECT_EQ(loopLikeOp.getMixedLowerBound().size(), 1u);
}
void checkMultidimensional(LoopLikeOpInterface loopLikeOp) {
@@ -48,6 +52,10 @@ class SCFLoopLikeTest : public ::testing::Test {
std::optional<OpFoldResult> maybeIndVar =
loopLikeOp.getSingleInductionVar();
EXPECT_FALSE(maybeIndVar.has_value());
+ EXPECT_EQ(loopLikeOp.getInductionVars().size(), 2u);
+ EXPECT_EQ(loopLikeOp.getMixedLowerBound().size(), 2u);
+ EXPECT_EQ(loopLikeOp.getMixedStep().size(), 2u);
+ EXPECT_EQ(loopLikeOp.getMixedLowerBound().size(), 2u);
}
MLIRContext context;
>From 50852d570440e0041c8b2b38925c4af05fac0636 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Tue, 4 Jun 2024 14:44:53 -0500
Subject: [PATCH 02/13] Refactor LoopFuseSiblingOp and support parallel fusion
---
mlir/include/mlir/Dialect/SCF/Utils/Utils.h | 16 ++
.../SCF/TransformOps/SCFTransformOps.cpp | 53 +++--
.../SCF/Transforms/ParallelLoopFusion.cpp | 204 +----------------
mlir/lib/Dialect/SCF/Utils/Utils.cpp | 208 ++++++++++++++++++
.../SCF/transform-loop-fuse-sibling.mlir | 53 +++++
5 files changed, 304 insertions(+), 230 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
index bc09cc7f7fa5e..2944d8ffac022 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
@@ -156,6 +156,12 @@ Loops tilePerfectlyNested(scf::ForOp rootForOp, ArrayRef<Value> sizes);
void getPerfectlyNestedLoops(SmallVectorImpl<scf::ForOp> &nestedLoops,
scf::ForOp root);
+/// Prepends operations of firstPloop's body into secondPloop's body.
+/// Updates secondPloop with new loop.
+void fuseIfLegal(scf::ParallelOp firstPloop, scf::ParallelOp &secondPloop,
+ OpBuilder builder,
+ llvm::function_ref<bool(Value, Value)> mayAlias);
+
/// Given two scf.forall loops, `target` and `source`, fuses `target` into
/// `source`. Assumes that the given loops are siblings and are independent of
/// each other.
@@ -177,6 +183,16 @@ scf::ForallOp fuseIndependentSiblingForallLoops(scf::ForallOp target,
scf::ForOp fuseIndependentSiblingForLoops(scf::ForOp target, scf::ForOp source,
RewriterBase &rewriter);
+/// Given two scf.parallel loops, `target` and `source`, fuses `target` into
+/// `source`. Assumes that the given loops are siblings and are independent of
+/// each other.
+///
+/// This function does not perform any legality checks and simply fuses the
+/// loops. The caller is responsible for ensuring that the loops are legal to
+/// fuse.
+scf::ParallelOp fuseIndependentSiblingParallelLoops(scf::ParallelOp target,
+ scf::ParallelOp source,
+ RewriterBase &rewriter);
} // namespace mlir
#endif // MLIR_DIALECT_SCF_UTILS_UTILS_H_
diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index 69f83d8bd70da..1c53e89d69040 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -442,39 +442,32 @@ static DiagnosedSilenceableFailure isOpSibling(Operation *target,
return DiagnosedSilenceableFailure::success();
}
-/// Check if `target` scf.forall can be fused into `source` scf.forall.
+/// Check if `target` scf loop can be fused into `source` scf loop.
+/// Applies for scf.for, scf.forall, and scf.parallel.
///
/// This simply checks if both loops have the same bounds, steps and mapping.
/// No attempt is made at checking that the side effects of `target` and
/// `source` are independent of each other.
-static bool isForallWithIdenticalConfiguration(Operation *target,
- Operation *source) {
- auto targetOp = dyn_cast<scf::ForallOp>(target);
- auto sourceOp = dyn_cast<scf::ForallOp>(source);
- if (!targetOp || !sourceOp)
- return false;
-
- return targetOp.getMixedLowerBound() == sourceOp.getMixedLowerBound() &&
- targetOp.getMixedUpperBound() == sourceOp.getMixedUpperBound() &&
- targetOp.getMixedStep() == sourceOp.getMixedStep() &&
- targetOp.getMapping() == sourceOp.getMapping();
-}
-
-/// Check if `target` scf.for can be fused into `source` scf.for.
-///
-/// This simply checks if both loops have the same bounds and steps. No attempt
-/// is made at checking that the side effects of `target` and `source` are
-/// independent of each other.
-static bool isForWithIdenticalConfiguration(Operation *target,
- Operation *source) {
- auto targetOp = dyn_cast<scf::ForOp>(target);
- auto sourceOp = dyn_cast<scf::ForOp>(source);
+template <typename LoopTy>
+static bool isLoopWithIdenticalConfiguration(Operation *target,
+ Operation *source) {
+ static_assert(llvm::is_one_of<LoopTy, scf::ForallOp, scf::ForOp,
+ scf::ParallelOp>::value,
+ "applies to only `forall`, `for` and `parallel`");
+ auto targetOp = dyn_cast<LoopTy>(target);
+ auto sourceOp = dyn_cast<LoopTy>(source);
if (!targetOp || !sourceOp)
return false;
- return targetOp.getLowerBound() == sourceOp.getLowerBound() &&
- targetOp.getUpperBound() == sourceOp.getUpperBound() &&
- targetOp.getStep() == sourceOp.getStep();
+ if constexpr (std::is_same_v<LoopTy, scf::ForallOp>)
+ return targetOp.getMixedLowerBound() == sourceOp.getMixedLowerBound() &&
+ targetOp.getMixedUpperBound() == sourceOp.getMixedUpperBound() &&
+ targetOp.getMixedStep() == sourceOp.getMixedStep() &&
+ targetOp.getMapping() == sourceOp.getMapping();
+ else
+ return targetOp.getLowerBound() == sourceOp.getLowerBound() &&
+ targetOp.getUpperBound() == sourceOp.getUpperBound() &&
+ targetOp.getStep() == sourceOp.getStep();
}
DiagnosedSilenceableFailure
@@ -502,12 +495,16 @@ transform::LoopFuseSiblingOp::apply(transform::TransformRewriter &rewriter,
Operation *fusedLoop;
/// TODO: Support fusion for loop-like ops besides scf.for and scf.forall.
- if (isForWithIdenticalConfiguration(target, source)) {
+ if (isLoopWithIdenticalConfiguration<scf::ForOp>(target, source)) {
fusedLoop = fuseIndependentSiblingForLoops(
cast<scf::ForOp>(target), cast<scf::ForOp>(source), rewriter);
- } else if (isForallWithIdenticalConfiguration(target, source)) {
+ } else if (isLoopWithIdenticalConfiguration<scf::ForallOp>(target, source)) {
fusedLoop = fuseIndependentSiblingForallLoops(
cast<scf::ForallOp>(target), cast<scf::ForallOp>(source), rewriter);
+ } else if (isLoopWithIdenticalConfiguration<scf::ParallelOp>(target,
+ source)) {
+ fusedLoop = fuseIndependentSiblingParallelLoops(
+ cast<scf::ParallelOp>(target), cast<scf::ParallelOp>(source), rewriter);
} else
return emitSilenceableFailure(target->getLoc())
<< "operations cannot be fused";
diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
index 5934d85373b03..abac91cfaf7d9 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
@@ -16,6 +16,7 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
+#include "mlir/Dialect/SCF/Utils/Utils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/OpDefinition.h"
@@ -30,207 +31,6 @@ namespace mlir {
using namespace mlir;
using namespace mlir::scf;
-/// Verify there are no nested ParallelOps.
-static bool hasNestedParallelOp(ParallelOp ploop) {
- auto walkResult =
- ploop.getBody()->walk([](ParallelOp) { return WalkResult::interrupt(); });
- return walkResult.wasInterrupted();
-}
-
-/// Verify equal iteration spaces.
-static bool equalIterationSpaces(ParallelOp firstPloop,
- ParallelOp secondPloop) {
- if (firstPloop.getNumLoops() != secondPloop.getNumLoops())
- return false;
-
- auto matchOperands = [&](const OperandRange &lhs,
- const OperandRange &rhs) -> bool {
- // TODO: Extend this to support aliases and equal constants.
- return std::equal(lhs.begin(), lhs.end(), rhs.begin());
- };
- return matchOperands(firstPloop.getLowerBound(),
- secondPloop.getLowerBound()) &&
- matchOperands(firstPloop.getUpperBound(),
- secondPloop.getUpperBound()) &&
- matchOperands(firstPloop.getStep(), secondPloop.getStep());
-}
-
-/// Checks if the parallel loops have mixed access to the same buffers. Returns
-/// `true` if the first parallel loop writes to the same indices that the second
-/// loop reads.
-static bool haveNoReadsAfterWriteExceptSameIndex(
- ParallelOp firstPloop, ParallelOp secondPloop,
- const IRMapping &firstToSecondPloopIndices,
- llvm::function_ref<bool(Value, Value)> mayAlias) {
- DenseMap<Value, SmallVector<ValueRange, 1>> bufferStores;
- SmallVector<Value> bufferStoresVec;
- firstPloop.getBody()->walk([&](memref::StoreOp store) {
- bufferStores[store.getMemRef()].push_back(store.getIndices());
- bufferStoresVec.emplace_back(store.getMemRef());
- });
- auto walkResult = secondPloop.getBody()->walk([&](memref::LoadOp load) {
- Value loadMem = load.getMemRef();
- // Stop if the memref is defined in secondPloop body. Careful alias analysis
- // is needed.
- auto *memrefDef = loadMem.getDefiningOp();
- if (memrefDef && memrefDef->getBlock() == load->getBlock())
- return WalkResult::interrupt();
-
- for (Value store : bufferStoresVec)
- if (store != loadMem && mayAlias(store, loadMem))
- return WalkResult::interrupt();
-
- auto write = bufferStores.find(loadMem);
- if (write == bufferStores.end())
- return WalkResult::advance();
-
- // Check that at last one store was retrieved
- if (!write->second.size())
- return WalkResult::interrupt();
-
- auto storeIndices = write->second.front();
-
- // Multiple writes to the same memref are allowed only on the same indices
- for (const auto &othStoreIndices : write->second) {
- if (othStoreIndices != storeIndices)
- return WalkResult::interrupt();
- }
-
- // Check that the load indices of secondPloop coincide with store indices of
- // firstPloop for the same memrefs.
- auto loadIndices = load.getIndices();
- if (storeIndices.size() != loadIndices.size())
- return WalkResult::interrupt();
- for (int i = 0, e = storeIndices.size(); i < e; ++i) {
- if (firstToSecondPloopIndices.lookupOrDefault(storeIndices[i]) !=
- loadIndices[i]) {
- auto *storeIndexDefOp = storeIndices[i].getDefiningOp();
- auto *loadIndexDefOp = loadIndices[i].getDefiningOp();
- if (storeIndexDefOp && loadIndexDefOp) {
- if (!isMemoryEffectFree(storeIndexDefOp))
- return WalkResult::interrupt();
- if (!isMemoryEffectFree(loadIndexDefOp))
- return WalkResult::interrupt();
- if (!OperationEquivalence::isEquivalentTo(
- storeIndexDefOp, loadIndexDefOp,
- [&](Value storeIndex, Value loadIndex) {
- if (firstToSecondPloopIndices.lookupOrDefault(storeIndex) !=
- firstToSecondPloopIndices.lookupOrDefault(loadIndex))
- return failure();
- else
- return success();
- },
- /*markEquivalent=*/nullptr,
- OperationEquivalence::Flags::IgnoreLocations)) {
- return WalkResult::interrupt();
- }
- } else
- return WalkResult::interrupt();
- }
- }
- return WalkResult::advance();
- });
- return !walkResult.wasInterrupted();
-}
-
-/// Analyzes dependencies in the most primitive way by checking simple read and
-/// write patterns.
-static LogicalResult
-verifyDependencies(ParallelOp firstPloop, ParallelOp secondPloop,
- const IRMapping &firstToSecondPloopIndices,
- llvm::function_ref<bool(Value, Value)> mayAlias) {
- if (!haveNoReadsAfterWriteExceptSameIndex(
- firstPloop, secondPloop, firstToSecondPloopIndices, mayAlias))
- return failure();
-
- IRMapping secondToFirstPloopIndices;
- secondToFirstPloopIndices.map(secondPloop.getBody()->getArguments(),
- firstPloop.getBody()->getArguments());
- return success(haveNoReadsAfterWriteExceptSameIndex(
- secondPloop, firstPloop, secondToFirstPloopIndices, mayAlias));
-}
-
-static bool isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop,
- const IRMapping &firstToSecondPloopIndices,
- llvm::function_ref<bool(Value, Value)> mayAlias) {
- return !hasNestedParallelOp(firstPloop) &&
- !hasNestedParallelOp(secondPloop) &&
- equalIterationSpaces(firstPloop, secondPloop) &&
- succeeded(verifyDependencies(firstPloop, secondPloop,
- firstToSecondPloopIndices, mayAlias));
-}
-
-/// Prepends operations of firstPloop's body into secondPloop's body.
-/// Updates secondPloop with new loop.
-static void fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop,
- OpBuilder builder,
- llvm::function_ref<bool(Value, Value)> mayAlias) {
- Block *block1 = firstPloop.getBody();
- Block *block2 = secondPloop.getBody();
- IRMapping firstToSecondPloopIndices;
- firstToSecondPloopIndices.map(block1->getArguments(), block2->getArguments());
-
- if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices,
- mayAlias))
- return;
-
- DominanceInfo dom;
- // We are fusing first loop into second, make sure there are no users of the
- // first loop results between loops.
- for (Operation *user : firstPloop->getUsers())
- if (!dom.properlyDominates(secondPloop, user, /*enclosingOpOk*/ false))
- return;
-
- ValueRange inits1 = firstPloop.getInitVals();
- ValueRange inits2 = secondPloop.getInitVals();
-
- SmallVector<Value> newInitVars(inits1.begin(), inits1.end());
- newInitVars.append(inits2.begin(), inits2.end());
-
- IRRewriter b(builder);
- b.setInsertionPoint(secondPloop);
- auto newSecondPloop = b.create<ParallelOp>(
- secondPloop.getLoc(), secondPloop.getLowerBound(),
- secondPloop.getUpperBound(), secondPloop.getStep(), newInitVars);
-
- Block *newBlock = newSecondPloop.getBody();
- auto term1 = cast<ReduceOp>(block1->getTerminator());
- auto term2 = cast<ReduceOp>(block2->getTerminator());
-
- b.inlineBlockBefore(block2, newBlock, newBlock->begin(),
- newBlock->getArguments());
- b.inlineBlockBefore(block1, newBlock, newBlock->begin(),
- newBlock->getArguments());
-
- ValueRange results = newSecondPloop.getResults();
- if (!results.empty()) {
- b.setInsertionPointToEnd(newBlock);
-
- ValueRange reduceArgs1 = term1.getOperands();
- ValueRange reduceArgs2 = term2.getOperands();
- SmallVector<Value> newReduceArgs(reduceArgs1.begin(), reduceArgs1.end());
- newReduceArgs.append(reduceArgs2.begin(), reduceArgs2.end());
-
- auto newReduceOp = b.create<scf::ReduceOp>(term2.getLoc(), newReduceArgs);
-
- for (auto &&[i, reg] : llvm::enumerate(llvm::concat<Region>(
- term1.getReductions(), term2.getReductions()))) {
- Block &oldRedBlock = reg.front();
- Block &newRedBlock = newReduceOp.getReductions()[i].front();
- b.inlineBlockBefore(&oldRedBlock, &newRedBlock, newRedBlock.begin(),
- newRedBlock.getArguments());
- }
-
- firstPloop.replaceAllUsesWith(results.take_front(inits1.size()));
- secondPloop.replaceAllUsesWith(results.take_back(inits2.size()));
- }
- term1->erase();
- term2->erase();
- firstPloop.erase();
- secondPloop.erase();
- secondPloop = newSecondPloop;
-}
-
void mlir::scf::naivelyFuseParallelOps(
Region ®ion, llvm::function_ref<bool(Value, Value)> mayAlias) {
OpBuilder b(region);
@@ -259,7 +59,7 @@ void mlir::scf::naivelyFuseParallelOps(
}
for (MutableArrayRef<ParallelOp> ploops : ploopChains) {
for (int i = 0, e = ploops.size(); i + 1 < e; ++i)
- fuseIfLegal(ploops[i], ploops[i + 1], b, mayAlias);
+ mlir::fuseIfLegal(ploops[i], ploops[i + 1], b, mayAlias);
}
}
}
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 6658cca03eba7..d85339f32dbe3 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -15,6 +15,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/IRMapping.h"
@@ -1070,6 +1071,206 @@ TileLoops mlir::extractFixedOuterLoops(scf::ForOp rootForOp,
return tileLoops;
}
+/// Checks if the parallel loops have mixed access to the same buffers. Returns
+/// `true` if the first parallel loop writes to the same indices that the second
+/// loop reads.
+static bool haveNoReadsAfterWriteExceptSameIndex(
+ scf::ParallelOp firstPloop, scf::ParallelOp secondPloop,
+ const IRMapping &firstToSecondPloopIndices,
+ llvm::function_ref<bool(Value, Value)> mayAlias) {
+ DenseMap<Value, SmallVector<ValueRange, 1>> bufferStores;
+ SmallVector<Value> bufferStoresVec;
+ firstPloop.getBody()->walk([&](memref::StoreOp store) {
+ bufferStores[store.getMemRef()].push_back(store.getIndices());
+ bufferStoresVec.emplace_back(store.getMemRef());
+ });
+ auto walkResult = secondPloop.getBody()->walk([&](memref::LoadOp load) {
+ Value loadMem = load.getMemRef();
+ // Stop if the memref is defined in secondPloop body. Careful alias analysis
+ // is needed.
+ auto *memrefDef = loadMem.getDefiningOp();
+ if (memrefDef && memrefDef->getBlock() == load->getBlock())
+ return WalkResult::interrupt();
+
+ for (Value store : bufferStoresVec)
+ if (store != loadMem && mayAlias(store, loadMem))
+ return WalkResult::interrupt();
+
+ auto write = bufferStores.find(loadMem);
+ if (write == bufferStores.end())
+ return WalkResult::advance();
+
+ // Check that at last one store was retrieved
+ if (!write->second.size())
+ return WalkResult::interrupt();
+
+ auto storeIndices = write->second.front();
+
+ // Multiple writes to the same memref are allowed only on the same indices
+ for (const auto &othStoreIndices : write->second) {
+ if (othStoreIndices != storeIndices)
+ return WalkResult::interrupt();
+ }
+
+ // Check that the load indices of secondPloop coincide with store indices of
+ // firstPloop for the same memrefs.
+ auto loadIndices = load.getIndices();
+ if (storeIndices.size() != loadIndices.size())
+ return WalkResult::interrupt();
+ for (int i = 0, e = storeIndices.size(); i < e; ++i) {
+ if (firstToSecondPloopIndices.lookupOrDefault(storeIndices[i]) !=
+ loadIndices[i]) {
+ auto *storeIndexDefOp = storeIndices[i].getDefiningOp();
+ auto *loadIndexDefOp = loadIndices[i].getDefiningOp();
+ if (storeIndexDefOp && loadIndexDefOp) {
+ if (!isMemoryEffectFree(storeIndexDefOp))
+ return WalkResult::interrupt();
+ if (!isMemoryEffectFree(loadIndexDefOp))
+ return WalkResult::interrupt();
+ if (!OperationEquivalence::isEquivalentTo(
+ storeIndexDefOp, loadIndexDefOp,
+ [&](Value storeIndex, Value loadIndex) {
+ if (firstToSecondPloopIndices.lookupOrDefault(storeIndex) !=
+ firstToSecondPloopIndices.lookupOrDefault(loadIndex))
+ return failure();
+ else
+ return success();
+ },
+ /*markEquivalent=*/nullptr,
+ OperationEquivalence::Flags::IgnoreLocations)) {
+ return WalkResult::interrupt();
+ }
+ } else
+ return WalkResult::interrupt();
+ }
+ }
+ return WalkResult::advance();
+ });
+ return !walkResult.wasInterrupted();
+}
+
+/// Analyzes dependencies in the most primitive way by checking simple read and
+/// write patterns.
+static LogicalResult
+verifyDependencies(scf::ParallelOp firstPloop, scf::ParallelOp secondPloop,
+ const IRMapping &firstToSecondPloopIndices,
+ llvm::function_ref<bool(Value, Value)> mayAlias) {
+ if (!haveNoReadsAfterWriteExceptSameIndex(
+ firstPloop, secondPloop, firstToSecondPloopIndices, mayAlias))
+ return failure();
+
+ IRMapping secondToFirstPloopIndices;
+ secondToFirstPloopIndices.map(secondPloop.getBody()->getArguments(),
+ firstPloop.getBody()->getArguments());
+ return success(haveNoReadsAfterWriteExceptSameIndex(
+ secondPloop, firstPloop, secondToFirstPloopIndices, mayAlias));
+}
+
+/// Verify equal iteration spaces.
+static bool equalIterationSpaces(scf::ParallelOp firstPloop,
+ scf::ParallelOp secondPloop) {
+ if (firstPloop.getNumLoops() != secondPloop.getNumLoops())
+ return false;
+
+ auto matchOperands = [&](const OperandRange &lhs,
+ const OperandRange &rhs) -> bool {
+ // TODO: Extend this to support aliases and equal constants.
+ return std::equal(lhs.begin(), lhs.end(), rhs.begin());
+ };
+ return matchOperands(firstPloop.getLowerBound(),
+ secondPloop.getLowerBound()) &&
+ matchOperands(firstPloop.getUpperBound(),
+ secondPloop.getUpperBound()) &&
+ matchOperands(firstPloop.getStep(), secondPloop.getStep());
+}
+
+/// Verify there are no nested ParallelOps.
+static bool hasNestedParallelOp(scf::ParallelOp ploop) {
+ auto walkResult = ploop.getBody()->walk(
+ [](scf::ParallelOp) { return WalkResult::interrupt(); });
+ return walkResult.wasInterrupted();
+}
+
+static bool isFusionLegal(scf::ParallelOp firstPloop,
+ scf::ParallelOp secondPloop,
+ const IRMapping &firstToSecondPloopIndices,
+ llvm::function_ref<bool(Value, Value)> mayAlias) {
+ return !hasNestedParallelOp(firstPloop) &&
+ !hasNestedParallelOp(secondPloop) &&
+ equalIterationSpaces(firstPloop, secondPloop) &&
+ succeeded(verifyDependencies(firstPloop, secondPloop,
+ firstToSecondPloopIndices, mayAlias));
+}
+
+void mlir::fuseIfLegal(scf::ParallelOp firstPloop, scf::ParallelOp &secondPloop,
+ OpBuilder builder,
+ llvm::function_ref<bool(Value, Value)> mayAlias) {
+ Block *block1 = firstPloop.getBody();
+ Block *block2 = secondPloop.getBody();
+ IRMapping firstToSecondPloopIndices;
+ firstToSecondPloopIndices.map(block1->getArguments(), block2->getArguments());
+
+ if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices,
+ mayAlias))
+ return;
+
+ DominanceInfo dom;
+ // We are fusing first loop into second, make sure there are no users of the
+ // first loop results between loops.
+ for (Operation *user : firstPloop->getUsers())
+ if (!dom.properlyDominates(secondPloop, user, /*enclosingOpOk*/ false))
+ return;
+
+ ValueRange inits1 = firstPloop.getInitVals();
+ ValueRange inits2 = secondPloop.getInitVals();
+
+ SmallVector<Value> newInitVars(inits1.begin(), inits1.end());
+ newInitVars.append(inits2.begin(), inits2.end());
+
+ IRRewriter b(builder);
+ b.setInsertionPoint(secondPloop);
+ auto newSecondPloop = b.create<scf::ParallelOp>(
+ secondPloop.getLoc(), secondPloop.getLowerBound(),
+ secondPloop.getUpperBound(), secondPloop.getStep(), newInitVars);
+
+ Block *newBlock = newSecondPloop.getBody();
+ auto term1 = cast<scf::ReduceOp>(block1->getTerminator());
+ auto term2 = cast<scf::ReduceOp>(block2->getTerminator());
+
+ b.inlineBlockBefore(block2, newBlock, newBlock->begin(),
+ newBlock->getArguments());
+ b.inlineBlockBefore(block1, newBlock, newBlock->begin(),
+ newBlock->getArguments());
+
+ ValueRange results = newSecondPloop.getResults();
+ if (!results.empty()) {
+ b.setInsertionPointToEnd(newBlock);
+
+ ValueRange reduceArgs1 = term1.getOperands();
+ ValueRange reduceArgs2 = term2.getOperands();
+ SmallVector<Value> newReduceArgs(reduceArgs1.begin(), reduceArgs1.end());
+ newReduceArgs.append(reduceArgs2.begin(), reduceArgs2.end());
+
+ auto newReduceOp = b.create<scf::ReduceOp>(term2.getLoc(), newReduceArgs);
+
+ for (auto &&[i, reg] : llvm::enumerate(llvm::concat<Region>(
+ term1.getReductions(), term2.getReductions()))) {
+ Block &oldRedBlock = reg.front();
+ Block &newRedBlock = newReduceOp.getReductions()[i].front();
+ b.inlineBlockBefore(&oldRedBlock, &newRedBlock, newRedBlock.begin(),
+ newRedBlock.getArguments());
+ }
+
+ firstPloop.replaceAllUsesWith(results.take_front(inits1.size()));
+ secondPloop.replaceAllUsesWith(results.take_back(inits2.size()));
+ }
+ term1->erase();
+ term2->erase();
+ firstPloop.erase();
+ secondPloop.erase();
+ secondPloop = newSecondPloop;
+}
+
scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target,
scf::ForallOp source,
RewriterBase &rewriter) {
@@ -1171,3 +1372,10 @@ scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target,
return fusedLoop;
}
+
+scf::ParallelOp mlir::fuseIndependentSiblingParallelLoops(
+ scf::ParallelOp target, scf::ParallelOp source, RewriterBase &rewriter) {
+ auto mayAlias = [&](Value val1, Value val2) -> bool { return false; };
+ mlir::fuseIfLegal(target, source, rewriter, mayAlias);
+ return source;
+}
diff --git a/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir b/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir
index 0f51b1cdbe0cf..46c6be36c3271 100644
--- a/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir
+++ b/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir
@@ -47,6 +47,59 @@ module attributes {transform.with_named_sequence} {
// -----
+// CHECK-LABEL: func @fuse_two_parallel
+// CHECK-SAME: ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}) {
+func.func @fuse_two_parallel(%A: memref<2x2xf32>, %B: memref<2x2xf32>) {
+// CHECK-DAG: [[C2:%.*]] = arith.constant 2 : index
+// CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index
+// CHECK-DAG: [[C1:%.*]] = arith.constant 1 : index
+// CHECK-DAG: [[C1FP:%.*]] = arith.constant 1.
+ %c2 = arith.constant 2 : index
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c1fp = arith.constant 1.0 : f32
+// CHECK: [[SUM:%.*]] = memref.alloc()
+ %sum = memref.alloc() : memref<2x2xf32>
+// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
+// CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
+// CHECK: [[B_ELEM:%.*]] = memref.load [[B]]{{\[}}[[I]], [[J]]]
+// CHECK: [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[C1FP]]
+// CHECK: memref.store [[SUM_ELEM]], [[SUM]]{{\[}}[[I]], [[J]]]
+// CHECK-NOT: scf.parallel
+// CHECK: [[SUM_ELEM_:%.*]] = memref.load [[SUM]]{{\[}}[[I]], [[J]]]
+// CHECK: [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]]
+// CHECK: [[PRODUCT_ELEM:%.*]] = arith.mulf [[SUM_ELEM_]], [[A_ELEM]]
+// CHECK: memref.store [[PRODUCT_ELEM]], [[B]]{{\[}}[[I]], [[J]]]
+// CHECK: scf.reduce
+// CHECK: }
+ scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+ %B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
+ %sum_elem = arith.addf %B_elem, %c1fp : f32
+ memref.store %sum_elem, %sum[%i, %j] : memref<2x2xf32>
+ scf.reduce
+ }
+ scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+ %sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32>
+ %A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
+ %product_elem = arith.mulf %sum_elem, %A_elem : f32
+ memref.store %product_elem, %B[%i, %j] : memref<2x2xf32>
+ scf.reduce
+ }
+// CHECK: memref.dealloc [[SUM]]
+ memref.dealloc %sum : memref<2x2xf32>
+ return
+}
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["scf.parallel"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %parallel:2 = transform.split_handle %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ %fused = transform.loop.fuse_sibling %parallel#0 into %parallel#1 : (!transform.any_op,!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
// CHECK: func.func @fuse_2nd_for_into_1st([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}
func.func @fuse_2nd_for_into_1st(%A: tensor<128xf32>, %B: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>) {
// CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index
>From b73238a9472b0682f250e37848ad504d21a57059 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Wed, 5 Jun 2024 10:47:00 -0500
Subject: [PATCH 03/13] add checkFusionStructuralLegality
---
mlir/include/mlir/Dialect/SCF/Utils/Utils.h | 7 ++++++
mlir/lib/Dialect/SCF/Utils/Utils.cpp | 26 +++++++++++++++++++++
2 files changed, 33 insertions(+)
diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
index 2944d8ffac022..834857f177cdf 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
@@ -156,6 +156,13 @@ Loops tilePerfectlyNested(scf::ForOp rootForOp, ArrayRef<Value> sizes);
void getPerfectlyNestedLoops(SmallVectorImpl<scf::ForOp> &nestedLoops,
scf::ForOp root);
+//===----------------------------------------------------------------------===//
+// Fusion related helpers
+//===----------------------------------------------------------------------===//
+
+template <typename LoopTy>
+bool checkFusionStructuralLegality(Operation *target, Operation *source);
+
/// Prepends operations of firstPloop's body into secondPloop's body.
/// Updates secondPloop with new loop.
void fuseIfLegal(scf::ParallelOp firstPloop, scf::ParallelOp &secondPloop,
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index d85339f32dbe3..c490983335470 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -1184,6 +1184,10 @@ static bool equalIterationSpaces(scf::ParallelOp firstPloop,
matchOperands(firstPloop.getStep(), secondPloop.getStep());
}
+//===----------------------------------------------------------------------===//
+// Fusion related helpers
+//===----------------------------------------------------------------------===//
+
/// Verify there are no nested ParallelOps.
static bool hasNestedParallelOp(scf::ParallelOp ploop) {
auto walkResult = ploop.getBody()->walk(
@@ -1191,6 +1195,28 @@ static bool hasNestedParallelOp(scf::ParallelOp ploop) {
return walkResult.wasInterrupted();
}
+template <typename LoopTy>
+static bool checkFusionStructuralLegality(Operation *target,
+ Operation *source) {
+ static_assert(llvm::is_one_of<LoopTy, scf::ForallOp, scf::ForOp,
+ scf::ParallelOp>::value,
+ "applies to only `forall`, `for` and `parallel`");
+ auto targetOp = dyn_cast<LoopTy>(target);
+ auto sourceOp = dyn_cast<LoopTy>(source);
+ if (!targetOp || !sourceOp)
+ return false;
+
+ if constexpr (std::is_same_v<LoopTy, scf::ForallOp>)
+ return targetOp.getMixedLowerBound() == sourceOp.getMixedLowerBound() &&
+ targetOp.getMixedUpperBound() == sourceOp.getMixedUpperBound() &&
+ targetOp.getMixedStep() == sourceOp.getMixedStep() &&
+ targetOp.getMapping() == sourceOp.getMapping();
+ else
+ return targetOp.getLowerBound() == sourceOp.getLowerBound() &&
+ targetOp.getUpperBound() == sourceOp.getUpperBound() &&
+ targetOp.getStep() == sourceOp.getStep();
+}
+
static bool isFusionLegal(scf::ParallelOp firstPloop,
scf::ParallelOp secondPloop,
const IRMapping &firstToSecondPloopIndices,
>From f5bbd131bb7713ae47a58587b2d9acf82dc3b12f Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Wed, 5 Jun 2024 14:19:56 -0500
Subject: [PATCH 04/13] replace isLoopWithIdenticalConfiguration with
checkFusionStructuralLegality
---
mlir/include/mlir/Dialect/SCF/Utils/Utils.h | 4 +-
.../SCF/TransformOps/SCFTransformOps.cpp | 50 ++++++-------------
mlir/lib/Dialect/SCF/Utils/Utils.cpp | 32 +++++-------
3 files changed, 29 insertions(+), 57 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
index 834857f177cdf..ab9d154aa480d 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
@@ -160,8 +160,8 @@ void getPerfectlyNestedLoops(SmallVectorImpl<scf::ForOp> &nestedLoops,
// Fusion related helpers
//===----------------------------------------------------------------------===//
-template <typename LoopTy>
-bool checkFusionStructuralLegality(Operation *target, Operation *source);
+bool checkFusionStructuralLegality(LoopLikeOpInterface &target,
+ LoopLikeOpInterface &source);
/// Prepends operations of firstPloop's body into secondPloop's body.
/// Updates secondPloop with new loop.
diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index 1c53e89d69040..9f541b94af474 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -442,34 +442,6 @@ static DiagnosedSilenceableFailure isOpSibling(Operation *target,
return DiagnosedSilenceableFailure::success();
}
-/// Check if `target` scf loop can be fused into `source` scf loop.
-/// Applies for scf.for, scf.forall, and scf.parallel.
-///
-/// This simply checks if both loops have the same bounds, steps and mapping.
-/// No attempt is made at checking that the side effects of `target` and
-/// `source` are independent of each other.
-template <typename LoopTy>
-static bool isLoopWithIdenticalConfiguration(Operation *target,
- Operation *source) {
- static_assert(llvm::is_one_of<LoopTy, scf::ForallOp, scf::ForOp,
- scf::ParallelOp>::value,
- "applies to only `forall`, `for` and `parallel`");
- auto targetOp = dyn_cast<LoopTy>(target);
- auto sourceOp = dyn_cast<LoopTy>(source);
- if (!targetOp || !sourceOp)
- return false;
-
- if constexpr (std::is_same_v<LoopTy, scf::ForallOp>)
- return targetOp.getMixedLowerBound() == sourceOp.getMixedLowerBound() &&
- targetOp.getMixedUpperBound() == sourceOp.getMixedUpperBound() &&
- targetOp.getMixedStep() == sourceOp.getMixedStep() &&
- targetOp.getMapping() == sourceOp.getMapping();
- else
- return targetOp.getLowerBound() == sourceOp.getLowerBound() &&
- targetOp.getUpperBound() == sourceOp.getUpperBound() &&
- targetOp.getStep() == sourceOp.getStep();
-}
-
DiagnosedSilenceableFailure
transform::LoopFuseSiblingOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
@@ -485,29 +457,37 @@ transform::LoopFuseSiblingOp::apply(transform::TransformRewriter &rewriter,
<< "source handle (got " << llvm::range_size(sourceOps) << ")";
}
- Operation *target = *targetOps.begin();
- Operation *source = *sourceOps.begin();
+ LoopLikeOpInterface target =
+ dyn_cast<LoopLikeOpInterface>(*targetOps.begin());
+ LoopLikeOpInterface source =
+ dyn_cast<LoopLikeOpInterface>(*sourceOps.begin());
+ if (!target || !source)
+ return emitSilenceableFailure(target->getLoc())
+ << "target or source is not a loop op";
// Check if the target and source are siblings.
DiagnosedSilenceableFailure diag = isOpSibling(target, source);
if (!diag.succeeded())
return diag;
+ if (!mlir::checkFusionStructuralLegality(target, source))
+ return emitSilenceableFailure(target->getLoc())
+ << "operations cannot be fused";
+
Operation *fusedLoop;
/// TODO: Support fusion for loop-like ops besides scf.for and scf.forall.
- if (isLoopWithIdenticalConfiguration<scf::ForOp>(target, source)) {
+ if (isa<scf::ForOp>(target) && isa<scf::ForOp>(source)) {
fusedLoop = fuseIndependentSiblingForLoops(
cast<scf::ForOp>(target), cast<scf::ForOp>(source), rewriter);
- } else if (isLoopWithIdenticalConfiguration<scf::ForallOp>(target, source)) {
+ } else if (isa<scf::ForallOp>(target) && isa<scf::ForallOp>(source)) {
fusedLoop = fuseIndependentSiblingForallLoops(
cast<scf::ForallOp>(target), cast<scf::ForallOp>(source), rewriter);
- } else if (isLoopWithIdenticalConfiguration<scf::ParallelOp>(target,
- source)) {
+ } else if (isa<scf::ParallelOp>(target) && isa<scf::ParallelOp>(source)) {
fusedLoop = fuseIndependentSiblingParallelLoops(
cast<scf::ParallelOp>(target), cast<scf::ParallelOp>(source), rewriter);
} else
return emitSilenceableFailure(target->getLoc())
- << "operations cannot be fused";
+ << "unsupported loop type for fusion";
assert(fusedLoop && "failed to fuse operations");
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index c490983335470..ce20730459c2a 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -1195,26 +1195,18 @@ static bool hasNestedParallelOp(scf::ParallelOp ploop) {
return walkResult.wasInterrupted();
}
-template <typename LoopTy>
-static bool checkFusionStructuralLegality(Operation *target,
- Operation *source) {
- static_assert(llvm::is_one_of<LoopTy, scf::ForallOp, scf::ForOp,
- scf::ParallelOp>::value,
- "applies to only `forall`, `for` and `parallel`");
- auto targetOp = dyn_cast<LoopTy>(target);
- auto sourceOp = dyn_cast<LoopTy>(source);
- if (!targetOp || !sourceOp)
- return false;
-
- if constexpr (std::is_same_v<LoopTy, scf::ForallOp>)
- return targetOp.getMixedLowerBound() == sourceOp.getMixedLowerBound() &&
- targetOp.getMixedUpperBound() == sourceOp.getMixedUpperBound() &&
- targetOp.getMixedStep() == sourceOp.getMixedStep() &&
- targetOp.getMapping() == sourceOp.getMapping();
- else
- return targetOp.getLowerBound() == sourceOp.getLowerBound() &&
- targetOp.getUpperBound() == sourceOp.getUpperBound() &&
- targetOp.getStep() == sourceOp.getStep();
+bool mlir::checkFusionStructuralLegality(LoopLikeOpInterface &target,
+ LoopLikeOpInterface &source) {
+ auto iterSpaceEq =
+ target.getMixedLowerBound() == source.getMixedLowerBound() &&
+ target.getMixedUpperBound() == source.getMixedUpperBound() &&
+ target.getMixedStep() == source.getMixedStep();
+ auto forAllTarget = dyn_cast<scf::ForallOp>(*target);
+ auto forAllSource = dyn_cast<scf::ForallOp>(*source);
+ if (forAllTarget && forAllSource)
+ return iterSpaceEq &&
+ forAllTarget.getMapping() == forAllSource.getMapping();
+ return iterSpaceEq;
}
static bool isFusionLegal(scf::ParallelOp firstPloop,
>From 7d995815064cb25e47ec8e400de3692fbe5fdfba Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Wed, 5 Jun 2024 14:37:57 -0500
Subject: [PATCH 05/13] address review comment
---
.../mlir/Interfaces/LoopLikeInterface.td | 20 +++++++++++--------
1 file changed, 12 insertions(+), 8 deletions(-)
diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.td b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
index 813779c852027..5cf3eba0bd9ed 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
@@ -234,29 +234,33 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
/// If there is a single induction variable return it, otherwise return
/// std::nullopt.
::std::optional<::mlir::Value> getSingleInductionVar() {
- if (this->getInductionVars().size() == 1)
- return this->getInductionVars()[0];
+ auto inductionVars = this->getInductionVars();
+ if (inductionVars.size() == 1)
+ return inductionVars[0];
return std::nullopt;
}
/// Return the single lower bound value or attribute if it exists, otherwise
/// return std::nullopt.
::std::optional<::mlir::OpFoldResult> getSingleLowerBound() {
- if (this->getMixedLowerBound().size() == 1)
- return this->getMixedLowerBound()[0];
+ auto lowerBounds = this->getMixedLowerBound();
+ if (lowerBounds.size() == 1)
+ return lowerBounds[0];
return std::nullopt;
}
/// Return the single step value or attribute if it exists, otherwise
/// return std::nullopt.
::std::optional<::mlir::OpFoldResult> getSingleStep() {
- if (this->getMixedStep().size() == 1)
- return this->getMixedStep()[0];
+ auto steps = this->getMixedStep();
+ if (steps.size() == 1)
+ return steps[0];
return std::nullopt;
}
/// Return the single upper bound value or attribute if it exists, otherwise
/// return std::nullopt.
::std::optional<::mlir::OpFoldResult> getSingleUpperBound() {
- if (this->getMixedUpperBound().size() == 1)
- return this->getMixedUpperBound()[0];
+ auto upperBounds = this->getMixedUpperBound();
+ if (upperBounds.size() == 1)
+ return upperBounds[0];
return std::nullopt;
}
>From a5fa3b3c4903c344847ee544cd9812b6f0c70571 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Wed, 5 Jun 2024 20:56:02 -0500
Subject: [PATCH 06/13] Make return types optional and change names
---
.../mlir/Dialect/Affine/IR/AffineOps.td | 10 +--
mlir/include/mlir/Dialect/SCF/IR/SCFOps.td | 29 +++++++--
.../mlir/Interfaces/LoopLikeInterface.td | 42 ++++++-------
.../AffineToStandard/AffineToStandard.cpp | 4 +-
.../SCFToControlFlow/SCFToControlFlow.cpp | 9 +--
mlir/lib/Dialect/Affine/IR/AffineOps.cpp | 27 ++++----
mlir/lib/Dialect/Affine/Utils/Utils.cpp | 2 +-
mlir/lib/Dialect/SCF/IR/SCF.cpp | 26 ++++----
.../Dialect/SCF/Transforms/ForallToFor.cpp | 9 +--
.../Dialect/SCF/LoopLikeSCFOpsTest.cpp | 62 +++++++++++++------
10 files changed, 131 insertions(+), 89 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index bb2c29b5733b8..4c032e66f7a83 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -118,8 +118,8 @@ def AffineForOp : Affine_Op<"for",
[AttrSizedOperandSegments, AutomaticAllocationScope,
ImplicitAffineTerminator, ConditionallySpeculatable,
RecursiveMemoryEffects, DeclareOpInterfaceMethods<LoopLikeOpInterface,
- ["getInductionVars", "getMixedLowerBound", "getMixedStep",
- "getMixedUpperBound", "getYieldedValuesMutable",
+ ["getInductionVars", "getLowerBounds", "getSteps",
+ "getUpperBounds", "getYieldedValuesMutable",
"replaceWithAdditionalYields"]>,
DeclareOpInterfaceMethods<RegionBranchOpInterface,
["getEntrySuccessorOperands"]>]> {
@@ -671,7 +671,7 @@ def AffineParallelOp : Affine_Op<"parallel",
I32ElementsAttr:$lowerBoundsGroups,
AffineMapAttr:$upperBoundsMap,
I32ElementsAttr:$upperBoundsGroups,
- I64SmallVectorArrayAttr:$steps,
+ I64SmallVectorArrayAttr:$step,
Variadic<Index>:$mapOperands);
let results = (outs Variadic<AnyType>:$results);
let regions = (region SizedRegion<1>:$region);
@@ -682,7 +682,7 @@ def AffineParallelOp : Affine_Op<"parallel",
OpBuilder<(ins "TypeRange":$resultTypes,
"ArrayRef<arith::AtomicRMWKind>":$reductions, "ArrayRef<AffineMap>":$lbMaps,
"ValueRange":$lbArgs, "ArrayRef<AffineMap>":$ubMaps, "ValueRange":$ubArgs,
- "ArrayRef<int64_t>":$steps)>
+ "ArrayRef<int64_t>":$step)>
];
let extraClassDeclaration = [{
@@ -727,7 +727,7 @@ def AffineParallelOp : Affine_Op<"parallel",
static StringRef getUpperBoundsGroupsAttrStrName() {
return "upperBoundsGroups";
}
- static StringRef getStepsAttrStrName() { return "steps"; }
+ static StringRef getStepsAttrStrName() { return "step"; }
/// Returns `true` if the loop bounds have min/max expressions.
bool hasMinMaxBounds() {
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 3b28ca8b21d0f..66b478f141b32 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -136,8 +136,8 @@ def ExecuteRegionOp : SCF_Op<"execute_region", [
def ForOp : SCF_Op<"for",
[AutomaticAllocationScope, DeclareOpInterfaceMethods<LoopLikeOpInterface,
["getInitsMutable", "getLoopResults", "getRegionIterArgs",
- "getInductionVars", "getMixedLowerBound", "getMixedStep",
- "getMixedUpperBound", "getYieldedValuesMutable",
+ "getInductionVars", "getLowerBounds", "getSteps",
+ "getUpperBounds", "getYieldedValuesMutable",
"promoteIfSingleIteration", "replaceWithAdditionalYields",
"yieldTiledValuesAndReplace"]>,
AllTypesMatch<["lowerBound", "upperBound", "step"]>,
@@ -302,7 +302,7 @@ def ForallOp : SCF_Op<"forall", [
AutomaticAllocationScope,
DeclareOpInterfaceMethods<LoopLikeOpInterface,
["getInitsMutable", "getRegionIterArgs", "getInductionVars",
- "getMixedLowerBound", "getMixedUpperBound", "getMixedStep",
+ "getLowerBounds", "getUpperBounds", "getSteps",
"promoteIfSingleIteration", "yieldTiledValuesAndReplace"]>,
RecursiveMemoryEffects,
SingleBlockImplicitTerminator<"scf::InParallelOp">,
@@ -510,6 +510,27 @@ def ForallOp : SCF_Op<"forall", [
];
let extraClassDeclaration = [{
+ // Get lower bounds as OpFoldResult.
+ SmallVector<OpFoldResult> getMixedLowerBound() {
+ auto maybeLowerBounds = getLowerBounds();
+ assert(maybeLowerBounds.has_value() && "expected values");
+ return *maybeLowerBounds;
+ }
+
+ // Get upper bounds as OpFoldResult.
+ SmallVector<OpFoldResult> getMixedUpperBound() {
+ auto maybeUpperBounds = getUpperBounds();
+ assert(maybeUpperBounds.has_value() && "expected values");
+ return *maybeUpperBounds;
+ }
+
+ // Get steps as OpFoldResult.
+ SmallVector<OpFoldResult> getMixedStep() {
+ auto maybeSteps = getSteps();
+ assert(maybeSteps.has_value() && "expected values");
+ return *maybeSteps;
+ }
+
/// Get lower bounds as values.
SmallVector<Value> getLowerBound(OpBuilder &b) {
return getValueOrCreateConstantIndexOp(b, getLoc(), getMixedLowerBound());
@@ -744,7 +765,7 @@ def ParallelOp : SCF_Op<"parallel",
[AutomaticAllocationScope,
AttrSizedOperandSegments,
DeclareOpInterfaceMethods<LoopLikeOpInterface, ["getInductionVars",
- "getMixedLowerBound", "getMixedUpperBound", "getMixedStep"]>,
+ "getLowerBounds", "getUpperBounds", "getSteps"]>,
RecursiveMemoryEffects,
DeclareOpInterfaceMethods<RegionBranchOpInterface>,
SingleBlockImplicitTerminator<"scf::ReduceOp">,
diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.td b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
index 5cf3eba0bd9ed..cc79d026c8d4e 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
@@ -106,34 +106,34 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
InterfaceMethod<[{
Return all lower bounds.
}],
- /*retTy=*/"::llvm::SmallVector<::mlir::OpFoldResult>",
- /*methodName=*/"getMixedLowerBound",
+ /*retTy=*/"::std::optional<::llvm::SmallVector<::mlir::OpFoldResult>>",
+ /*methodName=*/"getLowerBounds",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- return {};
+ return std::nullopt;
}]
>,
InterfaceMethod<[{
Return all steps.
}],
- /*retTy=*/"::llvm::SmallVector<::mlir::OpFoldResult>",
- /*methodName=*/"getMixedStep",
+ /*retTy=*/"std::optional<::llvm::SmallVector<::mlir::OpFoldResult>>",
+ /*methodName=*/"getSteps",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- return {};
+ return std::nullopt;
}]
>,
InterfaceMethod<[{
Return all upper bounds.
}],
- /*retTy=*/"::llvm::SmallVector<::mlir::OpFoldResult>",
- /*methodName=*/"getMixedUpperBound",
+ /*retTy=*/"::std::optional<::llvm::SmallVector<::mlir::OpFoldResult>>",
+ /*methodName=*/"getUpperBounds",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- return {};
+ return std::nullopt;
}]
>,
InterfaceMethod<[{
@@ -242,26 +242,26 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
/// Return the single lower bound value or attribute if it exists, otherwise
/// return std::nullopt.
::std::optional<::mlir::OpFoldResult> getSingleLowerBound() {
- auto lowerBounds = this->getMixedLowerBound();
- if (lowerBounds.size() == 1)
- return lowerBounds[0];
- return std::nullopt;
+ auto lowerBounds = this->getLowerBounds();
+ if (lowerBounds.has_value() && (*lowerBounds).size() == 1)
+ return (*lowerBounds)[0];
+ return std::nullopt;
}
/// Return the single step value or attribute if it exists, otherwise
/// return std::nullopt.
::std::optional<::mlir::OpFoldResult> getSingleStep() {
- auto steps = this->getMixedStep();
- if (steps.size() == 1)
- return steps[0];
- return std::nullopt;
+ auto steps = this->getSteps();
+ if (steps.has_value() && (*steps).size() == 1)
+ return (*steps)[0];
+ return std::nullopt;
}
/// Return the single upper bound value or attribute if it exists, otherwise
/// return std::nullopt.
::std::optional<::mlir::OpFoldResult> getSingleUpperBound() {
- auto upperBounds = this->getMixedUpperBound();
- if (upperBounds.size() == 1)
- return upperBounds[0];
- return std::nullopt;
+ auto upperBounds = this->getUpperBounds();
+ if (upperBounds.has_value() && (*upperBounds).size() == 1)
+ return (*upperBounds)[0];
+ return std::nullopt;
}
/// Append the specified additional "init" operands: replace this loop with
diff --git a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
index 10ccd5c97783b..20487b32e3fe0 100644
--- a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
+++ b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
@@ -196,8 +196,8 @@ class AffineParallelLowering : public OpRewritePattern<AffineParallelOp> {
return rewriter.notifyMatchFailure(op, "couldn't convert upper bounds");
upperBoundTuple.push_back(upper);
}
- steps.reserve(op.getSteps().size());
- for (int64_t step : op.getSteps())
+ steps.reserve(op.getStep().size());
+ for (int64_t step : op.getStep())
steps.push_back(rewriter.create<arith::ConstantIndexOp>(loc, step));
// Get the terminator op.
diff --git a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
index 9eb8a289d7d65..48e1d88c1c75e 100644
--- a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
+++ b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
@@ -695,12 +695,9 @@ LogicalResult ForallLowering::matchAndRewrite(ForallOp forallOp,
"only fully bufferized scf.forall ops can be lowered to scf.parallel");
// Convert mixed bounds and steps to SSA values.
- SmallVector<Value> lbs = getValueOrCreateConstantIndexOp(
- rewriter, loc, forallOp.getMixedLowerBound());
- SmallVector<Value> ubs = getValueOrCreateConstantIndexOp(
- rewriter, loc, forallOp.getMixedUpperBound());
- SmallVector<Value> steps =
- getValueOrCreateConstantIndexOp(rewriter, loc, forallOp.getMixedStep());
+ SmallVector<Value> lbs = forallOp.getLowerBound(rewriter);
+ SmallVector<Value> ubs = forallOp.getUpperBound(rewriter);
+ SmallVector<Value> steps = forallOp.getStep(rewriter);
// Create empty scf.parallel op.
auto parallelOp = rewriter.create<ParallelOp>(loc, lbs, ubs, steps);
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 746a9c919560c..d3f034a0660ba 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -2456,23 +2456,26 @@ SmallVector<Region *> AffineForOp::getLoopRegions() { return {&getRegion()}; }
ValueRange AffineForOp::getInductionVars() { return {getInductionVar()}; }
-SmallVector<OpFoldResult> AffineForOp::getMixedLowerBound() {
+std::optional<SmallVector<OpFoldResult>> AffineForOp::getLowerBounds() {
if (!hasConstantLowerBound())
- return {};
+ return std::nullopt;
OpBuilder b(getContext());
- return {OpFoldResult(b.getI64IntegerAttr(getConstantLowerBound()))};
+ return SmallVector<OpFoldResult>{
+ OpFoldResult(b.getI64IntegerAttr(getConstantLowerBound()))};
}
-SmallVector<OpFoldResult> AffineForOp::getMixedStep() {
+std::optional<SmallVector<OpFoldResult>> AffineForOp::getSteps() {
OpBuilder b(getContext());
- return {OpFoldResult(b.getI64IntegerAttr(getStepAsInt()))};
+ return SmallVector<OpFoldResult>{
+ OpFoldResult(b.getI64IntegerAttr(getStepAsInt()))};
}
-SmallVector<OpFoldResult> AffineForOp::getMixedUpperBound() {
+std::optional<SmallVector<OpFoldResult>> AffineForOp::getUpperBounds() {
if (!hasConstantUpperBound())
return {};
OpBuilder b(getContext());
- return {OpFoldResult(b.getI64IntegerAttr(getConstantUpperBound()))};
+ return SmallVector<OpFoldResult>{
+ OpFoldResult(b.getI64IntegerAttr(getConstantUpperBound()))};
}
FailureOr<LoopLikeOpInterface> AffineForOp::replaceWithAdditionalYields(
@@ -3753,7 +3756,7 @@ SmallVector<Region *> AffineParallelOp::getLoopRegions() {
return {&getRegion()};
}
-unsigned AffineParallelOp::getNumDims() { return getSteps().size(); }
+unsigned AffineParallelOp::getNumDims() { return getStep().size(); }
AffineParallelOp::operand_range AffineParallelOp::getLowerBoundsOperands() {
return getOperands().take_front(getLowerBoundsMap().getNumInputs());
@@ -3838,7 +3841,7 @@ void AffineParallelOp::setUpperBounds(ValueRange ubOperands, AffineMap map) {
}
void AffineParallelOp::setSteps(ArrayRef<int64_t> newSteps) {
- setStepsAttr(getBodyBuilder().getI64ArrayAttr(newSteps));
+ setStepAttr(getBodyBuilder().getI64ArrayAttr(newSteps));
}
// check whether resultType match op or not in affine.parallel
@@ -3888,14 +3891,14 @@ LogicalResult AffineParallelOp::verify() {
auto numDims = getNumDims();
if (getLowerBoundsGroups().getNumElements() != numDims ||
getUpperBoundsGroups().getNumElements() != numDims ||
- getSteps().size() != numDims || getBody()->getNumArguments() != numDims) {
+ getStep().size() != numDims || getBody()->getNumArguments() != numDims) {
return emitOpError() << "the number of region arguments ("
<< getBody()->getNumArguments()
<< ") and the number of map groups for lower ("
<< getLowerBoundsGroups().getNumElements()
<< ") and upper bound ("
<< getUpperBoundsGroups().getNumElements()
- << "), and the number of steps (" << getSteps().size()
+ << "), and the number of steps (" << getStep().size()
<< ") must all match";
}
@@ -4013,7 +4016,7 @@ void AffineParallelOp::print(OpAsmPrinter &p) {
printMinMaxBound(p, getUpperBoundsMapAttr(), getUpperBoundsGroupsAttr(),
getUpperBoundsOperands(), "min");
p << ')';
- SmallVector<int64_t, 8> steps = getSteps();
+ SmallVector<int64_t, 8> steps = getStep();
bool elideSteps = llvm::all_of(steps, [](int64_t step) { return step == 1; });
if (!elideSteps) {
p << " step (";
diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index f46381403bc52..a652ee4a488d1 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -494,7 +494,7 @@ void mlir::affine::normalizeAffineParallel(AffineParallelOp op) {
return;
AffineMap lbMap = op.getLowerBoundsMap();
- SmallVector<int64_t, 8> steps = op.getSteps();
+ SmallVector<int64_t, 8> steps = op.getStep();
// No need to do any work if the parallel op is already normalized.
bool isAlreadyNormalized =
llvm::all_of(llvm::zip(steps, lbMap.getResults()), [](auto tuple) {
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index e275ff1849c10..281d73afee4a8 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -380,16 +380,16 @@ LogicalResult ForOp::verifyRegions() {
ValueRange ForOp::getInductionVars() { return {getInductionVar()}; }
-SmallVector<OpFoldResult> ForOp::getMixedLowerBound() {
- return {OpFoldResult(getLowerBound())};
+std::optional<SmallVector<OpFoldResult>> ForOp::getLowerBounds() {
+ return SmallVector<OpFoldResult, 1>{OpFoldResult(getLowerBound())};
}
-SmallVector<OpFoldResult> ForOp::getMixedStep() {
- return {OpFoldResult(getStep())};
+std::optional<SmallVector<OpFoldResult>> ForOp::getSteps() {
+ return SmallVector<OpFoldResult, 1>{OpFoldResult(getStep())};
}
-SmallVector<OpFoldResult> ForOp::getMixedUpperBound() {
- return {OpFoldResult(getUpperBound())};
+std::optional<SmallVector<OpFoldResult>> ForOp::getUpperBounds() {
+ return SmallVector<OpFoldResult, 1>{OpFoldResult(getUpperBound())};
}
std::optional<ResultRange> ForOp::getLoopResults() { return getResults(); }
@@ -1431,19 +1431,19 @@ ValueRange ForallOp::getInductionVars() {
}
// Get lower bounds as OpFoldResult.
-SmallVector<OpFoldResult> ForallOp::getMixedLowerBound() {
+std::optional<SmallVector<OpFoldResult>> ForallOp::getLowerBounds() {
Builder b(getOperation()->getContext());
return getMixedValues(getStaticLowerBound(), getDynamicLowerBound(), b);
}
// Get upper bounds as OpFoldResult.
-SmallVector<OpFoldResult> ForallOp::getMixedUpperBound() {
+std::optional<SmallVector<OpFoldResult>> ForallOp::getUpperBounds() {
Builder b(getOperation()->getContext());
return getMixedValues(getStaticUpperBound(), getDynamicUpperBound(), b);
}
// Get steps as OpFoldResult.
-SmallVector<OpFoldResult> ForallOp::getMixedStep() {
+std::optional<SmallVector<OpFoldResult>> ForallOp::getSteps() {
Builder b(getOperation()->getContext());
return getMixedValues(getStaticStep(), getDynamicStep(), b);
}
@@ -3006,15 +3006,17 @@ SmallVector<Region *> ParallelOp::getLoopRegions() { return {&getRegion()}; }
ValueRange ParallelOp::getInductionVars() { return getBody()->getArguments(); }
-SmallVector<OpFoldResult> ParallelOp::getMixedLowerBound() {
+std::optional<SmallVector<OpFoldResult>> ParallelOp::getLowerBounds() {
return getLowerBound();
}
-SmallVector<OpFoldResult> ParallelOp::getMixedUpperBound() {
+std::optional<SmallVector<OpFoldResult>> ParallelOp::getUpperBounds() {
return getUpperBound();
}
-SmallVector<OpFoldResult> ParallelOp::getMixedStep() { return getStep(); }
+std::optional<SmallVector<OpFoldResult>> ParallelOp::getSteps() {
+ return getStep();
+}
ParallelOp mlir::scf::getParallelForInductionVarOwner(Value val) {
auto ivArg = llvm::dyn_cast<BlockArgument>(val);
diff --git a/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp b/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp
index 198cb2e6cc69e..5da1b76e929be 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp
@@ -34,12 +34,9 @@ mlir::scf::forallToForLoop(RewriterBase &rewriter, scf::ForallOp forallOp,
rewriter.setInsertionPoint(forallOp);
Location loc = forallOp.getLoc();
- SmallVector<Value> lbs = getValueOrCreateConstantIndexOp(
- rewriter, loc, forallOp.getMixedLowerBound());
- SmallVector<Value> ubs = getValueOrCreateConstantIndexOp(
- rewriter, loc, forallOp.getMixedUpperBound());
- SmallVector<Value> steps =
- getValueOrCreateConstantIndexOp(rewriter, loc, forallOp.getMixedStep());
+ SmallVector<Value> lbs = forallOp.getLowerBound(rewriter);
+ SmallVector<Value> ubs = forallOp.getUpperBound(rewriter);
+ SmallVector<Value> steps = forallOp.getStep(rewriter);
LoopNest loopNest = scf::buildLoopNest(rewriter, loc, lbs, ubs, steps);
SmallVector<Value> ivs = llvm::map_to_vector(
diff --git a/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp b/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp
index d8cdb213070da..07504a99fecd3 100644
--- a/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp
+++ b/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp
@@ -27,35 +27,57 @@ class SCFLoopLikeTest : public ::testing::Test {
}
void checkUnidimensional(LoopLikeOpInterface loopLikeOp) {
- std::optional<OpFoldResult> maybeLb = loopLikeOp.getSingleLowerBound();
+ std::optional<OpFoldResult> maybeSingleLb =
+ loopLikeOp.getSingleLowerBound();
+ EXPECT_TRUE(maybeSingleLb.has_value());
+ std::optional<OpFoldResult> maybeSingleUb =
+ loopLikeOp.getSingleUpperBound();
+ EXPECT_TRUE(maybeSingleUb.has_value());
+ std::optional<OpFoldResult> maybeSingleStep = loopLikeOp.getSingleStep();
+ EXPECT_TRUE(maybeSingleStep.has_value());
+ std::optional<OpFoldResult> maybeSingleIndVar =
+ loopLikeOp.getSingleInductionVar();
+ EXPECT_TRUE(maybeSingleIndVar.has_value());
+
+ std::optional<SmallVector<OpFoldResult>> maybeLb =
+ loopLikeOp.getLowerBounds();
EXPECT_TRUE(maybeLb.has_value());
- std::optional<OpFoldResult> maybeUb = loopLikeOp.getSingleUpperBound();
+ EXPECT_EQ((*maybeLb).size(), 1u);
+ std::optional<SmallVector<OpFoldResult>> maybeUb =
+ loopLikeOp.getUpperBounds();
EXPECT_TRUE(maybeUb.has_value());
- std::optional<OpFoldResult> maybeStep = loopLikeOp.getSingleStep();
+ EXPECT_EQ((*maybeUb).size(), 1u);
+ std::optional<SmallVector<OpFoldResult>> maybeStep = loopLikeOp.getSteps();
EXPECT_TRUE(maybeStep.has_value());
- std::optional<OpFoldResult> maybeIndVar =
- loopLikeOp.getSingleInductionVar();
- EXPECT_TRUE(maybeIndVar.has_value());
+ EXPECT_EQ((*maybeStep).size(), 1u);
EXPECT_EQ(loopLikeOp.getInductionVars().size(), 1u);
- EXPECT_EQ(loopLikeOp.getMixedLowerBound().size(), 1u);
- EXPECT_EQ(loopLikeOp.getMixedStep().size(), 1u);
- EXPECT_EQ(loopLikeOp.getMixedLowerBound().size(), 1u);
}
void checkMultidimensional(LoopLikeOpInterface loopLikeOp) {
- std::optional<OpFoldResult> maybeLb = loopLikeOp.getSingleLowerBound();
- EXPECT_FALSE(maybeLb.has_value());
- std::optional<OpFoldResult> maybeUb = loopLikeOp.getSingleUpperBound();
- EXPECT_FALSE(maybeUb.has_value());
- std::optional<OpFoldResult> maybeStep = loopLikeOp.getSingleStep();
- EXPECT_FALSE(maybeStep.has_value());
- std::optional<OpFoldResult> maybeIndVar =
+ std::optional<OpFoldResult> maybeSingleLb =
+ loopLikeOp.getSingleLowerBound();
+ EXPECT_FALSE(maybeSingleLb.has_value());
+ std::optional<OpFoldResult> maybeSingleUb =
+ loopLikeOp.getSingleUpperBound();
+ EXPECT_FALSE(maybeSingleUb.has_value());
+ std::optional<OpFoldResult> maybeSingleStep = loopLikeOp.getSingleStep();
+ EXPECT_FALSE(maybeSingleStep.has_value());
+ std::optional<OpFoldResult> maybeSingleIndVar =
loopLikeOp.getSingleInductionVar();
- EXPECT_FALSE(maybeIndVar.has_value());
+ EXPECT_FALSE(maybeSingleIndVar.has_value());
+
+ std::optional<SmallVector<OpFoldResult>> maybeLb =
+ loopLikeOp.getLowerBounds();
+ EXPECT_TRUE(maybeLb.has_value());
+ EXPECT_EQ((*maybeLb).size(), 2u);
+ std::optional<SmallVector<OpFoldResult>> maybeUb =
+ loopLikeOp.getUpperBounds();
+ EXPECT_TRUE(maybeUb.has_value());
+ EXPECT_EQ((*maybeUb).size(), 2u);
+ std::optional<SmallVector<OpFoldResult>> maybeStep = loopLikeOp.getSteps();
+ EXPECT_TRUE(maybeStep.has_value());
+ EXPECT_EQ((*maybeStep).size(), 2u);
EXPECT_EQ(loopLikeOp.getInductionVars().size(), 2u);
- EXPECT_EQ(loopLikeOp.getMixedLowerBound().size(), 2u);
- EXPECT_EQ(loopLikeOp.getMixedStep().size(), 2u);
- EXPECT_EQ(loopLikeOp.getMixedLowerBound().size(), 2u);
}
MLIRContext context;
>From 1babe681d7858a4992303c62e22684cb73d82472 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Thu, 6 Jun 2024 11:31:11 -0500
Subject: [PATCH 07/13] change return type of getInductionVars to
SmallVector<Value>
---
mlir/include/mlir/Interfaces/LoopLikeInterface.td | 2 +-
mlir/lib/Dialect/Affine/IR/AffineOps.cpp | 4 +++-
mlir/lib/Dialect/Linalg/Transforms/Loops.cpp | 3 +--
mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp | 6 +++---
mlir/lib/Dialect/SCF/IR/SCF.cpp | 10 ++++++----
5 files changed, 14 insertions(+), 11 deletions(-)
diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.td b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
index cc79d026c8d4e..bace8f8384d44 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
@@ -95,7 +95,7 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
InterfaceMethod<[{
Return all induction variables.
}],
- /*retTy=*/"::mlir::ValueRange",
+ /*retTy=*/"::llvm::SmallVector<::mlir::Value>",
/*methodName=*/"getInductionVars",
/*args=*/(ins),
/*methodBody=*/"",
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index d3f034a0660ba..5467c60242664 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -2454,7 +2454,9 @@ bool AffineForOp::matchingBoundOperandList() {
SmallVector<Region *> AffineForOp::getLoopRegions() { return {&getRegion()}; }
-ValueRange AffineForOp::getInductionVars() { return {getInductionVar()}; }
+SmallVector<Value> AffineForOp::getInductionVars() {
+ return {getInductionVar()};
+}
std::optional<SmallVector<OpFoldResult>> AffineForOp::getLowerBounds() {
if (!hasConstantLowerBound())
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
index b0a4de2da1e86..8b0e04fb61b1b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
@@ -184,8 +184,7 @@ static void replaceIndexOpsByInductionVariables(RewriterBase &rewriter,
for (Operation *loopOp : loopOps) {
llvm::TypeSwitch<Operation *>(loopOp)
.Case([&](scf::ParallelOp parallelOp) {
- allIvs.append(parallelOp.getInductionVars().begin(),
- parallelOp.getInductionVars().end());
+ allIvs.append(parallelOp.getInductionVars());
})
.Case([&](scf::ForOp forOp) {
allIvs.push_back(forOp.getInductionVar());
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index fd314ef9f8134..4eacaa8d1e327 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -243,7 +243,7 @@ static void calculateTileOffsetsAndSizes(
OpBuilder::InsertionGuard g(b);
b.setInsertionPointToStart(forallOp.getBody(0));
- ValueRange threadIds = forallOp.getInductionVars();
+ auto threadIds = forallOp.getInductionVars();
SmallVector<OpFoldResult> nonZeroNumThreads =
llvm::to_vector(llvm::make_filter_range(numThreads, [](OpFoldResult ofr) {
return !isConstantIntValue(ofr, 0);
@@ -746,7 +746,7 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
b.getIndexAttr(0));
SmallVector<OpFoldResult> sizes = tiledSizes;
sizes[reductionDim] = b.getIndexAttr(1);
- outOffsets[reductionDim] = forallOp.getInductionVars().front();
+ outOffsets[reductionDim] = forallOp.getInductionVars()[0];
// TODO: use SubsetExtractOpInterface once it is available.
tiledDpsInitOperands.push_back(b.create<tensor::ExtractSliceOp>(
loc, cast<RankedTensorType>(initOperand.getType()),
@@ -814,7 +814,7 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
int64_t sizeIdx = 0;
for (int64_t i = 0, e = numThreads.size(); i < e; ++i) {
if (i == reductionDim) {
- resultOffsetsRank.push_back(forallOp.getInductionVars().front());
+ resultOffsetsRank.push_back(forallOp.getInductionVars()[0]);
resultSizesRank.push_back(b.getIndexAttr(1));
continue;
}
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 281d73afee4a8..0ce10ebdad3e2 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -378,7 +378,7 @@ LogicalResult ForOp::verifyRegions() {
return success();
}
-ValueRange ForOp::getInductionVars() { return {getInductionVar()}; }
+SmallVector<Value> ForOp::getInductionVars() { return {getInductionVar()}; }
std::optional<SmallVector<OpFoldResult>> ForOp::getLowerBounds() {
return SmallVector<OpFoldResult, 1>{OpFoldResult(getLowerBound())};
@@ -1426,8 +1426,8 @@ SmallVector<Operation *> ForallOp::getCombiningOps(BlockArgument bbArg) {
return storeOps;
}
-ValueRange ForallOp::getInductionVars() {
- return getBody()->getArguments().take_front(getRank());
+SmallVector<Value> ForallOp::getInductionVars() {
+ return SmallVector<Value>(getBody()->getArguments().take_front(getRank()));
}
// Get lower bounds as OpFoldResult.
@@ -3004,7 +3004,9 @@ void ParallelOp::print(OpAsmPrinter &p) {
SmallVector<Region *> ParallelOp::getLoopRegions() { return {&getRegion()}; }
-ValueRange ParallelOp::getInductionVars() { return getBody()->getArguments(); }
+SmallVector<Value> ParallelOp::getInductionVars() {
+ return SmallVector<Value>(getBody()->getArguments());
+}
std::optional<SmallVector<OpFoldResult>> ParallelOp::getLowerBounds() {
return getLowerBound();
>From 009fd15ab8abefd56afe6424e27f912a4166329d Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Thu, 6 Jun 2024 14:02:52 -0500
Subject: [PATCH 08/13] address maks's comments
---
mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp | 2 +-
mlir/lib/Dialect/SCF/IR/SCF.cpp | 4 ++--
2 files changed, 3 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index 4eacaa8d1e327..a0a0e11a6903d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -243,7 +243,7 @@ static void calculateTileOffsetsAndSizes(
OpBuilder::InsertionGuard g(b);
b.setInsertionPointToStart(forallOp.getBody(0));
- auto threadIds = forallOp.getInductionVars();
+ SmallVector<Value> threadIds = forallOp.getInductionVars();
SmallVector<OpFoldResult> nonZeroNumThreads =
llvm::to_vector(llvm::make_filter_range(numThreads, [](OpFoldResult ofr) {
return !isConstantIntValue(ofr, 0);
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 0ce10ebdad3e2..a930f8c71454c 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -1427,7 +1427,7 @@ SmallVector<Operation *> ForallOp::getCombiningOps(BlockArgument bbArg) {
}
SmallVector<Value> ForallOp::getInductionVars() {
- return SmallVector<Value>(getBody()->getArguments().take_front(getRank()));
+ return SmallVector<Value>{getBody()->getArguments().take_front(getRank())};
}
// Get lower bounds as OpFoldResult.
@@ -3005,7 +3005,7 @@ void ParallelOp::print(OpAsmPrinter &p) {
SmallVector<Region *> ParallelOp::getLoopRegions() { return {&getRegion()}; }
SmallVector<Value> ParallelOp::getInductionVars() {
- return SmallVector<Value>(getBody()->getArguments());
+ return SmallVector<Value>{getBody()->getArguments()};
}
std::optional<SmallVector<OpFoldResult>> ParallelOp::getLowerBounds() {
>From d34ad95aba669b5700976f0d2ed4d68b4902e9be Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Thu, 6 Jun 2024 15:26:06 -0500
Subject: [PATCH 09/13] change interface method names again and revert steps
operand change
---
.../mlir/Dialect/Affine/IR/AffineOps.td | 10 +++----
mlir/include/mlir/Dialect/SCF/IR/SCFOps.td | 27 ++++++++++++-------
.../mlir/Interfaces/LoopLikeInterface.td | 16 +++++------
.../AffineToStandard/AffineToStandard.cpp | 4 +--
mlir/lib/Dialect/Affine/IR/AffineOps.cpp | 18 ++++++-------
mlir/lib/Dialect/Affine/Utils/Utils.cpp | 2 +-
mlir/lib/Dialect/SCF/IR/SCF.cpp | 24 ++++++++---------
.../Dialect/SCF/LoopLikeSCFOpsTest.cpp | 18 +++++++------
8 files changed, 64 insertions(+), 55 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index 4c032e66f7a83..dbec741cf1b1f 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -118,8 +118,8 @@ def AffineForOp : Affine_Op<"for",
[AttrSizedOperandSegments, AutomaticAllocationScope,
ImplicitAffineTerminator, ConditionallySpeculatable,
RecursiveMemoryEffects, DeclareOpInterfaceMethods<LoopLikeOpInterface,
- ["getInductionVars", "getLowerBounds", "getSteps",
- "getUpperBounds", "getYieldedValuesMutable",
+ ["getLoopInductionVars", "getLoopLowerBounds", "getLoopSteps",
+ "getLoopUpperBounds", "getYieldedValuesMutable",
"replaceWithAdditionalYields"]>,
DeclareOpInterfaceMethods<RegionBranchOpInterface,
["getEntrySuccessorOperands"]>]> {
@@ -671,7 +671,7 @@ def AffineParallelOp : Affine_Op<"parallel",
I32ElementsAttr:$lowerBoundsGroups,
AffineMapAttr:$upperBoundsMap,
I32ElementsAttr:$upperBoundsGroups,
- I64SmallVectorArrayAttr:$step,
+ I64SmallVectorArrayAttr:$steps,
Variadic<Index>:$mapOperands);
let results = (outs Variadic<AnyType>:$results);
let regions = (region SizedRegion<1>:$region);
@@ -682,7 +682,7 @@ def AffineParallelOp : Affine_Op<"parallel",
OpBuilder<(ins "TypeRange":$resultTypes,
"ArrayRef<arith::AtomicRMWKind>":$reductions, "ArrayRef<AffineMap>":$lbMaps,
"ValueRange":$lbArgs, "ArrayRef<AffineMap>":$ubMaps, "ValueRange":$ubArgs,
- "ArrayRef<int64_t>":$step)>
+ "ArrayRef<int64_t>":$steps)>
];
let extraClassDeclaration = [{
@@ -727,7 +727,7 @@ def AffineParallelOp : Affine_Op<"parallel",
static StringRef getUpperBoundsGroupsAttrStrName() {
return "upperBoundsGroups";
}
- static StringRef getStepsAttrStrName() { return "step"; }
+ static StringRef getStepsAttrStrName() { return "steps"; }
/// Returns `true` if the loop bounds have min/max expressions.
bool hasMinMaxBounds() {
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 66b478f141b32..3704b15972278 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -136,8 +136,8 @@ def ExecuteRegionOp : SCF_Op<"execute_region", [
def ForOp : SCF_Op<"for",
[AutomaticAllocationScope, DeclareOpInterfaceMethods<LoopLikeOpInterface,
["getInitsMutable", "getLoopResults", "getRegionIterArgs",
- "getInductionVars", "getLowerBounds", "getSteps",
- "getUpperBounds", "getYieldedValuesMutable",
+ "getLoopInductionVars", "getLoopLowerBounds", "getLoopSteps",
+ "getLoopUpperBounds", "getYieldedValuesMutable",
"promoteIfSingleIteration", "replaceWithAdditionalYields",
"yieldTiledValuesAndReplace"]>,
AllTypesMatch<["lowerBound", "upperBound", "step"]>,
@@ -301,8 +301,8 @@ def ForallOp : SCF_Op<"forall", [
AttrSizedOperandSegments,
AutomaticAllocationScope,
DeclareOpInterfaceMethods<LoopLikeOpInterface,
- ["getInitsMutable", "getRegionIterArgs", "getInductionVars",
- "getLowerBounds", "getUpperBounds", "getSteps",
+ ["getInitsMutable", "getRegionIterArgs", "getLoopInductionVars",
+ "getLoopLowerBounds", "getLoopUpperBounds", "getLoopSteps",
"promoteIfSingleIteration", "yieldTiledValuesAndReplace"]>,
RecursiveMemoryEffects,
SingleBlockImplicitTerminator<"scf::InParallelOp">,
@@ -510,23 +510,26 @@ def ForallOp : SCF_Op<"forall", [
];
let extraClassDeclaration = [{
+ SmallVector<Value> getInductionVars() {
+ return getLoopInductionVars();
+ }
// Get lower bounds as OpFoldResult.
SmallVector<OpFoldResult> getMixedLowerBound() {
- auto maybeLowerBounds = getLowerBounds();
+ auto maybeLowerBounds = getLoopLowerBounds();
assert(maybeLowerBounds.has_value() && "expected values");
return *maybeLowerBounds;
}
// Get upper bounds as OpFoldResult.
SmallVector<OpFoldResult> getMixedUpperBound() {
- auto maybeUpperBounds = getUpperBounds();
+ auto maybeUpperBounds = getLoopUpperBounds();
assert(maybeUpperBounds.has_value() && "expected values");
return *maybeUpperBounds;
}
// Get steps as OpFoldResult.
SmallVector<OpFoldResult> getMixedStep() {
- auto maybeSteps = getSteps();
+ auto maybeSteps = getLoopSteps();
assert(maybeSteps.has_value() && "expected values");
return *maybeSteps;
}
@@ -588,7 +591,7 @@ def ForallOp : SCF_Op<"forall", [
}
::mlir::Value getInductionVar(int64_t idx) {
- return getInductionVars()[idx];
+ return getLoopInductionVars()[idx];
}
::mlir::Block::BlockArgListType getRegionOutArgs() {
@@ -764,8 +767,8 @@ def IfOp : SCF_Op<"if", [DeclareOpInterfaceMethods<RegionBranchOpInterface, [
def ParallelOp : SCF_Op<"parallel",
[AutomaticAllocationScope,
AttrSizedOperandSegments,
- DeclareOpInterfaceMethods<LoopLikeOpInterface, ["getInductionVars",
- "getLowerBounds", "getUpperBounds", "getSteps"]>,
+ DeclareOpInterfaceMethods<LoopLikeOpInterface, ["getLoopInductionVars",
+ "getLoopLowerBounds", "getLoopUpperBounds", "getLoopSteps"]>,
RecursiveMemoryEffects,
DeclareOpInterfaceMethods<RegionBranchOpInterface>,
SingleBlockImplicitTerminator<"scf::ReduceOp">,
@@ -845,6 +848,10 @@ def ParallelOp : SCF_Op<"parallel",
];
let extraClassDeclaration = [{
+ // Get induction variables.
+ SmallVector<Value> getInductionVars() {
+ return getLoopInductionVars();
+ }
unsigned getNumLoops() { return getStep().size(); }
unsigned getNumReductions() { return getInitVals().size(); }
}];
diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.td b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
index bace8f8384d44..5312ace4db68e 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
@@ -96,7 +96,7 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
Return all induction variables.
}],
/*retTy=*/"::llvm::SmallVector<::mlir::Value>",
- /*methodName=*/"getInductionVars",
+ /*methodName=*/"getLoopInductionVars",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
@@ -107,7 +107,7 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
Return all lower bounds.
}],
/*retTy=*/"::std::optional<::llvm::SmallVector<::mlir::OpFoldResult>>",
- /*methodName=*/"getLowerBounds",
+ /*methodName=*/"getLoopLowerBounds",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
@@ -118,7 +118,7 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
Return all steps.
}],
/*retTy=*/"std::optional<::llvm::SmallVector<::mlir::OpFoldResult>>",
- /*methodName=*/"getSteps",
+ /*methodName=*/"getLoopSteps",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
@@ -129,7 +129,7 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
Return all upper bounds.
}],
/*retTy=*/"::std::optional<::llvm::SmallVector<::mlir::OpFoldResult>>",
- /*methodName=*/"getUpperBounds",
+ /*methodName=*/"getLoopUpperBounds",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
@@ -234,7 +234,7 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
/// If there is a single induction variable return it, otherwise return
/// std::nullopt.
::std::optional<::mlir::Value> getSingleInductionVar() {
- auto inductionVars = this->getInductionVars();
+ auto inductionVars = this->getLoopInductionVars();
if (inductionVars.size() == 1)
return inductionVars[0];
return std::nullopt;
@@ -242,7 +242,7 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
/// Return the single lower bound value or attribute if it exists, otherwise
/// return std::nullopt.
::std::optional<::mlir::OpFoldResult> getSingleLowerBound() {
- auto lowerBounds = this->getLowerBounds();
+ auto lowerBounds = this->getLoopLowerBounds();
if (lowerBounds.has_value() && (*lowerBounds).size() == 1)
return (*lowerBounds)[0];
return std::nullopt;
@@ -250,7 +250,7 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
/// Return the single step value or attribute if it exists, otherwise
/// return std::nullopt.
::std::optional<::mlir::OpFoldResult> getSingleStep() {
- auto steps = this->getSteps();
+ auto steps = this->getLoopSteps();
if (steps.has_value() && (*steps).size() == 1)
return (*steps)[0];
return std::nullopt;
@@ -258,7 +258,7 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
/// Return the single upper bound value or attribute if it exists, otherwise
/// return std::nullopt.
::std::optional<::mlir::OpFoldResult> getSingleUpperBound() {
- auto upperBounds = this->getUpperBounds();
+ auto upperBounds = this->getLoopUpperBounds();
if (upperBounds.has_value() && (*upperBounds).size() == 1)
return (*upperBounds)[0];
return std::nullopt;
diff --git a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
index 20487b32e3fe0..10ccd5c97783b 100644
--- a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
+++ b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
@@ -196,8 +196,8 @@ class AffineParallelLowering : public OpRewritePattern<AffineParallelOp> {
return rewriter.notifyMatchFailure(op, "couldn't convert upper bounds");
upperBoundTuple.push_back(upper);
}
- steps.reserve(op.getStep().size());
- for (int64_t step : op.getStep())
+ steps.reserve(op.getSteps().size());
+ for (int64_t step : op.getSteps())
steps.push_back(rewriter.create<arith::ConstantIndexOp>(loc, step));
// Get the terminator op.
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 5467c60242664..d5cb04743dfb9 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -2454,11 +2454,11 @@ bool AffineForOp::matchingBoundOperandList() {
SmallVector<Region *> AffineForOp::getLoopRegions() { return {&getRegion()}; }
-SmallVector<Value> AffineForOp::getInductionVars() {
+SmallVector<Value> AffineForOp::getLoopInductionVars() {
return {getInductionVar()};
}
-std::optional<SmallVector<OpFoldResult>> AffineForOp::getLowerBounds() {
+std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopLowerBounds() {
if (!hasConstantLowerBound())
return std::nullopt;
OpBuilder b(getContext());
@@ -2466,13 +2466,13 @@ std::optional<SmallVector<OpFoldResult>> AffineForOp::getLowerBounds() {
OpFoldResult(b.getI64IntegerAttr(getConstantLowerBound()))};
}
-std::optional<SmallVector<OpFoldResult>> AffineForOp::getSteps() {
+std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopSteps() {
OpBuilder b(getContext());
return SmallVector<OpFoldResult>{
OpFoldResult(b.getI64IntegerAttr(getStepAsInt()))};
}
-std::optional<SmallVector<OpFoldResult>> AffineForOp::getUpperBounds() {
+std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopUpperBounds() {
if (!hasConstantUpperBound())
return {};
OpBuilder b(getContext());
@@ -3758,7 +3758,7 @@ SmallVector<Region *> AffineParallelOp::getLoopRegions() {
return {&getRegion()};
}
-unsigned AffineParallelOp::getNumDims() { return getStep().size(); }
+unsigned AffineParallelOp::getNumDims() { return getSteps().size(); }
AffineParallelOp::operand_range AffineParallelOp::getLowerBoundsOperands() {
return getOperands().take_front(getLowerBoundsMap().getNumInputs());
@@ -3843,7 +3843,7 @@ void AffineParallelOp::setUpperBounds(ValueRange ubOperands, AffineMap map) {
}
void AffineParallelOp::setSteps(ArrayRef<int64_t> newSteps) {
- setStepAttr(getBodyBuilder().getI64ArrayAttr(newSteps));
+ setStepsAttr(getBodyBuilder().getI64ArrayAttr(newSteps));
}
// check whether resultType match op or not in affine.parallel
@@ -3893,14 +3893,14 @@ LogicalResult AffineParallelOp::verify() {
auto numDims = getNumDims();
if (getLowerBoundsGroups().getNumElements() != numDims ||
getUpperBoundsGroups().getNumElements() != numDims ||
- getStep().size() != numDims || getBody()->getNumArguments() != numDims) {
+ getSteps().size() != numDims || getBody()->getNumArguments() != numDims) {
return emitOpError() << "the number of region arguments ("
<< getBody()->getNumArguments()
<< ") and the number of map groups for lower ("
<< getLowerBoundsGroups().getNumElements()
<< ") and upper bound ("
<< getUpperBoundsGroups().getNumElements()
- << "), and the number of steps (" << getStep().size()
+ << "), and the number of steps (" << getSteps().size()
<< ") must all match";
}
@@ -4018,7 +4018,7 @@ void AffineParallelOp::print(OpAsmPrinter &p) {
printMinMaxBound(p, getUpperBoundsMapAttr(), getUpperBoundsGroupsAttr(),
getUpperBoundsOperands(), "min");
p << ')';
- SmallVector<int64_t, 8> steps = getStep();
+ SmallVector<int64_t, 8> steps = getSteps();
bool elideSteps = llvm::all_of(steps, [](int64_t step) { return step == 1; });
if (!elideSteps) {
p << " step (";
diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index a652ee4a488d1..f46381403bc52 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -494,7 +494,7 @@ void mlir::affine::normalizeAffineParallel(AffineParallelOp op) {
return;
AffineMap lbMap = op.getLowerBoundsMap();
- SmallVector<int64_t, 8> steps = op.getStep();
+ SmallVector<int64_t, 8> steps = op.getSteps();
// No need to do any work if the parallel op is already normalized.
bool isAlreadyNormalized =
llvm::all_of(llvm::zip(steps, lbMap.getResults()), [](auto tuple) {
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index a930f8c71454c..e921177f73215 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -378,17 +378,17 @@ LogicalResult ForOp::verifyRegions() {
return success();
}
-SmallVector<Value> ForOp::getInductionVars() { return {getInductionVar()}; }
+SmallVector<Value> ForOp::getLoopInductionVars() { return {getInductionVar()}; }
-std::optional<SmallVector<OpFoldResult>> ForOp::getLowerBounds() {
+std::optional<SmallVector<OpFoldResult>> ForOp::getLoopLowerBounds() {
return SmallVector<OpFoldResult, 1>{OpFoldResult(getLowerBound())};
}
-std::optional<SmallVector<OpFoldResult>> ForOp::getSteps() {
+std::optional<SmallVector<OpFoldResult>> ForOp::getLoopSteps() {
return SmallVector<OpFoldResult, 1>{OpFoldResult(getStep())};
}
-std::optional<SmallVector<OpFoldResult>> ForOp::getUpperBounds() {
+std::optional<SmallVector<OpFoldResult>> ForOp::getLoopUpperBounds() {
return SmallVector<OpFoldResult, 1>{OpFoldResult(getUpperBound())};
}
@@ -1426,24 +1426,24 @@ SmallVector<Operation *> ForallOp::getCombiningOps(BlockArgument bbArg) {
return storeOps;
}
-SmallVector<Value> ForallOp::getInductionVars() {
+SmallVector<Value> ForallOp::getLoopInductionVars() {
return SmallVector<Value>{getBody()->getArguments().take_front(getRank())};
}
// Get lower bounds as OpFoldResult.
-std::optional<SmallVector<OpFoldResult>> ForallOp::getLowerBounds() {
+std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopLowerBounds() {
Builder b(getOperation()->getContext());
return getMixedValues(getStaticLowerBound(), getDynamicLowerBound(), b);
}
// Get upper bounds as OpFoldResult.
-std::optional<SmallVector<OpFoldResult>> ForallOp::getUpperBounds() {
+std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopUpperBounds() {
Builder b(getOperation()->getContext());
return getMixedValues(getStaticUpperBound(), getDynamicUpperBound(), b);
}
// Get steps as OpFoldResult.
-std::optional<SmallVector<OpFoldResult>> ForallOp::getSteps() {
+std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopSteps() {
Builder b(getOperation()->getContext());
return getMixedValues(getStaticStep(), getDynamicStep(), b);
}
@@ -3004,19 +3004,19 @@ void ParallelOp::print(OpAsmPrinter &p) {
SmallVector<Region *> ParallelOp::getLoopRegions() { return {&getRegion()}; }
-SmallVector<Value> ParallelOp::getInductionVars() {
+SmallVector<Value> ParallelOp::getLoopInductionVars() {
return SmallVector<Value>{getBody()->getArguments()};
}
-std::optional<SmallVector<OpFoldResult>> ParallelOp::getLowerBounds() {
+std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopLowerBounds() {
return getLowerBound();
}
-std::optional<SmallVector<OpFoldResult>> ParallelOp::getUpperBounds() {
+std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopUpperBounds() {
return getUpperBound();
}
-std::optional<SmallVector<OpFoldResult>> ParallelOp::getSteps() {
+std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopSteps() {
return getStep();
}
diff --git a/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp b/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp
index 07504a99fecd3..75cd2bfb01de0 100644
--- a/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp
+++ b/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp
@@ -40,17 +40,18 @@ class SCFLoopLikeTest : public ::testing::Test {
EXPECT_TRUE(maybeSingleIndVar.has_value());
std::optional<SmallVector<OpFoldResult>> maybeLb =
- loopLikeOp.getLowerBounds();
+ loopLikeOp.getLoopLowerBounds();
EXPECT_TRUE(maybeLb.has_value());
EXPECT_EQ((*maybeLb).size(), 1u);
std::optional<SmallVector<OpFoldResult>> maybeUb =
- loopLikeOp.getUpperBounds();
+ loopLikeOp.getLoopUpperBounds();
EXPECT_TRUE(maybeUb.has_value());
EXPECT_EQ((*maybeUb).size(), 1u);
- std::optional<SmallVector<OpFoldResult>> maybeStep = loopLikeOp.getSteps();
+ std::optional<SmallVector<OpFoldResult>> maybeStep =
+ loopLikeOp.getLoopSteps();
EXPECT_TRUE(maybeStep.has_value());
EXPECT_EQ((*maybeStep).size(), 1u);
- EXPECT_EQ(loopLikeOp.getInductionVars().size(), 1u);
+ EXPECT_EQ(loopLikeOp.getLoopInductionVars().size(), 1u);
}
void checkMultidimensional(LoopLikeOpInterface loopLikeOp) {
@@ -67,17 +68,18 @@ class SCFLoopLikeTest : public ::testing::Test {
EXPECT_FALSE(maybeSingleIndVar.has_value());
std::optional<SmallVector<OpFoldResult>> maybeLb =
- loopLikeOp.getLowerBounds();
+ loopLikeOp.getLoopLowerBounds();
EXPECT_TRUE(maybeLb.has_value());
EXPECT_EQ((*maybeLb).size(), 2u);
std::optional<SmallVector<OpFoldResult>> maybeUb =
- loopLikeOp.getUpperBounds();
+ loopLikeOp.getLoopUpperBounds();
EXPECT_TRUE(maybeUb.has_value());
EXPECT_EQ((*maybeUb).size(), 2u);
- std::optional<SmallVector<OpFoldResult>> maybeStep = loopLikeOp.getSteps();
+ std::optional<SmallVector<OpFoldResult>> maybeStep =
+ loopLikeOp.getLoopSteps();
EXPECT_TRUE(maybeStep.has_value());
EXPECT_EQ((*maybeStep).size(), 2u);
- EXPECT_EQ(loopLikeOp.getInductionVars().size(), 2u);
+ EXPECT_EQ(loopLikeOp.getLoopInductionVars().size(), 2u);
}
MLIRContext context;
>From e0e526210a5ab6ce28fbc5fa5ee24f79cb1ee9a8 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Thu, 6 Jun 2024 16:04:47 -0500
Subject: [PATCH 10/13] return option induction vars
---
mlir/include/mlir/Dialect/SCF/IR/SCFOps.td | 10 +++++++---
mlir/include/mlir/Interfaces/LoopLikeInterface.td | 8 ++++----
mlir/lib/Dialect/Affine/IR/AffineOps.cpp | 4 ++--
mlir/lib/Dialect/SCF/IR/SCF.cpp | 8 +++++---
mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp | 10 ++++++++--
5 files changed, 26 insertions(+), 14 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 3704b15972278..d425c1c2a47b4 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -511,7 +511,9 @@ def ForallOp : SCF_Op<"forall", [
let extraClassDeclaration = [{
SmallVector<Value> getInductionVars() {
- return getLoopInductionVars();
+ auto maybeInductionVars = getLoopInductionVars();;
+ assert(maybeInductionVars.has_value() && "expected values");
+ return *maybeInductionVars;
}
// Get lower bounds as OpFoldResult.
SmallVector<OpFoldResult> getMixedLowerBound() {
@@ -591,7 +593,7 @@ def ForallOp : SCF_Op<"forall", [
}
::mlir::Value getInductionVar(int64_t idx) {
- return getLoopInductionVars()[idx];
+ return getInductionVars()[idx];
}
::mlir::Block::BlockArgListType getRegionOutArgs() {
@@ -850,7 +852,9 @@ def ParallelOp : SCF_Op<"parallel",
let extraClassDeclaration = [{
// Get induction variables.
SmallVector<Value> getInductionVars() {
- return getLoopInductionVars();
+ auto maybeInductionVars = getLoopInductionVars();;
+ assert(maybeInductionVars.has_value() && "expected values");
+ return *maybeInductionVars;
}
unsigned getNumLoops() { return getStep().size(); }
unsigned getNumReductions() { return getInitVals().size(); }
diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.td b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
index 5312ace4db68e..2e6aabda30b07 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
@@ -95,12 +95,12 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
InterfaceMethod<[{
Return all induction variables.
}],
- /*retTy=*/"::llvm::SmallVector<::mlir::Value>",
+ /*retTy=*/"::std::optional<::llvm::SmallVector<::mlir::Value>>",
/*methodName=*/"getLoopInductionVars",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- return {};
+ return std::nullopt;
}]
>,
InterfaceMethod<[{
@@ -235,8 +235,8 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
/// std::nullopt.
::std::optional<::mlir::Value> getSingleInductionVar() {
auto inductionVars = this->getLoopInductionVars();
- if (inductionVars.size() == 1)
- return inductionVars[0];
+ if (inductionVars.has_value() && (*inductionVars).size() == 1)
+ return (*inductionVars)[0];
return std::nullopt;
}
/// Return the single lower bound value or attribute if it exists, otherwise
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index d5cb04743dfb9..0a58d2fdb02f5 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -2454,8 +2454,8 @@ bool AffineForOp::matchingBoundOperandList() {
SmallVector<Region *> AffineForOp::getLoopRegions() { return {&getRegion()}; }
-SmallVector<Value> AffineForOp::getLoopInductionVars() {
- return {getInductionVar()};
+std::optional<SmallVector<Value>> AffineForOp::getLoopInductionVars() {
+ return SmallVector<Value>{getInductionVar()};
}
std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopLowerBounds() {
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index e921177f73215..c00579443ea29 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -378,7 +378,9 @@ LogicalResult ForOp::verifyRegions() {
return success();
}
-SmallVector<Value> ForOp::getLoopInductionVars() { return {getInductionVar()}; }
+std::optional<SmallVector<Value>> ForOp::getLoopInductionVars() {
+ return SmallVector<Value, 1>{getInductionVar()};
+}
std::optional<SmallVector<OpFoldResult>> ForOp::getLoopLowerBounds() {
return SmallVector<OpFoldResult, 1>{OpFoldResult(getLowerBound())};
@@ -1426,7 +1428,7 @@ SmallVector<Operation *> ForallOp::getCombiningOps(BlockArgument bbArg) {
return storeOps;
}
-SmallVector<Value> ForallOp::getLoopInductionVars() {
+std::optional<SmallVector<Value>> ForallOp::getLoopInductionVars() {
return SmallVector<Value>{getBody()->getArguments().take_front(getRank())};
}
@@ -3004,7 +3006,7 @@ void ParallelOp::print(OpAsmPrinter &p) {
SmallVector<Region *> ParallelOp::getLoopRegions() { return {&getRegion()}; }
-SmallVector<Value> ParallelOp::getLoopInductionVars() {
+std::optional<SmallVector<Value>> ParallelOp::getLoopInductionVars() {
return SmallVector<Value>{getBody()->getArguments()};
}
diff --git a/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp b/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp
index 75cd2bfb01de0..20dbc8d362d27 100644
--- a/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp
+++ b/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp
@@ -51,7 +51,10 @@ class SCFLoopLikeTest : public ::testing::Test {
loopLikeOp.getLoopSteps();
EXPECT_TRUE(maybeStep.has_value());
EXPECT_EQ((*maybeStep).size(), 1u);
- EXPECT_EQ(loopLikeOp.getLoopInductionVars().size(), 1u);
+ std::optional<SmallVector<Value>> maybeInductionVars =
+ loopLikeOp.getLoopInductionVars();
+ EXPECT_TRUE(maybeInductionVars.has_value());
+ EXPECT_EQ((*maybeInductionVars).size(), 1u);
}
void checkMultidimensional(LoopLikeOpInterface loopLikeOp) {
@@ -79,7 +82,10 @@ class SCFLoopLikeTest : public ::testing::Test {
loopLikeOp.getLoopSteps();
EXPECT_TRUE(maybeStep.has_value());
EXPECT_EQ((*maybeStep).size(), 2u);
- EXPECT_EQ(loopLikeOp.getLoopInductionVars().size(), 2u);
+ std::optional<SmallVector<Value>> maybeInductionVars =
+ loopLikeOp.getLoopInductionVars();
+ EXPECT_TRUE(maybeInductionVars.has_value());
+ EXPECT_EQ((*maybeInductionVars).size(), 2u);
}
MLIRContext context;
>From 7115a6e08bba43fe9750f8cef5c73f6be1b373fd Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Fri, 7 Jun 2024 11:45:44 -0500
Subject: [PATCH 11/13] address review comments
---
mlir/include/mlir/Dialect/SCF/IR/SCFOps.td | 19 ++++++------
.../mlir/Interfaces/LoopLikeInterface.td | 30 +++++++++++++------
mlir/lib/Dialect/SCF/IR/SCF.cpp | 8 ++---
.../Dialect/SCF/LoopLikeSCFOpsTest.cpp | 16 +++++-----
4 files changed, 43 insertions(+), 30 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index d425c1c2a47b4..f35ea962bea16 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -510,28 +510,29 @@ def ForallOp : SCF_Op<"forall", [
];
let extraClassDeclaration = [{
+ /// Get induction variables.
SmallVector<Value> getInductionVars() {
- auto maybeInductionVars = getLoopInductionVars();;
+ std::optional<SmallVector<Value>> maybeInductionVars = getLoopInductionVars();
assert(maybeInductionVars.has_value() && "expected values");
return *maybeInductionVars;
}
- // Get lower bounds as OpFoldResult.
+ /// Get lower bounds as OpFoldResult.
SmallVector<OpFoldResult> getMixedLowerBound() {
- auto maybeLowerBounds = getLoopLowerBounds();
+ std::optional<SmallVector<OpFoldResult>> maybeLowerBounds = getLoopLowerBounds();
assert(maybeLowerBounds.has_value() && "expected values");
return *maybeLowerBounds;
}
- // Get upper bounds as OpFoldResult.
+ /// Get upper bounds as OpFoldResult.
SmallVector<OpFoldResult> getMixedUpperBound() {
- auto maybeUpperBounds = getLoopUpperBounds();
+ std::optional<SmallVector<OpFoldResult>> maybeUpperBounds = getLoopUpperBounds();
assert(maybeUpperBounds.has_value() && "expected values");
return *maybeUpperBounds;
}
- // Get steps as OpFoldResult.
+ /// Get steps as OpFoldResult.
SmallVector<OpFoldResult> getMixedStep() {
- auto maybeSteps = getLoopSteps();
+ std::optional<SmallVector<OpFoldResult>> maybeSteps = getLoopSteps();
assert(maybeSteps.has_value() && "expected values");
return *maybeSteps;
}
@@ -850,9 +851,9 @@ def ParallelOp : SCF_Op<"parallel",
];
let extraClassDeclaration = [{
- // Get induction variables.
+ /// Get induction variables.
SmallVector<Value> getInductionVars() {
- auto maybeInductionVars = getLoopInductionVars();;
+ std::optional<SmallVector<Value>> maybeInductionVars = getLoopInductionVars();;
assert(maybeInductionVars.has_value() && "expected values");
return *maybeInductionVars;
}
diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.td b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
index 2e6aabda30b07..b748d5e29114a 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
@@ -93,47 +93,59 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
}]
>,
InterfaceMethod<[{
- Return all induction variables.
+ Return all induction variables, if they exist. If the op has no notion of
+ induction variable, then return std::nullopt. If it does have
+ a notion but an instance doesn't have induction variables, then
+ return empty vector.
}],
/*retTy=*/"::std::optional<::llvm::SmallVector<::mlir::Value>>",
/*methodName=*/"getLoopInductionVars",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- return std::nullopt;
+ return ::std::nullopt;
}]
>,
InterfaceMethod<[{
- Return all lower bounds.
+ Return all lower bounds, if they exist. If the op has no notion of
+ lower bounds, then return std::nullopt. If it does have
+ a notion but an instance doesn't have lower bounds, then
+ return empty vector.
}],
/*retTy=*/"::std::optional<::llvm::SmallVector<::mlir::OpFoldResult>>",
/*methodName=*/"getLoopLowerBounds",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- return std::nullopt;
+ return ::std::nullopt;
}]
>,
InterfaceMethod<[{
- Return all steps.
+ Return all steps, if they exist. If the op has no notion of
+ steps, then return std::nullopt. If it does have
+ a notion but an instance doesn't have steps, then
+ return empty vector.
}],
- /*retTy=*/"std::optional<::llvm::SmallVector<::mlir::OpFoldResult>>",
+ /*retTy=*/"::std::optional<::llvm::SmallVector<::mlir::OpFoldResult>>",
/*methodName=*/"getLoopSteps",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- return std::nullopt;
+ return ::std::nullopt;
}]
>,
InterfaceMethod<[{
- Return all upper bounds.
+ Return all upper bounds, if they exist. If the op has no notion of
+ lower bounds, then return std::nullopt. If it does have
+ a notion but an instance doesn't have lower bounds, then
+ return empty vector.
}],
/*retTy=*/"::std::optional<::llvm::SmallVector<::mlir::OpFoldResult>>",
/*methodName=*/"getLoopUpperBounds",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- return std::nullopt;
+ return ::std::nullopt;
}]
>,
InterfaceMethod<[{
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index c00579443ea29..5e94f4dc612a7 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -379,19 +379,19 @@ LogicalResult ForOp::verifyRegions() {
}
std::optional<SmallVector<Value>> ForOp::getLoopInductionVars() {
- return SmallVector<Value, 1>{getInductionVar()};
+ return SmallVector<Value>{getInductionVar()};
}
std::optional<SmallVector<OpFoldResult>> ForOp::getLoopLowerBounds() {
- return SmallVector<OpFoldResult, 1>{OpFoldResult(getLowerBound())};
+ return SmallVector<OpFoldResult>{OpFoldResult(getLowerBound())};
}
std::optional<SmallVector<OpFoldResult>> ForOp::getLoopSteps() {
- return SmallVector<OpFoldResult, 1>{OpFoldResult(getStep())};
+ return SmallVector<OpFoldResult>{OpFoldResult(getStep())};
}
std::optional<SmallVector<OpFoldResult>> ForOp::getLoopUpperBounds() {
- return SmallVector<OpFoldResult, 1>{OpFoldResult(getUpperBound())};
+ return SmallVector<OpFoldResult>{OpFoldResult(getUpperBound())};
}
std::optional<ResultRange> ForOp::getLoopResults() { return getResults(); }
diff --git a/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp b/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp
index 20dbc8d362d27..53a4af14d119a 100644
--- a/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp
+++ b/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp
@@ -41,19 +41,19 @@ class SCFLoopLikeTest : public ::testing::Test {
std::optional<SmallVector<OpFoldResult>> maybeLb =
loopLikeOp.getLoopLowerBounds();
- EXPECT_TRUE(maybeLb.has_value());
+ ASSERT_TRUE(maybeLb.has_value());
EXPECT_EQ((*maybeLb).size(), 1u);
std::optional<SmallVector<OpFoldResult>> maybeUb =
loopLikeOp.getLoopUpperBounds();
- EXPECT_TRUE(maybeUb.has_value());
+ ASSERT_TRUE(maybeUb.has_value());
EXPECT_EQ((*maybeUb).size(), 1u);
std::optional<SmallVector<OpFoldResult>> maybeStep =
loopLikeOp.getLoopSteps();
- EXPECT_TRUE(maybeStep.has_value());
+ ASSERT_TRUE(maybeStep.has_value());
EXPECT_EQ((*maybeStep).size(), 1u);
std::optional<SmallVector<Value>> maybeInductionVars =
loopLikeOp.getLoopInductionVars();
- EXPECT_TRUE(maybeInductionVars.has_value());
+ ASSERT_TRUE(maybeInductionVars.has_value());
EXPECT_EQ((*maybeInductionVars).size(), 1u);
}
@@ -72,19 +72,19 @@ class SCFLoopLikeTest : public ::testing::Test {
std::optional<SmallVector<OpFoldResult>> maybeLb =
loopLikeOp.getLoopLowerBounds();
- EXPECT_TRUE(maybeLb.has_value());
+ ASSERT_TRUE(maybeLb.has_value());
EXPECT_EQ((*maybeLb).size(), 2u);
std::optional<SmallVector<OpFoldResult>> maybeUb =
loopLikeOp.getLoopUpperBounds();
- EXPECT_TRUE(maybeUb.has_value());
+ ASSERT_TRUE(maybeUb.has_value());
EXPECT_EQ((*maybeUb).size(), 2u);
std::optional<SmallVector<OpFoldResult>> maybeStep =
loopLikeOp.getLoopSteps();
- EXPECT_TRUE(maybeStep.has_value());
+ ASSERT_TRUE(maybeStep.has_value());
EXPECT_EQ((*maybeStep).size(), 2u);
std::optional<SmallVector<Value>> maybeInductionVars =
loopLikeOp.getLoopInductionVars();
- EXPECT_TRUE(maybeInductionVars.has_value());
+ ASSERT_TRUE(maybeInductionVars.has_value());
EXPECT_EQ((*maybeInductionVars).size(), 2u);
}
>From 6336fdf28f06c7525bcf6822a386f8c4cabe3c2d Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Fri, 7 Jun 2024 15:25:40 -0500
Subject: [PATCH 12/13] update after rebase
---
mlir/lib/Dialect/SCF/Utils/Utils.cpp | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index ce20730459c2a..e3660e89fb684 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -1198,9 +1198,9 @@ static bool hasNestedParallelOp(scf::ParallelOp ploop) {
bool mlir::checkFusionStructuralLegality(LoopLikeOpInterface &target,
LoopLikeOpInterface &source) {
auto iterSpaceEq =
- target.getMixedLowerBound() == source.getMixedLowerBound() &&
- target.getMixedUpperBound() == source.getMixedUpperBound() &&
- target.getMixedStep() == source.getMixedStep();
+ target.getLoopLowerBounds() == source.getLoopLowerBounds() &&
+ target.getLoopUpperBounds() == source.getLoopUpperBounds() &&
+ target.getLoopSteps() == source.getLoopSteps();
auto forAllTarget = dyn_cast<scf::ForallOp>(*target);
auto forAllSource = dyn_cast<scf::ForallOp>(*source);
if (forAllTarget && forAllSource)
>From 86406c335dd216ac91a941102c060bc680af10b1 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Sat, 8 Jun 2024 23:31:57 -0500
Subject: [PATCH 13/13] refactor main parallel fusion logic from fuseIfLegal to
util func
---
mlir/include/mlir/Dialect/SCF/Utils/Utils.h | 6 -
.../SCF/Transforms/ParallelLoopFusion.cpp | 159 +++++++++++++++++-
mlir/lib/Dialect/SCF/Utils/Utils.cpp | 140 ++++++---------
3 files changed, 208 insertions(+), 97 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
index ab9d154aa480d..ac4434b337890 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
@@ -163,12 +163,6 @@ void getPerfectlyNestedLoops(SmallVectorImpl<scf::ForOp> &nestedLoops,
bool checkFusionStructuralLegality(LoopLikeOpInterface &target,
LoopLikeOpInterface &source);
-/// Prepends operations of firstPloop's body into secondPloop's body.
-/// Updates secondPloop with new loop.
-void fuseIfLegal(scf::ParallelOp firstPloop, scf::ParallelOp &secondPloop,
- OpBuilder builder,
- llvm::function_ref<bool(Value, Value)> mayAlias);
-
/// Given two scf.forall loops, `target` and `source`, fuses `target` into
/// `source`. Assumes that the given loops are siblings and are independent of
/// each other.
diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
index abac91cfaf7d9..326a8f93162b9 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
@@ -31,6 +31,163 @@ namespace mlir {
using namespace mlir;
using namespace mlir::scf;
+/// Verify there are no nested ParallelOps.
+static bool hasNestedParallelOp(ParallelOp ploop) {
+ auto walkResult =
+ ploop.getBody()->walk([](ParallelOp) { return WalkResult::interrupt(); });
+ return walkResult.wasInterrupted();
+}
+
+/// Verify equal iteration spaces.
+static bool equalIterationSpaces(ParallelOp firstPloop,
+ ParallelOp secondPloop) {
+ if (firstPloop.getNumLoops() != secondPloop.getNumLoops())
+ return false;
+
+ auto matchOperands = [&](const OperandRange &lhs,
+ const OperandRange &rhs) -> bool {
+ // TODO: Extend this to support aliases and equal constants.
+ return std::equal(lhs.begin(), lhs.end(), rhs.begin());
+ };
+ return matchOperands(firstPloop.getLowerBound(),
+ secondPloop.getLowerBound()) &&
+ matchOperands(firstPloop.getUpperBound(),
+ secondPloop.getUpperBound()) &&
+ matchOperands(firstPloop.getStep(), secondPloop.getStep());
+}
+
+/// Checks if the parallel loops have mixed access to the same buffers. Returns
+/// `true` if the first parallel loop writes to the same indices that the second
+/// loop reads.
+static bool haveNoReadsAfterWriteExceptSameIndex(
+ ParallelOp firstPloop, ParallelOp secondPloop,
+ const IRMapping &firstToSecondPloopIndices,
+ llvm::function_ref<bool(Value, Value)> mayAlias) {
+ DenseMap<Value, SmallVector<ValueRange, 1>> bufferStores;
+ SmallVector<Value> bufferStoresVec;
+ firstPloop.getBody()->walk([&](memref::StoreOp store) {
+ bufferStores[store.getMemRef()].push_back(store.getIndices());
+ bufferStoresVec.emplace_back(store.getMemRef());
+ });
+ auto walkResult = secondPloop.getBody()->walk([&](memref::LoadOp load) {
+ Value loadMem = load.getMemRef();
+ // Stop if the memref is defined in secondPloop body. Careful alias analysis
+ // is needed.
+ auto *memrefDef = loadMem.getDefiningOp();
+ if (memrefDef && memrefDef->getBlock() == load->getBlock())
+ return WalkResult::interrupt();
+
+ for (Value store : bufferStoresVec)
+ if (store != loadMem && mayAlias(store, loadMem))
+ return WalkResult::interrupt();
+
+ auto write = bufferStores.find(loadMem);
+ if (write == bufferStores.end())
+ return WalkResult::advance();
+
+ // Check that at last one store was retrieved
+ if (!write->second.size())
+ return WalkResult::interrupt();
+
+ auto storeIndices = write->second.front();
+
+ // Multiple writes to the same memref are allowed only on the same indices
+ for (const auto &othStoreIndices : write->second) {
+ if (othStoreIndices != storeIndices)
+ return WalkResult::interrupt();
+ }
+
+ // Check that the load indices of secondPloop coincide with store indices of
+ // firstPloop for the same memrefs.
+ auto loadIndices = load.getIndices();
+ if (storeIndices.size() != loadIndices.size())
+ return WalkResult::interrupt();
+ for (int i = 0, e = storeIndices.size(); i < e; ++i) {
+ if (firstToSecondPloopIndices.lookupOrDefault(storeIndices[i]) !=
+ loadIndices[i]) {
+ auto *storeIndexDefOp = storeIndices[i].getDefiningOp();
+ auto *loadIndexDefOp = loadIndices[i].getDefiningOp();
+ if (storeIndexDefOp && loadIndexDefOp) {
+ if (!isMemoryEffectFree(storeIndexDefOp))
+ return WalkResult::interrupt();
+ if (!isMemoryEffectFree(loadIndexDefOp))
+ return WalkResult::interrupt();
+ if (!OperationEquivalence::isEquivalentTo(
+ storeIndexDefOp, loadIndexDefOp,
+ [&](Value storeIndex, Value loadIndex) {
+ if (firstToSecondPloopIndices.lookupOrDefault(storeIndex) !=
+ firstToSecondPloopIndices.lookupOrDefault(loadIndex))
+ return failure();
+ else
+ return success();
+ },
+ /*markEquivalent=*/nullptr,
+ OperationEquivalence::Flags::IgnoreLocations)) {
+ return WalkResult::interrupt();
+ }
+ } else
+ return WalkResult::interrupt();
+ }
+ }
+ return WalkResult::advance();
+ });
+ return !walkResult.wasInterrupted();
+}
+
+/// Analyzes dependencies in the most primitive way by checking simple read and
+/// write patterns.
+static LogicalResult
+verifyDependencies(ParallelOp firstPloop, ParallelOp secondPloop,
+ const IRMapping &firstToSecondPloopIndices,
+ llvm::function_ref<bool(Value, Value)> mayAlias) {
+ if (!haveNoReadsAfterWriteExceptSameIndex(
+ firstPloop, secondPloop, firstToSecondPloopIndices, mayAlias))
+ return failure();
+
+ IRMapping secondToFirstPloopIndices;
+ secondToFirstPloopIndices.map(secondPloop.getBody()->getArguments(),
+ firstPloop.getBody()->getArguments());
+ return success(haveNoReadsAfterWriteExceptSameIndex(
+ secondPloop, firstPloop, secondToFirstPloopIndices, mayAlias));
+}
+
+static bool isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop,
+ const IRMapping &firstToSecondPloopIndices,
+ llvm::function_ref<bool(Value, Value)> mayAlias) {
+ return !hasNestedParallelOp(firstPloop) &&
+ !hasNestedParallelOp(secondPloop) &&
+ equalIterationSpaces(firstPloop, secondPloop) &&
+ succeeded(verifyDependencies(firstPloop, secondPloop,
+ firstToSecondPloopIndices, mayAlias));
+}
+
+/// Prepends operations of firstPloop's body into secondPloop's body.
+/// Updates secondPloop with new loop.
+static void fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop,
+ OpBuilder builder,
+ llvm::function_ref<bool(Value, Value)> mayAlias) {
+ Block *block1 = firstPloop.getBody();
+ Block *block2 = secondPloop.getBody();
+ IRMapping firstToSecondPloopIndices;
+ firstToSecondPloopIndices.map(block1->getArguments(), block2->getArguments());
+
+ if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices,
+ mayAlias))
+ return;
+
+ DominanceInfo dom;
+ // We are fusing first loop into second, make sure there are no users of the
+ // first loop results between loops.
+ for (Operation *user : firstPloop->getUsers())
+ if (!dom.properlyDominates(secondPloop, user, /*enclosingOpOk*/ false))
+ return;
+
+ IRRewriter rewriter(builder);
+ secondPloop = mlir::fuseIndependentSiblingParallelLoops(
+ firstPloop, secondPloop, rewriter);
+ ;
+}
+
void mlir::scf::naivelyFuseParallelOps(
Region ®ion, llvm::function_ref<bool(Value, Value)> mayAlias) {
OpBuilder b(region);
@@ -59,7 +216,7 @@ void mlir::scf::naivelyFuseParallelOps(
}
for (MutableArrayRef<ParallelOp> ploops : ploopChains) {
for (int i = 0, e = ploops.size(); i + 1 < e; ++i)
- mlir::fuseIfLegal(ploops[i], ploops[i + 1], b, mayAlias);
+ fuseIfLegal(ploops[i], ploops[i + 1], b, mayAlias);
}
}
}
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index e3660e89fb684..5f58767be409d 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -1188,13 +1188,6 @@ static bool equalIterationSpaces(scf::ParallelOp firstPloop,
// Fusion related helpers
//===----------------------------------------------------------------------===//
-/// Verify there are no nested ParallelOps.
-static bool hasNestedParallelOp(scf::ParallelOp ploop) {
- auto walkResult = ploop.getBody()->walk(
- [](scf::ParallelOp) { return WalkResult::interrupt(); });
- return walkResult.wasInterrupted();
-}
-
bool mlir::checkFusionStructuralLegality(LoopLikeOpInterface &target,
LoopLikeOpInterface &source) {
auto iterSpaceEq =
@@ -1209,86 +1202,6 @@ bool mlir::checkFusionStructuralLegality(LoopLikeOpInterface &target,
return iterSpaceEq;
}
-static bool isFusionLegal(scf::ParallelOp firstPloop,
- scf::ParallelOp secondPloop,
- const IRMapping &firstToSecondPloopIndices,
- llvm::function_ref<bool(Value, Value)> mayAlias) {
- return !hasNestedParallelOp(firstPloop) &&
- !hasNestedParallelOp(secondPloop) &&
- equalIterationSpaces(firstPloop, secondPloop) &&
- succeeded(verifyDependencies(firstPloop, secondPloop,
- firstToSecondPloopIndices, mayAlias));
-}
-
-void mlir::fuseIfLegal(scf::ParallelOp firstPloop, scf::ParallelOp &secondPloop,
- OpBuilder builder,
- llvm::function_ref<bool(Value, Value)> mayAlias) {
- Block *block1 = firstPloop.getBody();
- Block *block2 = secondPloop.getBody();
- IRMapping firstToSecondPloopIndices;
- firstToSecondPloopIndices.map(block1->getArguments(), block2->getArguments());
-
- if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices,
- mayAlias))
- return;
-
- DominanceInfo dom;
- // We are fusing first loop into second, make sure there are no users of the
- // first loop results between loops.
- for (Operation *user : firstPloop->getUsers())
- if (!dom.properlyDominates(secondPloop, user, /*enclosingOpOk*/ false))
- return;
-
- ValueRange inits1 = firstPloop.getInitVals();
- ValueRange inits2 = secondPloop.getInitVals();
-
- SmallVector<Value> newInitVars(inits1.begin(), inits1.end());
- newInitVars.append(inits2.begin(), inits2.end());
-
- IRRewriter b(builder);
- b.setInsertionPoint(secondPloop);
- auto newSecondPloop = b.create<scf::ParallelOp>(
- secondPloop.getLoc(), secondPloop.getLowerBound(),
- secondPloop.getUpperBound(), secondPloop.getStep(), newInitVars);
-
- Block *newBlock = newSecondPloop.getBody();
- auto term1 = cast<scf::ReduceOp>(block1->getTerminator());
- auto term2 = cast<scf::ReduceOp>(block2->getTerminator());
-
- b.inlineBlockBefore(block2, newBlock, newBlock->begin(),
- newBlock->getArguments());
- b.inlineBlockBefore(block1, newBlock, newBlock->begin(),
- newBlock->getArguments());
-
- ValueRange results = newSecondPloop.getResults();
- if (!results.empty()) {
- b.setInsertionPointToEnd(newBlock);
-
- ValueRange reduceArgs1 = term1.getOperands();
- ValueRange reduceArgs2 = term2.getOperands();
- SmallVector<Value> newReduceArgs(reduceArgs1.begin(), reduceArgs1.end());
- newReduceArgs.append(reduceArgs2.begin(), reduceArgs2.end());
-
- auto newReduceOp = b.create<scf::ReduceOp>(term2.getLoc(), newReduceArgs);
-
- for (auto &&[i, reg] : llvm::enumerate(llvm::concat<Region>(
- term1.getReductions(), term2.getReductions()))) {
- Block &oldRedBlock = reg.front();
- Block &newRedBlock = newReduceOp.getReductions()[i].front();
- b.inlineBlockBefore(&oldRedBlock, &newRedBlock, newRedBlock.begin(),
- newRedBlock.getArguments());
- }
-
- firstPloop.replaceAllUsesWith(results.take_front(inits1.size()));
- secondPloop.replaceAllUsesWith(results.take_back(inits2.size()));
- }
- term1->erase();
- term2->erase();
- firstPloop.erase();
- secondPloop.erase();
- secondPloop = newSecondPloop;
-}
-
scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target,
scf::ForallOp source,
RewriterBase &rewriter) {
@@ -1393,7 +1306,54 @@ scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target,
scf::ParallelOp mlir::fuseIndependentSiblingParallelLoops(
scf::ParallelOp target, scf::ParallelOp source, RewriterBase &rewriter) {
- auto mayAlias = [&](Value val1, Value val2) -> bool { return false; };
- mlir::fuseIfLegal(target, source, rewriter, mayAlias);
- return source;
+ Block *block1 = target.getBody();
+ Block *block2 = source.getBody();
+ auto term1 = cast<scf::ReduceOp>(block1->getTerminator());
+ auto term2 = cast<scf::ReduceOp>(block2->getTerminator());
+
+ ValueRange inits1 = target.getInitVals();
+ ValueRange inits2 = source.getInitVals();
+
+ SmallVector<Value> newInitVars(inits1.begin(), inits1.end());
+ newInitVars.append(inits2.begin(), inits2.end());
+
+ rewriter.setInsertionPoint(source);
+ auto fusedLoop = rewriter.create<scf::ParallelOp>(
+ source.getLoc(), source.getLowerBound(), source.getUpperBound(),
+ source.getStep(), newInitVars);
+ Block *newBlock = fusedLoop.getBody();
+ rewriter.inlineBlockBefore(block2, newBlock, newBlock->begin(),
+ newBlock->getArguments());
+ rewriter.inlineBlockBefore(block1, newBlock, newBlock->begin(),
+ newBlock->getArguments());
+
+ ValueRange results = fusedLoop.getResults();
+ if (!results.empty()) {
+ rewriter.setInsertionPointToEnd(newBlock);
+
+ ValueRange reduceArgs1 = term1.getOperands();
+ ValueRange reduceArgs2 = term2.getOperands();
+ SmallVector<Value> newReduceArgs(reduceArgs1.begin(), reduceArgs1.end());
+ newReduceArgs.append(reduceArgs2.begin(), reduceArgs2.end());
+
+ auto newReduceOp =
+ rewriter.create<scf::ReduceOp>(term2.getLoc(), newReduceArgs);
+
+ for (auto &&[i, reg] : llvm::enumerate(llvm::concat<Region>(
+ term1.getReductions(), term2.getReductions()))) {
+ Block &oldRedBlock = reg.front();
+ Block &newRedBlock = newReduceOp.getReductions()[i].front();
+ rewriter.inlineBlockBefore(&oldRedBlock, &newRedBlock,
+ newRedBlock.begin(),
+ newRedBlock.getArguments());
+ }
+ target.replaceAllUsesWith(results.take_front(inits1.size()));
+ source.replaceAllUsesWith(results.take_back(inits2.size()));
+ }
+ term1->erase();
+ term2->erase();
+ target.erase();
+ source.erase();
+
+ return fusedLoop;
}
More information about the Mlir-commits
mailing list