[Mlir-commits] [mlir] [mlir][SCF] scf.for: Consistent API around `initArgs` (PR #66512)

Matthias Springer llvmlistbot at llvm.org
Mon Sep 18 00:08:19 PDT 2023


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/66512

>From 9d1ba8cce84366113a4a72eb5c36075b0b74996d Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Mon, 18 Sep 2023 09:07:37 +0200
Subject: [PATCH] [mlir][SCF] scf.for: Consistent API around `initArgs` API

* Always use the auto-generated `getInitArgs` function. Remove the hand-written `getInitOperands` duplicate.
* Remove `hasIterOperands` and `getNumIterOperands`. The names were inconsistent because the "arg" is called `initArgs` in TableGen. Use `getInitArgs().size()` instead.
* Fix verification around ops with no results.

BEGIN_PUBLIC
No public commit message needed for presubmit.
END_PUBLIC
---
 mlir/include/mlir/Dialect/SCF/IR/SCFOps.td    | 11 -----
 .../SCFToControlFlow/SCFToControlFlow.cpp     |  3 +-
 .../Conversion/VectorToGPU/VectorToGPU.cpp    | 14 +++---
 mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp |  2 +-
 mlir/lib/Dialect/SCF/IR/SCF.cpp               | 45 ++++++++-----------
 .../SCF/Transforms/LoopCanonicalization.cpp   | 21 +++++----
 mlir/lib/Dialect/SCF/Utils/Utils.cpp          |  6 +--
 .../Vector/Transforms/VectorDistribute.cpp    |  2 +-
 mlir/lib/Target/Cpp/TranslateToCpp.cpp        |  2 +-
 mlir/test/Dialect/SCF/invalid.mlir            | 13 ++++++
 10 files changed, 55 insertions(+), 64 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 232e6b0bf4ed772..6d8aaf64e3263b9 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -250,9 +250,6 @@ def ForOp : SCF_Op<"for",
         "expected an index less than the number of region iter args");
       return getBody()->getArguments().drop_front(getNumInductionVars())[index];
     }
-    Operation::operand_range getIterOperands() {
-      return getOperands().drop_front(getNumControlOperands());
-    }
     MutableArrayRef<OpOperand> getIterOpOperands() {
       return
         getOperation()->getOpOperands().drop_front(getNumControlOperands());
@@ -273,14 +270,6 @@ def ForOp : SCF_Op<"for",
     }
     /// Number of operands controlling the loop: lb, ub, step
     unsigned getNumControlOperands() { return 3; }
-    /// Does the operation hold operands for loop-carried values
-    bool hasIterOperands() {
-      return getOperation()->getNumOperands() > getNumControlOperands();
-    }
-    /// Get Number of loop-carried values
-    unsigned getNumIterOperands() {
-      return getOperation()->getNumOperands() - getNumControlOperands();
-    }
     /// Get the iter arg number for an operand. If it isnt an iter arg
     /// operand return std::nullopt.
     std::optional<unsigned> getIterArgNumberForOpOperand(OpOperand &opOperand) {
diff --git a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
index f5face5929916ae..c9b45fd4a7957b8 100644
--- a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
+++ b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
@@ -361,8 +361,7 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
   // of the loop operation.
   SmallVector<Value, 8> destOperands;
   destOperands.push_back(lowerBound);
-  auto iterOperands = forOp.getIterOperands();
-  destOperands.append(iterOperands.begin(), iterOperands.end());
+  llvm::append_range(destOperands, forOp.getInitArgs());
   rewriter.create<cf::BranchOp>(loc, conditionBlock, destOperands);
 
   // With the body block done, we can fill in the condition block.
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index 3089e917d0eed9c..c8871c945cbe759 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -1106,14 +1106,14 @@ convertBroadcastOp(RewriterBase &rewriter, vector::BroadcastOp op,
 // updated and needs to be updated separatly for the loop to be correct.
 static scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter,
                                                scf::ForOp loop,
-                                               ValueRange newIterOperands) {
+                                               ValueRange newInitArgs) {
   OpBuilder::InsertionGuard g(rewriter);
   rewriter.setInsertionPoint(loop);
 
   // Create a new loop before the existing one, with the extra operands.
   rewriter.setInsertionPoint(loop);
-  auto operands = llvm::to_vector<4>(loop.getIterOperands());
-  operands.append(newIterOperands.begin(), newIterOperands.end());
+  auto operands = llvm::to_vector<4>(loop.getInitArgs());
+  llvm::append_range(operands, newInitArgs);
   scf::ForOp newLoop = rewriter.create<scf::ForOp>(
       loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(),
       operands);
@@ -1122,7 +1122,7 @@ static scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter,
   newLoop.getLoopBody().getBlocks().splice(
       newLoop.getLoopBody().getBlocks().begin(),
       loop.getLoopBody().getBlocks());
-  for (Value operand : newIterOperands)
+  for (Value operand : newInitArgs)
     newLoop.getBody()->addArgument(operand.getType(), operand.getLoc());
 
   for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front(
@@ -1144,14 +1144,14 @@ static LogicalResult convertForOp(RewriterBase &rewriter, scf::ForOp op,
 
   SmallVector<Value> newOperands;
   SmallVector<std::pair<size_t, size_t>> argMapping;
-  for (const auto &operand : llvm::enumerate(op.getIterOperands())) {
+  for (const auto &operand : llvm::enumerate(op.getInitArgs())) {
     auto it = valueMapping.find(operand.value());
     if (it == valueMapping.end()) {
       LLVM_DEBUG(DBGS() << "no value mapping for: " << operand.value() << "\n");
       continue;
     }
     argMapping.push_back(std::make_pair(
-        operand.index(), op.getNumIterOperands() + newOperands.size()));
+        operand.index(), op.getInitArgs().size() + newOperands.size()));
     newOperands.push_back(it->second);
   }
 
@@ -1183,7 +1183,7 @@ convertYieldOp(RewriterBase &rewriter, scf::YieldOp op,
       continue;
     // Replace the yield of old value with the for op argument to make it easier
     // to remove the dead code.
-    yieldOperands[operand.index()] = loop.getIterOperands()[operand.index()];
+    yieldOperands[operand.index()] = loop.getInitArgs()[operand.index()];
     yieldOperands.push_back(it->second);
   }
   rewriter.create<scf::YieldOp>(op.getLoc(), yieldOperands);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index cf3fd4ba0a0b5dc..ae0461965c4785c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -225,7 +225,7 @@ static void getProducerOfTensor(Value tensor, OpResult &opResult) {
     }
     if (auto blockArg = dyn_cast<BlockArgument>(tensor)) {
       if (auto forOp = blockArg.getDefiningOp<scf::ForOp>()) {
-        tensor = *(forOp.getIterOperands().begin() + blockArg.getArgNumber());
+        tensor = forOp.getInitArgs()[blockArg.getArgNumber()];
         continue;
       }
     }
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index ce413b283730407..5565aefbad18db5 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -343,15 +343,11 @@ LogicalResult ForOp::verify() {
   if (matchPattern(getStep(), m_Constant(&step)) && step.getInt() <= 0)
     return emitOpError("constant step operand must be positive");
 
-  auto opNumResults = getNumResults();
-  if (opNumResults == 0)
-    return success();
-  // If ForOp defines values, check that the number and types of
-  // the defined values match ForOp initial iter operands and backedge
-  // basic block arguments.
-  if (getNumIterOperands() != opNumResults)
+  // Check that the number of init args and op results is the same.
+  if (getInitArgs().size() != getNumResults())
     return emitOpError(
         "mismatch in number of loop-carried values and defined values");
+
   return success();
 }
 
@@ -362,19 +358,15 @@ LogicalResult ForOp::verifyRegions() {
     return emitOpError(
         "expected induction variable to be same type as bounds and step");
 
-  auto opNumResults = getNumResults();
-  if (opNumResults == 0)
-    return success();
-
-  if (getNumRegionIterArgs() != opNumResults)
+  if (getNumRegionIterArgs() != getNumResults())
     return emitOpError(
         "mismatch in number of basic block args and defined values");
 
-  auto iterOperands = getIterOperands();
+  auto initArgs = getInitArgs();
   auto iterArgs = getRegionIterArgs();
   auto opResults = getResults();
   unsigned i = 0;
-  for (auto e : llvm::zip(iterOperands, iterArgs, opResults)) {
+  for (auto e : llvm::zip(initArgs, iterArgs, opResults)) {
     if (std::get<0>(e).getType() != std::get<2>(e).getType())
       return emitOpError() << "types mismatch between " << i
                            << "th iter operand and defined value";
@@ -419,7 +411,7 @@ LogicalResult ForOp::promoteIfSingleIteration(RewriterBase &rewriter) {
   // iter_args.
   SmallVector<Value> bbArgReplacements;
   bbArgReplacements.push_back(getLowerBound());
-  bbArgReplacements.append(getIterOperands().begin(), getIterOperands().end());
+  llvm::append_range(bbArgReplacements, getInitArgs());
 
   // Move the loop body operations to the loop's containing block.
   rewriter.inlineBlockBefore(getBody(), getOperation()->getBlock(),
@@ -456,16 +448,15 @@ void ForOp::print(OpAsmPrinter &p) {
   p << " " << getInductionVar() << " = " << getLowerBound() << " to "
     << getUpperBound() << " step " << getStep();
 
-  printInitializationList(p, getRegionIterArgs(), getIterOperands(),
-                          " iter_args");
-  if (!getIterOperands().empty())
-    p << " -> (" << getIterOperands().getTypes() << ')';
+  printInitializationList(p, getRegionIterArgs(), getInitArgs(), " iter_args");
+  if (!getInitArgs().empty())
+    p << " -> (" << getInitArgs().getTypes() << ')';
   p << ' ';
   if (Type t = getInductionVar().getType(); !t.isIndex())
     p << " : " << t << ' ';
   p.printRegion(getRegion(),
                 /*printEntryBlockArgs=*/false,
-                /*printBlockTerminators=*/hasIterOperands());
+                /*printBlockTerminators=*/!getInitArgs().empty());
   p.printOptionalAttrDict((*this)->getAttrs());
 }
 
@@ -751,12 +742,12 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
     keepMask.reserve(yieldOp.getNumOperands());
     SmallVector<Value, 4> newBlockTransferArgs, newIterArgs, newYieldValues,
         newResultValues;
-    newBlockTransferArgs.reserve(1 + forOp.getNumIterOperands());
+    newBlockTransferArgs.reserve(1 + forOp.getInitArgs().size());
     newBlockTransferArgs.push_back(Value()); // iv placeholder with null value
-    newIterArgs.reserve(forOp.getNumIterOperands());
+    newIterArgs.reserve(forOp.getInitArgs().size());
     newYieldValues.reserve(yieldOp.getNumOperands());
     newResultValues.reserve(forOp.getNumResults());
-    for (auto it : llvm::zip(forOp.getIterOperands(),   // iter from outside
+    for (auto it : llvm::zip(forOp.getInitArgs(),       // iter from outside
                              forOp.getRegionIterArgs(), // iter inside region
                              forOp.getResults(),        // op results
                              yieldOp.getOperands()      // iter yield
@@ -876,7 +867,7 @@ struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
     // If the upper bound is the same as the lower bound, the loop does not
     // iterate, just remove it.
     if (op.getLowerBound() == op.getUpperBound()) {
-      rewriter.replaceOp(op, op.getIterOperands());
+      rewriter.replaceOp(op, op.getInitArgs());
       return success();
     }
 
@@ -887,7 +878,7 @@ struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
 
     // If the loop is known to have 0 iterations, remove it.
     if (*diff <= 0) {
-      rewriter.replaceOp(op, op.getIterOperands());
+      rewriter.replaceOp(op, op.getInitArgs());
       return success();
     }
 
@@ -900,9 +891,9 @@ struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
     llvm::APInt stepValue = *maybeStepValue;
     if (stepValue.sge(*diff)) {
       SmallVector<Value, 4> blockArgs;
-      blockArgs.reserve(op.getNumIterOperands() + 1);
+      blockArgs.reserve(op.getInitArgs().size() + 1);
       blockArgs.push_back(op.getLowerBound());
-      llvm::append_range(blockArgs, op.getIterOperands());
+      llvm::append_range(blockArgs, op.getInitArgs());
       replaceOpWithRegion(rewriter, op, op.getLoopBody(), blockArgs);
       return success();
     }
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
index 1da10ddd6371f42..0cd19fbefa8ef98 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
@@ -48,16 +48,15 @@ static bool isShapePreserving(ForOp forOp, int64_t arg) {
       return false;
 
     using tensor::InsertSliceOp;
-    value =
-        llvm::TypeSwitch<Operation *, Value>(opResult.getOwner())
-            .template Case<InsertSliceOp>(
-                [&](InsertSliceOp op) { return op.getDest(); })
-            .template Case<ForOp>([&](ForOp forOp) {
-              return isShapePreserving(forOp, opResult.getResultNumber())
-                         ? forOp.getIterOperands()[opResult.getResultNumber()]
-                         : Value();
-            })
-            .Default([&](auto op) { return Value(); });
+    value = llvm::TypeSwitch<Operation *, Value>(opResult.getOwner())
+                .template Case<InsertSliceOp>(
+                    [&](InsertSliceOp op) { return op.getDest(); })
+                .template Case<ForOp>([&](ForOp forOp) {
+                  return isShapePreserving(forOp, opResult.getResultNumber())
+                             ? forOp.getInitArgs()[opResult.getResultNumber()]
+                             : Value();
+                })
+                .Default([&](auto op) { return Value(); });
   }
   return false;
 }
@@ -144,7 +143,7 @@ struct DimOfLoopResultFolder : public OpRewritePattern<OpTy> {
     if (!isShapePreserving(forOp, resultNumber))
       return failure();
     rewriter.updateRootInPlace(dimOp, [&]() {
-      dimOp.getSourceMutable().assign(forOp.getIterOperands()[resultNumber]);
+      dimOp.getSourceMutable().assign(forOp.getInitArgs()[resultNumber]);
     });
     return success();
   }
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 9ac751f1915ab14..222a9aa395c4f09 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -46,8 +46,8 @@ mlir::replaceLoopWithNewYields(OpBuilder &builder, scf::ForOp loop,
   // Create a new loop before the existing one, with the extra operands.
   OpBuilder::InsertionGuard g(builder);
   builder.setInsertionPoint(loop);
-  auto operands = llvm::to_vector(loop.getIterOperands());
-  operands.append(newIterOperands.begin(), newIterOperands.end());
+  auto operands = llvm::to_vector(loop.getInitArgs());
+  llvm::append_range(operands, newIterOperands);
   scf::ForOp newLoop = builder.create<scf::ForOp>(
       loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(),
       operands, [](OpBuilder &, Location, Value, ValueRange) {});
@@ -515,7 +515,7 @@ LogicalResult mlir::loopUnrollByFactor(
       std::get<0>(e).replaceAllUsesWith(std::get<1>(e));
     }
     epilogueForOp->setOperands(epilogueForOp.getNumControlOperands(),
-                               epilogueForOp.getNumIterOperands(), results);
+                               epilogueForOp.getInitArgs().size(), results);
     (void)epilogueForOp.promoteIfSingleIteration(rewriter);
   }
 
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 05b5ff09321489f..2a50947e976dffb 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1491,7 +1491,7 @@ struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> {
       auto forResult = cast<OpResult>(yieldOperand.get());
       newOperands.push_back(
           newWarpOp.getResult(yieldOperand.getOperandNumber()));
-      yieldOperand.set(forOp.getIterOperands()[forResult.getResultNumber()]);
+      yieldOperand.set(forOp.getInitArgs()[forResult.getResultNumber()]);
       resultIdx.push_back(yieldOperand.getOperandNumber());
     }
 
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index 91a4db9cb8be8ec..832dd8f2013fa4d 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -493,7 +493,7 @@ static LogicalResult printOperation(CppEmitter &emitter, scf::ForOp forOp) {
 
   raw_indented_ostream &os = emitter.ostream();
 
-  OperandRange operands = forOp.getIterOperands();
+  OperandRange operands = forOp.getInitArgs();
   Block::BlockArgListType iterArgs = forOp.getRegionIterArgs();
   Operation::result_range results = forOp.getResults();
 
diff --git a/mlir/test/Dialect/SCF/invalid.mlir b/mlir/test/Dialect/SCF/invalid.mlir
index 0cf587af42637c9..f6044ad10829227 100644
--- a/mlir/test/Dialect/SCF/invalid.mlir
+++ b/mlir/test/Dialect/SCF/invalid.mlir
@@ -83,6 +83,19 @@ func.func @loop_for_single_index_argument(%arg0: index) {
 
 // -----
 
+func.func @not_enough_loop_results(%arg0: index, %init: f32) {
+  // expected-error @below{{mismatch in number of loop-carried values and defined values}}
+  "scf.for"(%arg0, %arg0, %arg0, %init) (
+    {
+    ^bb0(%i0 : index, %iter: f32):
+      scf.yield %iter : f32
+    }
+  ) : (index, index, index, f32) -> ()
+  return
+}
+
+// -----
+
 func.func @loop_if_not_i1(%arg0: index) {
   // expected-error at +1 {{operand #0 must be 1-bit signless integer}}
   "scf.if"(%arg0) ({}, {}) : (index) -> ()



More information about the Mlir-commits mailing list