[Mlir-commits] [mlir] [mlir][SCF] scf.for: Consistent API around `initArgs` API (PR #66512)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Sep 15 06:31:39 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-gpu
            
<details>
<summary>Changes</summary>
* Always use the auto-generated `getInitArgs` function. Remove the hand-written `getInitOperands` duplicate.
* Remove `hasIterOperands` and `getNumIterOperands`. The names were inconsistent because the "arg" is called `initArgs` in TableGen. Use `getInitArgs().size()` instead.
* Fix verification around ops with no results.
--
Full diff: https://github.com/llvm/llvm-project/pull/66512.diff

9 Files Affected:

- (modified) mlir/include/mlir/Dialect/SCF/IR/SCFOps.td (-11) 
- (modified) mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp (+1-2) 
- (modified) mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp (+7-7) 
- (modified) mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp (+1-1) 
- (modified) mlir/lib/Dialect/SCF/IR/SCF.cpp (+18-27) 
- (modified) mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp (+10-11) 
- (modified) mlir/lib/Dialect/SCF/Utils/Utils.cpp (+2-2) 
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp (+1-1) 
- (modified) mlir/lib/Target/Cpp/TranslateToCpp.cpp (+1-1) 


<pre>
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 232e6b0bf4ed772..6d8aaf64e3263b9 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -250,9 +250,6 @@ def ForOp : SCF_Op&lt;&quot;for&quot;,
         &quot;expected an index less than the number of region iter args&quot;);
       return getBody()-&gt;getArguments().drop_front(getNumInductionVars())[index];
     }
-    Operation::operand_range getIterOperands() {
-      return getOperands().drop_front(getNumControlOperands());
-    }
     MutableArrayRef&lt;OpOperand&gt; getIterOpOperands() {
       return
         getOperation()-&gt;getOpOperands().drop_front(getNumControlOperands());
@@ -273,14 +270,6 @@ def ForOp : SCF_Op&lt;&quot;for&quot;,
     }
     /// Number of operands controlling the loop: lb, ub, step
     unsigned getNumControlOperands() { return 3; }
-    /// Does the operation hold operands for loop-carried values
-    bool hasIterOperands() {
-      return getOperation()-&gt;getNumOperands() &gt; getNumControlOperands();
-    }
-    /// Get Number of loop-carried values
-    unsigned getNumIterOperands() {
-      return getOperation()-&gt;getNumOperands() - getNumControlOperands();
-    }
     /// Get the iter arg number for an operand. If it isnt an iter arg
     /// operand return std::nullopt.
     std::optional&lt;unsigned&gt; getIterArgNumberForOpOperand(OpOperand &amp;opOperand) {
diff --git a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
index f5face5929916ae..c18cc1d835f4833 100644
--- a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
+++ b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
@@ -361,8 +361,7 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
   // of the loop operation.
   SmallVector&lt;Value, 8&gt; destOperands;
   destOperands.push_back(lowerBound);
-  auto iterOperands = forOp.getIterOperands();
-  destOperands.append(iterOperands.begin(), iterOperands.end());
+  destOperands.append(forOp.getInitArgs().begin(), forOp.getInitArgs().end());
   rewriter.create&lt;cf::BranchOp&gt;(loc, conditionBlock, destOperands);
 
   // With the body block done, we can fill in the condition block.
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index 8a46357acd7bf1f..d659a1279158638 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -1107,14 +1107,14 @@ convertBroadcastOp(RewriterBase &amp;rewriter, vector::BroadcastOp op,
 // updated and needs to be updated separatly for the loop to be correct.
 static scf::ForOp replaceForOpWithNewSignature(RewriterBase &amp;rewriter,
                                                scf::ForOp loop,
-                                               ValueRange newIterOperands) {
+                                               ValueRange newInitArgs) {
   OpBuilder::InsertionGuard g(rewriter);
   rewriter.setInsertionPoint(loop);
 
   // Create a new loop before the existing one, with the extra operands.
   rewriter.setInsertionPoint(loop);
-  auto operands = llvm::to_vector&lt;4&gt;(loop.getIterOperands());
-  operands.append(newIterOperands.begin(), newIterOperands.end());
+  auto operands = llvm::to_vector&lt;4&gt;(loop.getInitArgs());
+  operands.append(newInitArgs.begin(), newInitArgs.end());
   scf::ForOp newLoop = rewriter.create&lt;scf::ForOp&gt;(
       loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(),
       operands);
@@ -1123,7 +1123,7 @@ static scf::ForOp replaceForOpWithNewSignature(RewriterBase &amp;rewriter,
   newLoop.getLoopBody().getBlocks().splice(
       newLoop.getLoopBody().getBlocks().begin(),
       loop.getLoopBody().getBlocks());
-  for (Value operand : newIterOperands)
+  for (Value operand : newInitArgs)
     newLoop.getBody()-&gt;addArgument(operand.getType(), operand.getLoc());
 
   for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front(
@@ -1145,14 +1145,14 @@ static LogicalResult convertForOp(RewriterBase &amp;rewriter, scf::ForOp op,
 
   SmallVector&lt;Value&gt; newOperands;
   SmallVector&lt;std::pair&lt;size_t, size_t&gt;&gt; argMapping;
-  for (const auto &amp;operand : llvm::enumerate(op.getIterOperands())) {
+  for (const auto &amp;operand : llvm::enumerate(op.getInitArgs())) {
     auto it = valueMapping.find(operand.value());
     if (it == valueMapping.end()) {
       LLVM_DEBUG(DBGS() &lt;&lt; &quot;no value mapping for: &quot; &lt;&lt; operand.value() &lt;&lt; &quot;\n&quot;);
       continue;
     }
     argMapping.push_back(std::make_pair(
-        operand.index(), op.getNumIterOperands() + newOperands.size()));
+        operand.index(), op.getInitArgs().size() + newOperands.size()));
     newOperands.push_back(it-&gt;second);
   }
 
@@ -1184,7 +1184,7 @@ convertYieldOp(RewriterBase &amp;rewriter, scf::YieldOp op,
       continue;
     // Replace the yield of old value with the for op argument to make it easier
     // to remove the dead code.
-    yieldOperands[operand.index()] = loop.getIterOperands()[operand.index()];
+    yieldOperands[operand.index()] = loop.getInitArgs()[operand.index()];
     yieldOperands.push_back(it-&gt;second);
   }
   rewriter.create&lt;scf::YieldOp&gt;(op.getLoc(), yieldOperands);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index cf3fd4ba0a0b5dc..ae0461965c4785c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -225,7 +225,7 @@ static void getProducerOfTensor(Value tensor, OpResult &amp;opResult) {
     }
     if (auto blockArg = dyn_cast&lt;BlockArgument&gt;(tensor)) {
       if (auto forOp = blockArg.getDefiningOp&lt;scf::ForOp&gt;()) {
-        tensor = *(forOp.getIterOperands().begin() + blockArg.getArgNumber());
+        tensor = forOp.getInitArgs()[blockArg.getArgNumber()];
         continue;
       }
     }
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index ce413b283730407..f471e41f6ea54b7 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -343,15 +343,11 @@ LogicalResult ForOp::verify() {
   if (matchPattern(getStep(), m_Constant(&amp;step)) &amp;&amp; step.getInt() &lt;= 0)
     return emitOpError(&quot;constant step operand must be positive&quot;);
 
-  auto opNumResults = getNumResults();
-  if (opNumResults == 0)
-    return success();
-  // If ForOp defines values, check that the number and types of
-  // the defined values match ForOp initial iter operands and backedge
-  // basic block arguments.
-  if (getNumIterOperands() != opNumResults)
+  // Check that the number of init args and op results is the same.
+  if (getInitArgs().size() != getNumResults())
     return emitOpError(
         &quot;mismatch in number of loop-carried values and defined values&quot;);
+
   return success();
 }
 
@@ -362,19 +358,15 @@ LogicalResult ForOp::verifyRegions() {
     return emitOpError(
         &quot;expected induction variable to be same type as bounds and step&quot;);
 
-  auto opNumResults = getNumResults();
-  if (opNumResults == 0)
-    return success();
-
-  if (getNumRegionIterArgs() != opNumResults)
+  if (getNumRegionIterArgs() != getNumResults())
     return emitOpError(
         &quot;mismatch in number of basic block args and defined values&quot;);
 
-  auto iterOperands = getIterOperands();
+  auto initArgs = getInitArgs();
   auto iterArgs = getRegionIterArgs();
   auto opResults = getResults();
   unsigned i = 0;
-  for (auto e : llvm::zip(iterOperands, iterArgs, opResults)) {
+  for (auto e : llvm::zip(initArgs, iterArgs, opResults)) {
     if (std::get&lt;0&gt;(e).getType() != std::get&lt;2&gt;(e).getType())
       return emitOpError() &lt;&lt; &quot;types mismatch between &quot; &lt;&lt; i
                            &lt;&lt; &quot;th iter operand and defined value&quot;;
@@ -419,7 +411,7 @@ LogicalResult ForOp::promoteIfSingleIteration(RewriterBase &amp;rewriter) {
   // iter_args.
   SmallVector&lt;Value&gt; bbArgReplacements;
   bbArgReplacements.push_back(getLowerBound());
-  bbArgReplacements.append(getIterOperands().begin(), getIterOperands().end());
+  bbArgReplacements.append(getInitArgs().begin(), getInitArgs().end());
 
   // Move the loop body operations to the loop&#x27;s containing block.
   rewriter.inlineBlockBefore(getBody(), getOperation()-&gt;getBlock(),
@@ -456,16 +448,15 @@ void ForOp::print(OpAsmPrinter &amp;p) {
   p &lt;&lt; &quot; &quot; &lt;&lt; getInductionVar() &lt;&lt; &quot; = &quot; &lt;&lt; getLowerBound() &lt;&lt; &quot; to &quot;
     &lt;&lt; getUpperBound() &lt;&lt; &quot; step &quot; &lt;&lt; getStep();
 
-  printInitializationList(p, getRegionIterArgs(), getIterOperands(),
-                          &quot; iter_args&quot;);
-  if (!getIterOperands().empty())
-    p &lt;&lt; &quot; -&gt; (&quot; &lt;&lt; getIterOperands().getTypes() &lt;&lt; &#x27;)&#x27;;
+  printInitializationList(p, getRegionIterArgs(), getInitArgs(), &quot; iter_args&quot;);
+  if (!getInitArgs().empty())
+    p &lt;&lt; &quot; -&gt; (&quot; &lt;&lt; getInitArgs().getTypes() &lt;&lt; &#x27;)&#x27;;
   p &lt;&lt; &#x27; &#x27;;
   if (Type t = getInductionVar().getType(); !t.isIndex())
     p &lt;&lt; &quot; : &quot; &lt;&lt; t &lt;&lt; &#x27; &#x27;;
   p.printRegion(getRegion(),
                 /*printEntryBlockArgs=*/false,
-                /*printBlockTerminators=*/hasIterOperands());
+                /*printBlockTerminators=*/!getInitArgs().empty());
   p.printOptionalAttrDict((*this)-&gt;getAttrs());
 }
 
@@ -751,12 +742,12 @@ struct ForOpIterArgsFolder : public OpRewritePattern&lt;scf::ForOp&gt; {
     keepMask.reserve(yieldOp.getNumOperands());
     SmallVector&lt;Value, 4&gt; newBlockTransferArgs, newIterArgs, newYieldValues,
         newResultValues;
-    newBlockTransferArgs.reserve(1 + forOp.getNumIterOperands());
+    newBlockTransferArgs.reserve(1 + forOp.getInitArgs().size());
     newBlockTransferArgs.push_back(Value()); // iv placeholder with null value
-    newIterArgs.reserve(forOp.getNumIterOperands());
+    newIterArgs.reserve(forOp.getInitArgs().size());
     newYieldValues.reserve(yieldOp.getNumOperands());
     newResultValues.reserve(forOp.getNumResults());
-    for (auto it : llvm::zip(forOp.getIterOperands(),   // iter from outside
+    for (auto it : llvm::zip(forOp.getInitArgs(),       // iter from outside
                              forOp.getRegionIterArgs(), // iter inside region
                              forOp.getResults(),        // op results
                              yieldOp.getOperands()      // iter yield
@@ -876,7 +867,7 @@ struct SimplifyTrivialLoops : public OpRewritePattern&lt;ForOp&gt; {
     // If the upper bound is the same as the lower bound, the loop does not
     // iterate, just remove it.
     if (op.getLowerBound() == op.getUpperBound()) {
-      rewriter.replaceOp(op, op.getIterOperands());
+      rewriter.replaceOp(op, op.getInitArgs());
       return success();
     }
 
@@ -887,7 +878,7 @@ struct SimplifyTrivialLoops : public OpRewritePattern&lt;ForOp&gt; {
 
     // If the loop is known to have 0 iterations, remove it.
     if (*diff &lt;= 0) {
-      rewriter.replaceOp(op, op.getIterOperands());
+      rewriter.replaceOp(op, op.getInitArgs());
       return success();
     }
 
@@ -900,9 +891,9 @@ struct SimplifyTrivialLoops : public OpRewritePattern&lt;ForOp&gt; {
     llvm::APInt stepValue = *maybeStepValue;
     if (stepValue.sge(*diff)) {
       SmallVector&lt;Value, 4&gt; blockArgs;
-      blockArgs.reserve(op.getNumIterOperands() + 1);
+      blockArgs.reserve(op.getInitArgs().size() + 1);
       blockArgs.push_back(op.getLowerBound());
-      llvm::append_range(blockArgs, op.getIterOperands());
+      llvm::append_range(blockArgs, op.getInitArgs());
       replaceOpWithRegion(rewriter, op, op.getLoopBody(), blockArgs);
       return success();
     }
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
index 1da10ddd6371f42..0cd19fbefa8ef98 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
@@ -48,16 +48,15 @@ static bool isShapePreserving(ForOp forOp, int64_t arg) {
       return false;
 
     using tensor::InsertSliceOp;
-    value =
-        llvm::TypeSwitch&lt;Operation *, Value&gt;(opResult.getOwner())
-            .template Case&lt;InsertSliceOp&gt;(
-                [&amp;](InsertSliceOp op) { return op.getDest(); })
-            .template Case&lt;ForOp&gt;([&amp;](ForOp forOp) {
-              return isShapePreserving(forOp, opResult.getResultNumber())
-                         ? forOp.getIterOperands()[opResult.getResultNumber()]
-                         : Value();
-            })
-            .Default([&amp;](auto op) { return Value(); });
+    value = llvm::TypeSwitch&lt;Operation *, Value&gt;(opResult.getOwner())
+                .template Case&lt;InsertSliceOp&gt;(
+                    [&amp;](InsertSliceOp op) { return op.getDest(); })
+                .template Case&lt;ForOp&gt;([&amp;](ForOp forOp) {
+                  return isShapePreserving(forOp, opResult.getResultNumber())
+                             ? forOp.getInitArgs()[opResult.getResultNumber()]
+                             : Value();
+                })
+                .Default([&amp;](auto op) { return Value(); });
   }
   return false;
 }
@@ -144,7 +143,7 @@ struct DimOfLoopResultFolder : public OpRewritePattern&lt;OpTy&gt; {
     if (!isShapePreserving(forOp, resultNumber))
       return failure();
     rewriter.updateRootInPlace(dimOp, [&amp;]() {
-      dimOp.getSourceMutable().assign(forOp.getIterOperands()[resultNumber]);
+      dimOp.getSourceMutable().assign(forOp.getInitArgs()[resultNumber]);
     });
     return success();
   }
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 9ac751f1915ab14..feafbdd9f5c5404 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -46,7 +46,7 @@ mlir::replaceLoopWithNewYields(OpBuilder &amp;builder, scf::ForOp loop,
   // Create a new loop before the existing one, with the extra operands.
   OpBuilder::InsertionGuard g(builder);
   builder.setInsertionPoint(loop);
-  auto operands = llvm::to_vector(loop.getIterOperands());
+  auto operands = llvm::to_vector(loop.getInitArgs());
   operands.append(newIterOperands.begin(), newIterOperands.end());
   scf::ForOp newLoop = builder.create&lt;scf::ForOp&gt;(
       loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(),
@@ -515,7 +515,7 @@ LogicalResult mlir::loopUnrollByFactor(
       std::get&lt;0&gt;(e).replaceAllUsesWith(std::get&lt;1&gt;(e));
     }
     epilogueForOp-&gt;setOperands(epilogueForOp.getNumControlOperands(),
-                               epilogueForOp.getNumIterOperands(), results);
+                               epilogueForOp.getInitArgs().size(), results);
     (void)epilogueForOp.promoteIfSingleIteration(rewriter);
   }
 
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 05b5ff09321489f..2a50947e976dffb 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1491,7 +1491,7 @@ struct WarpOpScfForOp : public OpRewritePattern&lt;WarpExecuteOnLane0Op&gt; {
       auto forResult = cast&lt;OpResult&gt;(yieldOperand.get());
       newOperands.push_back(
           newWarpOp.getResult(yieldOperand.getOperandNumber()));
-      yieldOperand.set(forOp.getIterOperands()[forResult.getResultNumber()]);
+      yieldOperand.set(forOp.getInitArgs()[forResult.getResultNumber()]);
       resultIdx.push_back(yieldOperand.getOperandNumber());
     }
 
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index 91a4db9cb8be8ec..832dd8f2013fa4d 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -493,7 +493,7 @@ static LogicalResult printOperation(CppEmitter &amp;emitter, scf::ForOp forOp) {
 
   raw_indented_ostream &amp;os = emitter.ostream();
 
-  OperandRange operands = forOp.getIterOperands();
+  OperandRange operands = forOp.getInitArgs();
   Block::BlockArgListType iterArgs = forOp.getRegionIterArgs();
   Operation::result_range results = forOp.getResults();
 
</pre>
</details>


https://github.com/llvm/llvm-project/pull/66512


More information about the Mlir-commits mailing list