[Mlir-commits] [mlir] [mlir][scf]: Removed LoopParams struct and used Range instead (NFC) (PR #95501)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jun 13 21:09:40 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Aviad Cohen (AviadCo)

<details>
<summary>Changes</summary>



---
Full diff: https://github.com/llvm/llvm-project/pull/95501.diff


2 Files Affected:

- (modified) mlir/include/mlir/Dialect/SCF/Utils/Utils.h (+3-11) 
- (modified) mlir/lib/Dialect/SCF/Utils/Utils.cpp (+15-16) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
index f719c00213987..da3fe3ceb86be 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
@@ -120,14 +120,6 @@ LogicalResult loopUnrollByFactor(
     scf::ForOp forOp, uint64_t unrollFactor,
     function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn = nullptr);
 
-/// This structure is to pass and return sets of loop parameters without
-/// confusing the order.
-struct LoopParams {
-  OpFoldResult lowerBound;
-  OpFoldResult upperBound;
-  OpFoldResult step;
-};
-
 /// Transform a loop with a strictly positive step
 ///   for %i = %lb to %ub step %s
 /// into a 0-based loop with step 1
@@ -137,9 +129,9 @@ struct LoopParams {
 /// expected to be either `loop` or another loop perfectly nested under `loop`.
 /// Insert the definition of new bounds immediate before `outer`, which is
 /// expected to be either `loop` or its parent in the loop nest.
-LoopParams emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc,
-                                    OpFoldResult lb, OpFoldResult ub,
-                                    OpFoldResult step);
+Range emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc,
+                               OpFoldResult lb, OpFoldResult ub,
+                               OpFoldResult step);
 
 /// Get back the original induction variable values after loop normalization.
 void denormalizeInductionVariable(RewriterBase &rewriter, Location loc,
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index a031e53fe0ffb..ff5e3a002263d 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -464,9 +464,9 @@ LogicalResult mlir::loopUnrollByFactor(
   return success();
 }
 
-LoopParams mlir::emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc,
-                                          OpFoldResult lb, OpFoldResult ub,
-                                          OpFoldResult step) {
+Range mlir::emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc,
+                                     OpFoldResult lb, OpFoldResult ub,
+                                     OpFoldResult step) {
   // For non-index types, generate `arith` instructions
   // Check if the loop is already known to have a constant zero lower bound or
   // a constant one step.
@@ -478,8 +478,8 @@ LoopParams mlir::emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc,
   if (auto stepCst = getConstantIntValue(step))
     isStepOne = stepCst.value() == 1;
 
-  Type loopParamsType = getType(lb);
-  assert(loopParamsType == getType(ub) && loopParamsType == getType(step) &&
+  Type rangeType = getType(lb);
+  assert(rangeType == getType(ub) && rangeType == getType(step) &&
          "expected matching types");
 
   // Compute the number of iterations the loop executes: ceildiv(ub - lb, step)
@@ -501,8 +501,8 @@ LoopParams mlir::emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc,
         getValueOrCreateConstantIntOp(rewriter, loc, step));
   }
 
-  OpFoldResult newLowerBound = rewriter.getZeroAttr(loopParamsType);
-  OpFoldResult newStep = rewriter.getOneAttr(loopParamsType);
+  OpFoldResult newLowerBound = rewriter.getZeroAttr(rangeType);
+  OpFoldResult newStep = rewriter.getOneAttr(rangeType);
 
   return {newLowerBound, newUpperBound, newStep};
 }
@@ -626,18 +626,17 @@ LogicalResult mlir::coalesceLoops(RewriterBase &rewriter,
     Value lb = loop.getLowerBound();
     Value ub = loop.getUpperBound();
     Value step = loop.getStep();
-    auto newLoopParams =
+    auto newLoopRange =
         emitNormalizedLoopBounds(rewriter, loop.getLoc(), lb, ub, step);
 
     rewriter.modifyOpInPlace(loop, [&]() {
-      loop.setLowerBound(getValueOrCreateConstantIntOp(
-          rewriter, loop.getLoc(), newLoopParams.lowerBound));
-      loop.setUpperBound(getValueOrCreateConstantIntOp(
-          rewriter, loop.getLoc(), newLoopParams.upperBound));
+      loop.setLowerBound(getValueOrCreateConstantIntOp(rewriter, loop.getLoc(),
+                                                       newLoopRange.offset));
+      loop.setUpperBound(getValueOrCreateConstantIntOp(rewriter, loop.getLoc(),
+                                                       newLoopRange.size));
       loop.setStep(getValueOrCreateConstantIntOp(rewriter, loop.getLoc(),
-                                                 newLoopParams.step));
+                                                 newLoopRange.stride));
     });
-
     rewriter.setInsertionPointToStart(innermost.getBody());
     denormalizeInductionVariable(rewriter, loop.getLoc(),
                                  loop.getInductionVar(), lb, step);
@@ -780,9 +779,9 @@ void mlir::collapseParallelLoops(
     Value lb = loops.getLowerBound()[i];
     Value ub = loops.getUpperBound()[i];
     Value step = loops.getStep()[i];
-    auto newLoopParams = emitNormalizedLoopBounds(rewriter, loc, lb, ub, step);
+    auto newLoopRange = emitNormalizedLoopBounds(rewriter, loc, lb, ub, step);
     normalizedUpperBounds.push_back(getValueOrCreateConstantIntOp(
-        rewriter, loops.getLoc(), newLoopParams.upperBound));
+        rewriter, loops.getLoc(), newLoopRange.size));
 
     rewriter.setInsertionPointToStart(loops.getBody());
     denormalizeInductionVariable(rewriter, loc, loops.getInductionVars()[i], lb,

``````````

</details>


https://github.com/llvm/llvm-project/pull/95501


More information about the Mlir-commits mailing list