[Mlir-commits] [mlir] [MLIR][SCF] Fix normalizeForallOp helper function (PR #138615)

Colin De Vlieghere llvmlistbot at llvm.org
Wed May 7 12:15:11 PDT 2025


https://github.com/Cubevoid updated https://github.com/llvm/llvm-project/pull/138615

>From e1acacc4528cae9e9a99857a7e18de353ceb53dc Mon Sep 17 00:00:00 2001
From: Colin De Vlieghere <cdevlieghere at tesla.com>
Date: Mon, 5 May 2025 19:10:24 -0700
Subject: [PATCH 1/2] [MLIR][SCF] Fix normalizeForallOp helper function

Previously the `normalizeForallOp` function did not work properly, since the
newly created op was not being returned in addition to the op failing
verification.

This patch fixes the helper function and adds a unit test for it.
---
 mlir/lib/Dialect/SCF/Utils/Utils.cpp          | 25 ++++++-----
 mlir/unittests/Dialect/SCF/CMakeLists.txt     |  1 +
 .../Dialect/SCF/LoopLikeSCFOpsTest.cpp        | 44 ++++++++++++++++++-
 3 files changed, 57 insertions(+), 13 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index e9471c1dbd0b7..d6bed551ec8fa 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -1482,30 +1482,31 @@ FailureOr<scf::ForallOp> mlir::normalizeForallOp(RewriterBase &rewriter,
   SmallVector<OpFoldResult> ubs = forallOp.getMixedUpperBound();
   SmallVector<OpFoldResult> steps = forallOp.getMixedStep();
 
-  if (llvm::all_of(
-          lbs, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); }) &&
-      llvm::all_of(
-          steps, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 1); })) {
+  if (forallOp.isNormalized())
     return forallOp;
-  }
 
-  SmallVector<OpFoldResult> newLbs, newUbs, newSteps;
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(forallOp);
+  SmallVector<OpFoldResult> newUbs;
   for (auto [lb, ub, step] : llvm::zip_equal(lbs, ubs, steps)) {
     Range normalizedLoopParams =
         emitNormalizedLoopBounds(rewriter, forallOp.getLoc(), lb, ub, step);
-    newLbs.push_back(normalizedLoopParams.offset);
     newUbs.push_back(normalizedLoopParams.size);
-    newSteps.push_back(normalizedLoopParams.stride);
   }
+  (void)foldDynamicIndexList(newUbs);
 
+  // Use the normalized builder since the lower bounds are always 0 and the
+  // steps are always 1.
   auto normalizedForallOp = rewriter.create<scf::ForallOp>(
-      forallOp.getLoc(), newLbs, newUbs, newSteps, forallOp.getOutputs(),
-      forallOp.getMapping(), [](OpBuilder &, Location, ValueRange) {});
+      forallOp.getLoc(), newUbs, forallOp.getOutputs(), forallOp.getMapping(),
+      [](OpBuilder &, Location, ValueRange) {});
 
   rewriter.inlineRegionBefore(forallOp.getBodyRegion(),
                               normalizedForallOp.getBodyRegion(),
                               normalizedForallOp.getBodyRegion().begin());
+  // Remove the original empty block in the new loop.
+  rewriter.eraseBlock(&normalizedForallOp.getBodyRegion().back());
 
-  rewriter.replaceAllOpUsesWith(forallOp, normalizedForallOp);
-  return success();
+  rewriter.replaceOp(forallOp, normalizedForallOp);
+  return normalizedForallOp;
 }
diff --git a/mlir/unittests/Dialect/SCF/CMakeLists.txt b/mlir/unittests/Dialect/SCF/CMakeLists.txt
index c0c1757b80fb5..83cefbcabf4d9 100644
--- a/mlir/unittests/Dialect/SCF/CMakeLists.txt
+++ b/mlir/unittests/Dialect/SCF/CMakeLists.txt
@@ -5,4 +5,5 @@ mlir_target_link_libraries(MLIRSCFTests
   PRIVATE
   MLIRIR
   MLIRSCFDialect
+  MLIRSCFUtils
 )
diff --git a/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp b/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp
index 53a4af14d119a..e4a3a857a747e 100644
--- a/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp
+++ b/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp
@@ -6,11 +6,15 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Utils/Utils.h"
 #include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/OwningOpRef.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/LoopLikeInterface.h"
 #include "gtest/gtest.h"
 
 using namespace mlir;
@@ -23,7 +27,7 @@ using namespace mlir::scf;
 class SCFLoopLikeTest : public ::testing::Test {
 protected:
   SCFLoopLikeTest() : b(&context), loc(UnknownLoc::get(&context)) {
-    context.loadDialect<arith::ArithDialect, scf::SCFDialect>();
+    context.loadDialect<affine::AffineDialect, arith::ArithDialect, scf::SCFDialect>();
   }
 
   void checkUnidimensional(LoopLikeOpInterface loopLikeOp) {
@@ -88,6 +92,24 @@ class SCFLoopLikeTest : public ::testing::Test {
     EXPECT_EQ((*maybeInductionVars).size(), 2u);
   }
 
+  void checkNormalized(LoopLikeOpInterface loopLikeOp) {
+    std::optional<SmallVector<OpFoldResult>> maybeLb =
+        loopLikeOp.getLoopLowerBounds();
+    ASSERT_TRUE(maybeLb.has_value());
+    std::optional<SmallVector<OpFoldResult>> maybeStep =
+        loopLikeOp.getLoopSteps();
+    ASSERT_TRUE(maybeStep.has_value());
+
+    auto allEqual = [](ArrayRef<OpFoldResult> results, int64_t val) {
+      return llvm::all_of(results, [&](OpFoldResult ofr) {
+        auto intValue = getConstantIntValue(ofr);
+        return intValue.has_value() && intValue == val;
+      });
+    };
+    EXPECT_TRUE(allEqual(*maybeLb, 0));
+    EXPECT_TRUE(allEqual(*maybeStep, 1));
+  }
+
   MLIRContext context;
   OpBuilder b;
   Location loc;
@@ -138,3 +160,23 @@ TEST_F(SCFLoopLikeTest, queryMultidimensionalLooplikes) {
       ValueRange({step->getResult(), step->getResult()}), ValueRange());
   checkMultidimensional(parallelOp.get());
 }
+
+TEST_F(SCFLoopLikeTest, testForallNormalize) {
+  OwningOpRef<arith::ConstantIndexOp> lb =
+      b.create<arith::ConstantIndexOp>(loc, 1);
+  OwningOpRef<arith::ConstantIndexOp> ub =
+      b.create<arith::ConstantIndexOp>(loc, 10);
+  OwningOpRef<arith::ConstantIndexOp> step =
+      b.create<arith::ConstantIndexOp>(loc, 3);
+
+  scf::ForallOp forallOp = b.create<scf::ForallOp>(
+      loc, ArrayRef<OpFoldResult>({lb->getResult(), lb->getResult()}),
+      ArrayRef<OpFoldResult>({ub->getResult(), ub->getResult()}),
+      ArrayRef<OpFoldResult>({step->getResult(), step->getResult()}),
+      ValueRange(), std::nullopt);
+  IRRewriter rewriter(b);
+  FailureOr<scf::ForallOp> maybeNormalizedForallOp = normalizeForallOp(rewriter, forallOp);
+  EXPECT_TRUE(succeeded(maybeNormalizedForallOp));
+  OwningOpRef<scf::ForallOp> normalizedForallOp(*maybeNormalizedForallOp);
+  checkNormalized(normalizedForallOp.get());
+}

>From 3bb4d9b7c002faf965c6c193f1dde3a67099b740 Mon Sep 17 00:00:00 2001
From: Colin De Vlieghere <cdevlieghere at tesla.com>
Date: Wed, 7 May 2025 12:14:39 -0700
Subject: [PATCH 2/2] Update users of IVs inside loop

---
 mlir/lib/Dialect/SCF/Utils/Utils.cpp          | 14 ++++++++++++--
 .../Dialect/SCF/LoopLikeSCFOpsTest.cpp        | 19 +++++++++++++++++--
 2 files changed, 29 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index d6bed551ec8fa..d9550fe18dc02 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -1486,11 +1486,12 @@ FailureOr<scf::ForallOp> mlir::normalizeForallOp(RewriterBase &rewriter,
     return forallOp;
 
   OpBuilder::InsertionGuard g(rewriter);
+  auto loc = forallOp.getLoc();
   rewriter.setInsertionPoint(forallOp);
   SmallVector<OpFoldResult> newUbs;
   for (auto [lb, ub, step] : llvm::zip_equal(lbs, ubs, steps)) {
     Range normalizedLoopParams =
-        emitNormalizedLoopBounds(rewriter, forallOp.getLoc(), lb, ub, step);
+        emitNormalizedLoopBounds(rewriter, loc, lb, ub, step);
     newUbs.push_back(normalizedLoopParams.size);
   }
   (void)foldDynamicIndexList(newUbs);
@@ -1498,7 +1499,7 @@ FailureOr<scf::ForallOp> mlir::normalizeForallOp(RewriterBase &rewriter,
   // Use the normalized builder since the lower bounds are always 0 and the
   // steps are always 1.
   auto normalizedForallOp = rewriter.create<scf::ForallOp>(
-      forallOp.getLoc(), newUbs, forallOp.getOutputs(), forallOp.getMapping(),
+      loc, newUbs, forallOp.getOutputs(), forallOp.getMapping(),
       [](OpBuilder &, Location, ValueRange) {});
 
   rewriter.inlineRegionBefore(forallOp.getBodyRegion(),
@@ -1507,6 +1508,15 @@ FailureOr<scf::ForallOp> mlir::normalizeForallOp(RewriterBase &rewriter,
   // Remove the original empty block in the new loop.
   rewriter.eraseBlock(&normalizedForallOp.getBodyRegion().back());
 
+  rewriter.setInsertionPointToStart(normalizedForallOp.getBody());
+  // Update the users of the original loop variables.
+  for (auto [idx, iv] :
+       llvm::enumerate(normalizedForallOp.getInductionVars())) {
+    auto origLb = getValueOrCreateConstantIndexOp(rewriter, loc, lbs[idx]);
+    auto origStep = getValueOrCreateConstantIndexOp(rewriter, loc, steps[idx]);
+    denormalizeInductionVariable(rewriter, loc, iv, origLb, origStep);
+  }
+
   rewriter.replaceOp(forallOp, normalizedForallOp);
   return normalizedForallOp;
 }
diff --git a/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp b/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp
index e4a3a857a747e..fecd960d694b1 100644
--- a/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp
+++ b/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp
@@ -10,6 +10,7 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/SCF/Utils/Utils.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/OwningOpRef.h"
@@ -27,7 +28,8 @@ using namespace mlir::scf;
 class SCFLoopLikeTest : public ::testing::Test {
 protected:
   SCFLoopLikeTest() : b(&context), loc(UnknownLoc::get(&context)) {
-    context.loadDialect<affine::AffineDialect, arith::ArithDialect, scf::SCFDialect>();
+    context.loadDialect<affine::AffineDialect, arith::ArithDialect,
+                        scf::SCFDialect>();
   }
 
   void checkUnidimensional(LoopLikeOpInterface loopLikeOp) {
@@ -174,9 +176,22 @@ TEST_F(SCFLoopLikeTest, testForallNormalize) {
       ArrayRef<OpFoldResult>({ub->getResult(), ub->getResult()}),
       ArrayRef<OpFoldResult>({step->getResult(), step->getResult()}),
       ValueRange(), std::nullopt);
+  // Create a user of the induction variable. Bitcast is chosen for simplicity
+  // since it is unary.
+  b.setInsertionPointToStart(forallOp.getBody());
+  b.create<arith::BitcastOp>(UnknownLoc::get(&context), b.getF64Type(),
+                             forallOp.getInductionVar(0));
   IRRewriter rewriter(b);
-  FailureOr<scf::ForallOp> maybeNormalizedForallOp = normalizeForallOp(rewriter, forallOp);
+  FailureOr<scf::ForallOp> maybeNormalizedForallOp =
+      normalizeForallOp(rewriter, forallOp);
   EXPECT_TRUE(succeeded(maybeNormalizedForallOp));
   OwningOpRef<scf::ForallOp> normalizedForallOp(*maybeNormalizedForallOp);
   checkNormalized(normalizedForallOp.get());
+
+  // Check that the IV user has been updated to use the denormalized variable.
+  Block *body = normalizedForallOp->getBody();
+  auto bitcastOps = body->getOps<arith::BitcastOp>();
+  ASSERT_EQ(std::distance(bitcastOps.begin(), bitcastOps.end()), 1);
+  arith::BitcastOp ivUser = *bitcastOps.begin();
+  ASSERT_NE(ivUser.getIn(), normalizedForallOp->getInductionVar(0));
 }



More information about the Mlir-commits mailing list