[Mlir-commits] [mlir] 9ca4664 - [MLIR][SCF] Fix normalizeForallOp helper function (#138615)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri May 9 12:59:02 PDT 2025
Author: Colin De Vlieghere
Date: 2025-05-09T12:58:59-07:00
New Revision: 9ca46640062a6c0b955d16ad6f88305b534af8a3
URL: https://github.com/llvm/llvm-project/commit/9ca46640062a6c0b955d16ad6f88305b534af8a3
DIFF: https://github.com/llvm/llvm-project/commit/9ca46640062a6c0b955d16ad6f88305b534af8a3.diff
LOG: [MLIR][SCF] Fix normalizeForallOp helper function (#138615)
Previously the `normalizeForallOp` function did not work properly, since
the newly created op was not being returned in addition to the op
failing verification.
This patch fixes the helper function and adds a unit test for it.
Added:
Modified:
mlir/lib/Dialect/SCF/Utils/Utils.cpp
mlir/unittests/Dialect/SCF/CMakeLists.txt
mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index e9471c1dbd0b7..d9550fe18dc02 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -1482,30 +1482,41 @@ FailureOr<scf::ForallOp> mlir::normalizeForallOp(RewriterBase &rewriter,
SmallVector<OpFoldResult> ubs = forallOp.getMixedUpperBound();
SmallVector<OpFoldResult> steps = forallOp.getMixedStep();
- if (llvm::all_of(
- lbs, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); }) &&
- llvm::all_of(
- steps, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 1); })) {
+ if (forallOp.isNormalized())
return forallOp;
- }
- SmallVector<OpFoldResult> newLbs, newUbs, newSteps;
+ OpBuilder::InsertionGuard g(rewriter);
+ auto loc = forallOp.getLoc();
+ rewriter.setInsertionPoint(forallOp);
+ SmallVector<OpFoldResult> newUbs;
for (auto [lb, ub, step] : llvm::zip_equal(lbs, ubs, steps)) {
Range normalizedLoopParams =
- emitNormalizedLoopBounds(rewriter, forallOp.getLoc(), lb, ub, step);
- newLbs.push_back(normalizedLoopParams.offset);
+ emitNormalizedLoopBounds(rewriter, loc, lb, ub, step);
newUbs.push_back(normalizedLoopParams.size);
- newSteps.push_back(normalizedLoopParams.stride);
}
+ (void)foldDynamicIndexList(newUbs);
+ // Use the normalized builder since the lower bounds are always 0 and the
+ // steps are always 1.
auto normalizedForallOp = rewriter.create<scf::ForallOp>(
- forallOp.getLoc(), newLbs, newUbs, newSteps, forallOp.getOutputs(),
- forallOp.getMapping(), [](OpBuilder &, Location, ValueRange) {});
+ loc, newUbs, forallOp.getOutputs(), forallOp.getMapping(),
+ [](OpBuilder &, Location, ValueRange) {});
rewriter.inlineRegionBefore(forallOp.getBodyRegion(),
normalizedForallOp.getBodyRegion(),
normalizedForallOp.getBodyRegion().begin());
+ // Remove the original empty block in the new loop.
+ rewriter.eraseBlock(&normalizedForallOp.getBodyRegion().back());
+
+ rewriter.setInsertionPointToStart(normalizedForallOp.getBody());
+ // Update the users of the original loop variables.
+ for (auto [idx, iv] :
+ llvm::enumerate(normalizedForallOp.getInductionVars())) {
+ auto origLb = getValueOrCreateConstantIndexOp(rewriter, loc, lbs[idx]);
+ auto origStep = getValueOrCreateConstantIndexOp(rewriter, loc, steps[idx]);
+ denormalizeInductionVariable(rewriter, loc, iv, origLb, origStep);
+ }
- rewriter.replaceAllOpUsesWith(forallOp, normalizedForallOp);
- return success();
+ rewriter.replaceOp(forallOp, normalizedForallOp);
+ return normalizedForallOp;
}
diff --git a/mlir/unittests/Dialect/SCF/CMakeLists.txt b/mlir/unittests/Dialect/SCF/CMakeLists.txt
index c0c1757b80fb5..83cefbcabf4d9 100644
--- a/mlir/unittests/Dialect/SCF/CMakeLists.txt
+++ b/mlir/unittests/Dialect/SCF/CMakeLists.txt
@@ -5,4 +5,5 @@ mlir_target_link_libraries(MLIRSCFTests
PRIVATE
MLIRIR
MLIRSCFDialect
+ MLIRSCFUtils
)
diff --git a/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp b/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp
index 53a4af14d119a..fecd960d694b1 100644
--- a/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp
+++ b/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp
@@ -6,11 +6,16 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Utils/Utils.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OwningOpRef.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/LoopLikeInterface.h"
#include "gtest/gtest.h"
using namespace mlir;
@@ -23,7 +28,8 @@ using namespace mlir::scf;
class SCFLoopLikeTest : public ::testing::Test {
protected:
SCFLoopLikeTest() : b(&context), loc(UnknownLoc::get(&context)) {
- context.loadDialect<arith::ArithDialect, scf::SCFDialect>();
+ context.loadDialect<affine::AffineDialect, arith::ArithDialect,
+ scf::SCFDialect>();
}
void checkUnidimensional(LoopLikeOpInterface loopLikeOp) {
@@ -88,6 +94,24 @@ class SCFLoopLikeTest : public ::testing::Test {
EXPECT_EQ((*maybeInductionVars).size(), 2u);
}
+ void checkNormalized(LoopLikeOpInterface loopLikeOp) {
+ std::optional<SmallVector<OpFoldResult>> maybeLb =
+ loopLikeOp.getLoopLowerBounds();
+ ASSERT_TRUE(maybeLb.has_value());
+ std::optional<SmallVector<OpFoldResult>> maybeStep =
+ loopLikeOp.getLoopSteps();
+ ASSERT_TRUE(maybeStep.has_value());
+
+ auto allEqual = [](ArrayRef<OpFoldResult> results, int64_t val) {
+ return llvm::all_of(results, [&](OpFoldResult ofr) {
+ auto intValue = getConstantIntValue(ofr);
+ return intValue.has_value() && intValue == val;
+ });
+ };
+ EXPECT_TRUE(allEqual(*maybeLb, 0));
+ EXPECT_TRUE(allEqual(*maybeStep, 1));
+ }
+
MLIRContext context;
OpBuilder b;
Location loc;
@@ -138,3 +162,36 @@ TEST_F(SCFLoopLikeTest, queryMultidimensionalLooplikes) {
ValueRange({step->getResult(), step->getResult()}), ValueRange());
checkMultidimensional(parallelOp.get());
}
+
+TEST_F(SCFLoopLikeTest, testForallNormalize) {
+ OwningOpRef<arith::ConstantIndexOp> lb =
+ b.create<arith::ConstantIndexOp>(loc, 1);
+ OwningOpRef<arith::ConstantIndexOp> ub =
+ b.create<arith::ConstantIndexOp>(loc, 10);
+ OwningOpRef<arith::ConstantIndexOp> step =
+ b.create<arith::ConstantIndexOp>(loc, 3);
+
+ scf::ForallOp forallOp = b.create<scf::ForallOp>(
+ loc, ArrayRef<OpFoldResult>({lb->getResult(), lb->getResult()}),
+ ArrayRef<OpFoldResult>({ub->getResult(), ub->getResult()}),
+ ArrayRef<OpFoldResult>({step->getResult(), step->getResult()}),
+ ValueRange(), std::nullopt);
+ // Create a user of the induction variable. Bitcast is chosen for simplicity
+ // since it is unary.
+ b.setInsertionPointToStart(forallOp.getBody());
+ b.create<arith::BitcastOp>(UnknownLoc::get(&context), b.getF64Type(),
+ forallOp.getInductionVar(0));
+ IRRewriter rewriter(b);
+ FailureOr<scf::ForallOp> maybeNormalizedForallOp =
+ normalizeForallOp(rewriter, forallOp);
+ EXPECT_TRUE(succeeded(maybeNormalizedForallOp));
+ OwningOpRef<scf::ForallOp> normalizedForallOp(*maybeNormalizedForallOp);
+ checkNormalized(normalizedForallOp.get());
+
+ // Check that the IV user has been updated to use the denormalized variable.
+ Block *body = normalizedForallOp->getBody();
+ auto bitcastOps = body->getOps<arith::BitcastOp>();
+ ASSERT_EQ(std::distance(bitcastOps.begin(), bitcastOps.end()), 1);
+ arith::BitcastOp ivUser = *bitcastOps.begin();
+ ASSERT_NE(ivUser.getIn(), normalizedForallOp->getInductionVar(0));
+}
More information about the Mlir-commits
mailing list