[Mlir-commits] [mlir] [mlir][scf]: Expose emitNormalizedLoopBounds/denormalizeInductionVariable util functions (NFC) (PR #94429)

Aviad Cohen llvmlistbot at llvm.org
Wed Jun 12 22:21:27 PDT 2024


https://github.com/AviadCo updated https://github.com/llvm/llvm-project/pull/94429

>From 8191de20e0891c72d705efb5ec13a45b0c4cb0a2 Mon Sep 17 00:00:00 2001
From: Aviad Cohen <aviad.cohen2 at mobileye.com>
Date: Wed, 5 Jun 2024 08:07:16 +0300
Subject: [PATCH] [mlir][scf]: Expose
 emitNormalizedLoopBounds/denormalizeInductionVariable util functions

* Also updated normarlize/denormalize loop bounds to be folded if
possible.
---
 mlir/include/mlir/Dialect/Arith/Utils/Utils.h | 12 ++-
 mlir/include/mlir/Dialect/SCF/Utils/Utils.h   | 26 ++++++
 mlir/lib/Dialect/Arith/Utils/Utils.cpp        | 21 ++++-
 mlir/lib/Dialect/SCF/Utils/Utils.cpp          | 87 +++++++++----------
 mlir/test/Dialect/Affine/loop-coalescing.mlir | 15 ++--
 mlir/test/Dialect/SCF/transform-ops.mlir      | 15 +++-
 6 files changed, 113 insertions(+), 63 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
index 5e7945d9b0492..7f4822c3ffa90 100644
--- a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
@@ -54,7 +54,13 @@ llvm::SmallBitVector getPositionsOfShapeOne(unsigned rank,
                                             ArrayRef<int64_t> shape);
 
 /// Converts an OpFoldResult to a Value. Returns the fold result if it casts to
-/// a Value or creates a ConstantIndexOp if it casts to an IntegerAttribute.
+/// a Value or creates a ConstantOp if it casts to an Integer Attribute.
+/// Other attribute types are not supported.
+Value getValueOrCreateConstantIntOp(OpBuilder &b, Location loc, Type targetType,
+                                    OpFoldResult ofr);
+
+/// Converts an OpFoldResult to a Value. Returns the fold result if it casts to
+/// a Value or creates a ConstantIndexOp if it casts to an Integer Attribute.
 /// Other attribute types are not supported.
 Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,
                                       OpFoldResult ofr);
@@ -88,6 +94,10 @@ Value createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type,
 Value createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type,
                                   const APFloat &value);
 
+/// Returns the int type of the integer in ofr.
+/// Other attribute types are not supported.
+Type getIntType(OpFoldResult ofr);
+
 /// Helper struct to build simple arithmetic quantities with minimal type
 /// inference support.
 struct ArithBuilder {
diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
index bc09cc7f7fa5e..f719c00213987 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
@@ -120,6 +120,32 @@ LogicalResult loopUnrollByFactor(
     scf::ForOp forOp, uint64_t unrollFactor,
     function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn = nullptr);
 
+/// This structure is to pass and return sets of loop parameters without
+/// confusing the order.
+struct LoopParams {
+  OpFoldResult lowerBound;
+  OpFoldResult upperBound;
+  OpFoldResult step;
+};
+
+/// Transform a loop with a strictly positive step
+///   for %i = %lb to %ub step %s
+/// into a 0-based loop with step 1
+///   for %ii = 0 to ceildiv(%ub - %lb, %s) step 1 {
+///     %i = %ii * %s + %lb
+/// Insert the induction variable remapping in the body of `inner`, which is
+/// expected to be either `loop` or another loop perfectly nested under `loop`.
+/// Insert the definition of new bounds immediate before `outer`, which is
+/// expected to be either `loop` or its parent in the loop nest.
+LoopParams emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc,
+                                    OpFoldResult lb, OpFoldResult ub,
+                                    OpFoldResult step);
+
+/// Get back the original induction variable values after loop normalization.
+void denormalizeInductionVariable(RewriterBase &rewriter, Location loc,
+                                  Value normalizedIv, OpFoldResult origLb,
+                                  OpFoldResult origStep);
+
 /// Tile a nest of standard for loops rooted at `rootForOp` by finding such
 /// parametric tile sizes that the outer loops have a fixed number of iterations
 /// as defined in `sizes`.
diff --git a/mlir/lib/Dialect/Arith/Utils/Utils.cpp b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
index 4ce55a23820cf..61404336a4b7b 100644
--- a/mlir/lib/Dialect/Arith/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
@@ -100,12 +100,20 @@ llvm::SmallBitVector mlir::getPositionsOfShapeOne(unsigned rank,
   return dimsToProject;
 }
 
+Value mlir::getValueOrCreateConstantIntOp(OpBuilder &b, Location loc,
+                                          Type targetType, OpFoldResult ofr) {
+  if (auto value = dyn_cast_if_present<Value>(ofr))
+    return value;
+  auto attr = cast<IntegerAttr>(cast<Attribute>(ofr));
+  return b.create<arith::ConstantOp>(
+      loc, b.getIntegerAttr(targetType, attr.getValue().getSExtValue()));
+}
+
 Value mlir::getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,
                                             OpFoldResult ofr) {
-  if (auto value = llvm::dyn_cast_if_present<Value>(ofr))
+  if (auto value = dyn_cast_if_present<Value>(ofr))
     return value;
-  auto attr = dyn_cast<IntegerAttr>(llvm::dyn_cast_if_present<Attribute>(ofr));
-  assert(attr && "expect the op fold result casts to an integer attribute");
+  auto attr = cast<IntegerAttr>(cast<Attribute>(ofr));
   return b.create<arith::ConstantIndexOp>(loc, attr.getValue().getSExtValue());
 }
 
@@ -294,6 +302,13 @@ Value mlir::createScalarOrSplatConstant(OpBuilder &builder, Location loc,
   return builder.createOrFold<arith::ConstantOp>(loc, type, splat);
 }
 
+Type mlir::getIntType(OpFoldResult ofr) {
+  if (auto value = dyn_cast_if_present<Value>(ofr))
+    return value.getType();
+  auto attr = cast<IntegerAttr>(cast<Attribute>(ofr));
+  return attr.getType();
+}
+
 Value ArithBuilder::_and(Value lhs, Value rhs) {
   return b.create<arith::AndIOp>(loc, lhs, rhs);
 }
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 6658cca03eba7..f11636fff37c3 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -18,6 +18,7 @@
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/IRMapping.h"
+#include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Support/MathExtras.h"
@@ -29,16 +30,6 @@
 
 using namespace mlir;
 
-namespace {
-// This structure is to pass and return sets of loop parameters without
-// confusing the order.
-struct LoopParams {
-  Value lowerBound;
-  Value upperBound;
-  Value step;
-};
-} // namespace
-
 SmallVector<scf::ForOp> mlir::replaceLoopNestWithNewYields(
     RewriterBase &rewriter, MutableArrayRef<scf::ForOp> loopNest,
     ValueRange newIterOperands, const NewYieldValuesFn &newYieldValuesFn,
@@ -473,17 +464,9 @@ LogicalResult mlir::loopUnrollByFactor(
   return success();
 }
 
-/// Transform a loop with a strictly positive step
-///   for %i = %lb to %ub step %s
-/// into a 0-based loop with step 1
-///   for %ii = 0 to ceildiv(%ub - %lb, %s) step 1 {
-///     %i = %ii * %s + %lb
-/// Insert the induction variable remapping in the body of `inner`, which is
-/// expected to be either `loop` or another loop perfectly nested under `loop`.
-/// Insert the definition of new bounds immediate before `outer`, which is
-/// expected to be either `loop` or its parent in the loop nest.
-static LoopParams emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc,
-                                           Value lb, Value ub, Value step) {
+LoopParams mlir::emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc,
+                                          OpFoldResult lb, OpFoldResult ub,
+                                          OpFoldResult step) {
   // For non-index types, generate `arith` instructions
   // Check if the loop is already known to have a constant zero lower bound or
   // a constant one step.
@@ -495,32 +478,38 @@ static LoopParams emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc,
   if (auto stepCst = getConstantIntValue(step))
     isStepOne = stepCst.value() == 1;
 
+  Type loopParamsType = getIntType(lb);
+  assert(loopParamsType == getIntType(ub) &&
+         loopParamsType == getIntType(step) && "expected matching types");
+
   // Compute the number of iterations the loop executes: ceildiv(ub - lb, step)
   // assuming the step is strictly positive.  Update the bounds and the step
   // of the loop to go from 0 to the number of iterations, if necessary.
   if (isZeroBased && isStepOne)
     return {lb, ub, step};
 
-  Value diff = isZeroBased ? ub : rewriter.create<arith::SubIOp>(loc, ub, lb);
-  Value newUpperBound =
-      isStepOne ? diff : rewriter.create<arith::CeilDivSIOp>(loc, diff, step);
+  OpFoldResult diff = ub;
+  if (!isZeroBased) {
+    diff = rewriter.createOrFold<arith::SubIOp>(
+        loc, getValueOrCreateConstantIntOp(rewriter, loc, loopParamsType, ub),
+        getValueOrCreateConstantIntOp(rewriter, loc, loopParamsType, lb));
+  }
+  OpFoldResult newUpperBound = diff;
+  if (!isStepOne) {
+    newUpperBound = rewriter.createOrFold<arith::CeilDivSIOp>(
+        loc, getValueOrCreateConstantIntOp(rewriter, loc, loopParamsType, diff),
+        getValueOrCreateConstantIntOp(rewriter, loc, loopParamsType, step));
+  }
 
-  Value newLowerBound = isZeroBased
-                            ? lb
-                            : rewriter.create<arith::ConstantOp>(
-                                  loc, rewriter.getZeroAttr(lb.getType()));
-  Value newStep = isStepOne
-                      ? step
-                      : rewriter.create<arith::ConstantOp>(
-                            loc, rewriter.getIntegerAttr(step.getType(), 1));
+  OpFoldResult newLowerBound = rewriter.getZeroAttr(loopParamsType);
+  OpFoldResult newStep = rewriter.getOneAttr(loopParamsType);
 
   return {newLowerBound, newUpperBound, newStep};
 }
 
-/// Get back the original induction variable values after loop normalization
-static void denormalizeInductionVariable(RewriterBase &rewriter, Location loc,
-                                         Value normalizedIv, Value origLb,
-                                         Value origStep) {
+void mlir::denormalizeInductionVariable(RewriterBase &rewriter, Location loc,
+                                        Value normalizedIv, OpFoldResult origLb,
+                                        OpFoldResult origStep) {
   Value denormalizedIv;
   SmallPtrSet<Operation *, 2> preserve;
   bool isStepOne = isConstantIntValue(origStep, 1);
@@ -528,12 +517,16 @@ static void denormalizeInductionVariable(RewriterBase &rewriter, Location loc,
 
   Value scaled = normalizedIv;
   if (!isStepOne) {
-    scaled = rewriter.create<arith::MulIOp>(loc, normalizedIv, origStep);
+    Value origStepValue = getValueOrCreateConstantIntOp(
+        rewriter, loc, getIntType(origStep), origStep);
+    scaled = rewriter.create<arith::MulIOp>(loc, normalizedIv, origStepValue);
     preserve.insert(scaled.getDefiningOp());
   }
   denormalizedIv = scaled;
   if (!isZeroBased) {
-    denormalizedIv = rewriter.create<arith::AddIOp>(loc, scaled, origLb);
+    Value origLbValue = getValueOrCreateConstantIntOp(
+        rewriter, loc, getIntType(origLb), origLb);
+    denormalizedIv = rewriter.create<arith::AddIOp>(loc, scaled, origLbValue);
     preserve.insert(denormalizedIv.getDefiningOp());
   }
 
@@ -638,9 +631,13 @@ LogicalResult mlir::coalesceLoops(RewriterBase &rewriter,
         emitNormalizedLoopBounds(rewriter, loop.getLoc(), lb, ub, step);
 
     rewriter.modifyOpInPlace(loop, [&]() {
-      loop.setLowerBound(newLoopParams.lowerBound);
-      loop.setUpperBound(newLoopParams.upperBound);
-      loop.setStep(newLoopParams.step);
+      Type loopParamsType = lb.getType();
+      loop.setLowerBound(getValueOrCreateConstantIntOp(
+          rewriter, loop.getLoc(), loopParamsType, newLoopParams.lowerBound));
+      loop.setUpperBound(getValueOrCreateConstantIntOp(
+          rewriter, loop.getLoc(), loopParamsType, newLoopParams.upperBound));
+      loop.setStep(getValueOrCreateConstantIntOp(
+          rewriter, loop.getLoc(), loopParamsType, newLoopParams.step));
     });
 
     rewriter.setInsertionPointToStart(innermost.getBody());
@@ -778,8 +775,7 @@ void mlir::collapseParallelLoops(
     llvm::sort(dims);
 
   // Normalize ParallelOp's iteration pattern.
-  SmallVector<Value, 3> normalizedLowerBounds, normalizedSteps,
-      normalizedUpperBounds;
+  SmallVector<Value, 3> normalizedUpperBounds;
   for (unsigned i = 0, e = loops.getNumLoops(); i < e; ++i) {
     OpBuilder::InsertionGuard g2(rewriter);
     rewriter.setInsertionPoint(loops);
@@ -787,9 +783,8 @@ void mlir::collapseParallelLoops(
     Value ub = loops.getUpperBound()[i];
     Value step = loops.getStep()[i];
     auto newLoopParams = emitNormalizedLoopBounds(rewriter, loc, lb, ub, step);
-    normalizedLowerBounds.push_back(newLoopParams.lowerBound);
-    normalizedUpperBounds.push_back(newLoopParams.upperBound);
-    normalizedSteps.push_back(newLoopParams.step);
+    normalizedUpperBounds.push_back(getValueOrCreateConstantIntOp(
+        rewriter, loops.getLoc(), ub.getType(), newLoopParams.upperBound));
 
     rewriter.setInsertionPointToStart(loops.getBody());
     denormalizeInductionVariable(rewriter, loc, loops.getInductionVars()[i], lb,
diff --git a/mlir/test/Dialect/Affine/loop-coalescing.mlir b/mlir/test/Dialect/Affine/loop-coalescing.mlir
index ae0adf5a0a02d..0235000aeac53 100644
--- a/mlir/test/Dialect/Affine/loop-coalescing.mlir
+++ b/mlir/test/Dialect/Affine/loop-coalescing.mlir
@@ -74,11 +74,10 @@ func.func @multi_use() {
 
 func.func @unnormalized_loops() {
   // CHECK: %[[orig_step_i:.*]] = arith.constant 2
-  // CHECK: %[[orig_step_j:.*]] = arith.constant 3
+
+  // CHECK: %[[orig_step_j_and_numiter_i:.*]] = arith.constant 3
   // CHECK: %[[orig_lb_i:.*]] = arith.constant 5
   // CHECK: %[[orig_lb_j:.*]] = arith.constant 7
-  // CHECK: %[[orig_ub_i:.*]] = arith.constant 10
-  // CHECK: %[[orig_ub_j:.*]] = arith.constant 17
   %c2 = arith.constant 2 : index
   %c3 = arith.constant 3 : index
   %c5 = arith.constant 5 : index
@@ -86,20 +85,16 @@ func.func @unnormalized_loops() {
   %c10 = arith.constant 10 : index
   %c17 = arith.constant 17 : index
 
-  // Number of iterations in the outer scf.
-  // CHECK: %[[diff_i:.*]] = arith.subi %[[orig_ub_i]], %[[orig_lb_i]]
-  // CHECK: %[[numiter_i:.*]] = arith.ceildivsi %[[diff_i]], %[[orig_step_i]]
-
   // Normalized lower bound and step for the outer scf.
   // CHECK: %[[lb_i:.*]] = arith.constant 0
   // CHECK: %[[step_i:.*]] = arith.constant 1
 
   // Number of iterations in the inner loop, the pattern is the same as above,
   // only capture the final result.
-  // CHECK: %[[numiter_j:.*]] = arith.ceildivsi {{.*}}, %[[orig_step_j]]
+  // CHECK: %[[numiter_j:.*]] = arith.constant 4
 
   // New bounds of the outer scf.
-  // CHECK: %[[range:.*]] = arith.muli %[[numiter_i]], %[[numiter_j]]
+  // CHECK: %[[range:.*]] = arith.muli %[[orig_step_j_and_numiter_i:.*]], %[[numiter_j]]
   // CHECK: scf.for %[[i:.*]] = %[[lb_i]] to %[[range]] step %[[step_i]]
   scf.for %i = %c5 to %c10 step %c2 {
     // The inner loop has been removed.
@@ -108,7 +103,7 @@ func.func @unnormalized_loops() {
       // The IVs are rewritten.
       // CHECK: %[[normalized_j:.*]] = arith.remsi %[[i]], %[[numiter_j]]
       // CHECK: %[[normalized_i:.*]] = arith.divsi %[[i]], %[[numiter_j]]
-      // CHECK: %[[scaled_j:.*]] = arith.muli %[[normalized_j]], %[[orig_step_j]]
+      // CHECK: %[[scaled_j:.*]] = arith.muli %[[normalized_j]], %[[orig_step_j_and_numiter_i]]
       // CHECK: %[[orig_j:.*]] = arith.addi %[[scaled_j]], %[[orig_lb_j]]
       // CHECK: %[[scaled_i:.*]] = arith.muli %[[normalized_i]], %[[orig_step_i]]
       // CHECK: %[[orig_i:.*]] = arith.addi %[[scaled_i]], %[[orig_lb_i]]
diff --git a/mlir/test/Dialect/SCF/transform-ops.mlir b/mlir/test/Dialect/SCF/transform-ops.mlir
index a4daa86583c3d..66b6ecfcf3ff2 100644
--- a/mlir/test/Dialect/SCF/transform-ops.mlir
+++ b/mlir/test/Dialect/SCF/transform-ops.mlir
@@ -277,13 +277,22 @@ module attributes {transform.with_named_sequence} {
 
 // This test checks for loop coalescing success for non-index loop boundaries and step type
 func.func @coalesce_i32_loops() {
+  // CHECK:           %[[VAL_0:.*]] = arith.constant 0 : i32
+  // CHECK:           %[[VAL_1:.*]] = arith.constant 128 : i32
+  // CHECK:           %[[VAL_2:.*]] = arith.constant 2 : i32
+  // CHECK:           %[[VAL_3:.*]] = arith.constant 64 : i32
   %0 = arith.constant 0 : i32
   %1 = arith.constant 128 : i32
   %2 = arith.constant 2 : i32
   %3 = arith.constant 64 : i32
-  // CHECK-DAG: %[[C0_I32:.*]] = arith.constant 0 : i32
-  // CHECK-DAG: %[[C1_I32:.*]] = arith.constant 1 : i32
-  // CHECK: scf.for %[[ARG0:.*]] = %[[C0_I32]] to {{.*}} step %[[C1_I32]]  : i32
+  // CHECK:           %[[VAL_4:.*]] = arith.constant 64 : i32
+  // CHECK:           %[[ZERO:.*]] = arith.constant 0 : i32
+  // CHECK:           %[[ONE:.*]] = arith.constant 1 : i32
+  // CHECK:           %[[VAL_7:.*]] = arith.constant 32 : i32
+  // CHECK:           %[[VAL_8:.*]] = arith.constant 0 : i32
+  // CHECK:           %[[VAL_9:.*]] = arith.constant 1 : i32
+  // CHECK:           %[[UB:.*]] = arith.muli %[[VAL_4]], %[[VAL_7]] : i32
+  // CHECK:           scf.for %[[VAL_11:.*]] = %[[ZERO]] to %[[UB]] step %[[ONE]]  : i32 {
   scf.for %i = %0 to %1 step %2 : i32 {
     scf.for %j = %0 to %3 step %2 : i32 {
       arith.addi %i, %j : i32



More information about the Mlir-commits mailing list