[Mlir-commits] [mlir] 2ecb1ab - [mlir][scf]: Removed LoopParams struct and used Range instead (NFC) (#95501)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jun 14 11:56:20 PDT 2024


Author: Aviad Cohen
Date: 2024-06-14T21:56:17+03:00
New Revision: 2ecb1ab6d701b6b4ec451f2c402c80c9fb9dcb14

URL: https://github.com/llvm/llvm-project/commit/2ecb1ab6d701b6b4ec451f2c402c80c9fb9dcb14
DIFF: https://github.com/llvm/llvm-project/commit/2ecb1ab6d701b6b4ec451f2c402c80c9fb9dcb14.diff

LOG: [mlir][scf]: Removed LoopParams struct and used Range instead (NFC) (#95501)

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SCF/Utils/Utils.h
    mlir/lib/Dialect/SCF/Utils/Utils.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
index f719c00213987..da3fe3ceb86be 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
@@ -120,14 +120,6 @@ LogicalResult loopUnrollByFactor(
     scf::ForOp forOp, uint64_t unrollFactor,
     function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn = nullptr);
 
-/// This structure is to pass and return sets of loop parameters without
-/// confusing the order.
-struct LoopParams {
-  OpFoldResult lowerBound;
-  OpFoldResult upperBound;
-  OpFoldResult step;
-};
-
 /// Transform a loop with a strictly positive step
 ///   for %i = %lb to %ub step %s
 /// into a 0-based loop with step 1
@@ -137,9 +129,9 @@ struct LoopParams {
 /// expected to be either `loop` or another loop perfectly nested under `loop`.
 /// Insert the definition of new bounds immediate before `outer`, which is
 /// expected to be either `loop` or its parent in the loop nest.
-LoopParams emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc,
-                                    OpFoldResult lb, OpFoldResult ub,
-                                    OpFoldResult step);
+Range emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc,
+                               OpFoldResult lb, OpFoldResult ub,
+                               OpFoldResult step);
 
 /// Get back the original induction variable values after loop normalization.
 void denormalizeInductionVariable(RewriterBase &rewriter, Location loc,

diff  --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index a031e53fe0ffb..ff5e3a002263d 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -464,9 +464,9 @@ LogicalResult mlir::loopUnrollByFactor(
   return success();
 }
 
-LoopParams mlir::emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc,
-                                          OpFoldResult lb, OpFoldResult ub,
-                                          OpFoldResult step) {
+Range mlir::emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc,
+                                     OpFoldResult lb, OpFoldResult ub,
+                                     OpFoldResult step) {
   // For non-index types, generate `arith` instructions
   // Check if the loop is already known to have a constant zero lower bound or
   // a constant one step.
@@ -478,8 +478,8 @@ LoopParams mlir::emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc,
   if (auto stepCst = getConstantIntValue(step))
     isStepOne = stepCst.value() == 1;
 
-  Type loopParamsType = getType(lb);
-  assert(loopParamsType == getType(ub) && loopParamsType == getType(step) &&
+  Type rangeType = getType(lb);
+  assert(rangeType == getType(ub) && rangeType == getType(step) &&
          "expected matching types");
 
   // Compute the number of iterations the loop executes: ceildiv(ub - lb, step)
@@ -501,8 +501,8 @@ LoopParams mlir::emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc,
         getValueOrCreateConstantIntOp(rewriter, loc, step));
   }
 
-  OpFoldResult newLowerBound = rewriter.getZeroAttr(loopParamsType);
-  OpFoldResult newStep = rewriter.getOneAttr(loopParamsType);
+  OpFoldResult newLowerBound = rewriter.getZeroAttr(rangeType);
+  OpFoldResult newStep = rewriter.getOneAttr(rangeType);
 
   return {newLowerBound, newUpperBound, newStep};
 }
@@ -626,18 +626,17 @@ LogicalResult mlir::coalesceLoops(RewriterBase &rewriter,
     Value lb = loop.getLowerBound();
     Value ub = loop.getUpperBound();
     Value step = loop.getStep();
-    auto newLoopParams =
+    auto newLoopRange =
         emitNormalizedLoopBounds(rewriter, loop.getLoc(), lb, ub, step);
 
     rewriter.modifyOpInPlace(loop, [&]() {
-      loop.setLowerBound(getValueOrCreateConstantIntOp(
-          rewriter, loop.getLoc(), newLoopParams.lowerBound));
-      loop.setUpperBound(getValueOrCreateConstantIntOp(
-          rewriter, loop.getLoc(), newLoopParams.upperBound));
+      loop.setLowerBound(getValueOrCreateConstantIntOp(rewriter, loop.getLoc(),
+                                                       newLoopRange.offset));
+      loop.setUpperBound(getValueOrCreateConstantIntOp(rewriter, loop.getLoc(),
+                                                       newLoopRange.size));
       loop.setStep(getValueOrCreateConstantIntOp(rewriter, loop.getLoc(),
-                                                 newLoopParams.step));
+                                                 newLoopRange.stride));
     });
-
     rewriter.setInsertionPointToStart(innermost.getBody());
     denormalizeInductionVariable(rewriter, loop.getLoc(),
                                  loop.getInductionVar(), lb, step);
@@ -780,9 +779,9 @@ void mlir::collapseParallelLoops(
     Value lb = loops.getLowerBound()[i];
     Value ub = loops.getUpperBound()[i];
     Value step = loops.getStep()[i];
-    auto newLoopParams = emitNormalizedLoopBounds(rewriter, loc, lb, ub, step);
+    auto newLoopRange = emitNormalizedLoopBounds(rewriter, loc, lb, ub, step);
     normalizedUpperBounds.push_back(getValueOrCreateConstantIntOp(
-        rewriter, loops.getLoc(), newLoopParams.upperBound));
+        rewriter, loops.getLoc(), newLoopRange.size));
 
     rewriter.setInsertionPointToStart(loops.getBody());
     denormalizeInductionVariable(rewriter, loc, loops.getInductionVars()[i], lb,


        


More information about the Mlir-commits mailing list