[Mlir-commits] [mlir] [mlir][Interfaces] `LoopLikeOpInterface`: Add helpers to query tied inits/iter_args (PR #70408)

Matthias Springer llvmlistbot at llvm.org
Fri Oct 27 00:15:22 PDT 2023


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/70408

>From ed8d939520bc4fa23dcd5727a49cbef4048d9bd5 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Fri, 27 Oct 2023 16:14:06 +0900
Subject: [PATCH] [mlir][Interfaces] `LoopLikeOpInterface`: Add helpers to
 query tied inits/iter_args

The `LoopLikeOpInterface` already has interface methods to query inits and iter_args. This commit adds helper functions to query tied init/iter_arg pairs and removes the corresponding functions for `scf::ForOp`.
---
 mlir/include/mlir/Dialect/SCF/IR/SCFOps.td    | 27 ++--------------
 .../mlir/Interfaces/LoopLikeInterface.td      | 31 ++++++++++++++++++-
 .../Linalg/Transforms/HoistPadding.cpp        |  2 +-
 mlir/lib/Dialect/SCF/IR/SCF.cpp               |  4 +--
 .../BufferizableOpInterfaceImpl.cpp           | 10 +++---
 .../SCF/Transforms/LoopCanonicalization.cpp   |  2 +-
 .../Dialect/SCF/Transforms/LoopPipelining.cpp |  4 +--
 .../SCF/Transforms/TileUsingInterface.cpp     |  2 +-
 8 files changed, 45 insertions(+), 37 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index fde9176c670bc6b..43beebc1bf54166 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -123,7 +123,8 @@ def ForOp : SCF_Op<"for",
       [AutomaticAllocationScope, DeclareOpInterfaceMethods<LoopLikeOpInterface,
        ["getInitsMutable", "getSingleInductionVar", "getSingleLowerBound",
         "getSingleStep", "getSingleUpperBound", "getYieldedValuesMutable",
-        "promoteIfSingleIteration", "replaceWithAdditionalYields"]>,
+        "getLoopResults", "promoteIfSingleIteration",
+        "replaceWithAdditionalYields"]>,
        AllTypesMatch<["lowerBound", "upperBound", "step"]>,
        ConditionallySpeculatable,
        DeclareOpInterfaceMethods<RegionBranchOpInterface,
@@ -268,28 +269,6 @@ def ForOp : SCF_Op<"for",
     /// Number of operands controlling the loop: lb, ub, step
     unsigned getNumControlOperands() { return 3; }
 
-    /// Get the region iter arg that corresponds to an OpOperand.
-    /// This helper prevents internal op implementation detail leakage to
-    /// clients by hiding the operand / block argument mapping.
-    BlockArgument getRegionIterArgForOpOperand(OpOperand &opOperand) {
-      assert(opOperand.getOperandNumber() >= getNumControlOperands() &&
-             "expected an iter args operand");
-      assert(opOperand.getOwner() == getOperation() &&
-             "opOperand does not belong to this scf::ForOp operation");
-      return getRegionIterArgs()[
-        opOperand.getOperandNumber() - getNumControlOperands()];
-    }
-    /// Get the OpOperand& that corresponds to a region iter arg.
-    /// This helper prevents internal op implementation detail leakage to
-    /// clients by hiding the operand / block argument mapping.
-    OpOperand &getOpOperandForRegionIterArg(BlockArgument bbArg) {
-      assert(bbArg.getArgNumber() >= getNumInductionVars() &&
-             "expected a bbArg that is not an induction variable");
-      assert(bbArg.getOwner()->getParentOp() == getOperation() &&
-             "bbArg does not belong to the scf::ForOp body");
-      return getOperation()->getOpOperand(
-        getNumControlOperands() + bbArg.getArgNumber() - getNumInductionVars());
-    }
     /// Get the OpResult that corresponds to an OpOperand.
     /// Assert that opOperand is an iterArg.
     /// This helper prevents internal op implementation detail leakage to
@@ -963,7 +942,7 @@ def WhileOp : SCF_Op<"while",
     [DeclareOpInterfaceMethods<RegionBranchOpInterface,
         ["getEntrySuccessorOperands"]>,
      DeclareOpInterfaceMethods<LoopLikeOpInterface,
-        ["getRegionIterArgs", "getYieldedValuesMutable"]>,
+        ["getLoopResults", "getRegionIterArgs", "getYieldedValuesMutable"]>,
      RecursiveMemoryEffects, SingleBlock]> {
   let summary = "a generic 'while' loop";
   let description = [{
diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.td b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
index afb7860491664da..d3d07eec8ebff57 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
@@ -238,7 +238,36 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
       unsigned firstOperandIndex = initsMutable.begin()->getOperandNumber();
       return OperandRange(
           $_op->operand_begin() + firstOperandIndex,
-          $_op->operand_begin() + firstOperandIndex + initsMutable.size());    }
+          $_op->operand_begin() + firstOperandIndex + initsMutable.size());
+    }
+
+    /// Return the region iter_arg that corresponds to the given init operand.
+    BlockArgument getTiedLoopRegionIterArg(OpOperand *opOperand) {
+      auto initsMutable = $_op.getInitsMutable();
+      auto it = llvm::find(initsMutable, *opOperand);
+      if (it == initsMutable.end())
+        return {};
+      return $_op.getRegionIterArgs()[std::distance(initsMutable.begin(), it)];
+    }
+
+    /// Return the init operand that corresponds to the given region iter_arg.
+    OpOperand *getTiedLoopInit(BlockArgument bbArg) {
+      auto iterArgs = $_op.getRegionIterArgs();
+      auto it = llvm::find(iterArgs, bbArg);
+      if (it == iterArgs.end())
+        return {};
+      return &$_op.getInitsMutable()[std::distance(iterArgs.begin(), it)];
+    }
+
+    /// Return the yielded value that corresponds to the given region iter_arg.
+    OpOperand *getTiedLoopYieldedValue(BlockArgument bbArg) {
+      auto iterArgs = $_op.getRegionIterArgs();
+      auto it = llvm::find(iterArgs, bbArg);
+      if (it == iterArgs.end())
+        return {};
+      return
+          &$_op.getYieldedValuesMutable()[std::distance(iterArgs.begin(), it)];
+    }
   }];
 
   let verifyWithRegions = 1;
diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
index 8fef99bb375095a..19f704f5232ed81 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
@@ -558,7 +558,7 @@ static FailureOr<PackingResult> buildPackingLoopNestImpl(
       break;
     if (forOp != outerLoop && !outerLoop->isAncestor(forOp))
       break;
-    OpOperand &operand = forOp.getOpOperandForRegionIterArg(bbArg);
+    OpOperand &operand = *forOp.getTiedLoopInit(bbArg);
     bvm.map(bbArg, operand.get());
     bbArg = dyn_cast<BlockArgument>(operand.get());
   }
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index cb888bc17c571fe..b8b75f3f476a5da 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -994,8 +994,8 @@ replaceTensorCastForOpIterArg(PatternRewriter &rewriter, OpOperand &operand,
   // corresponding to the `replacement` value.
   OpBuilder::InsertionGuard g(rewriter);
   rewriter.setInsertionPoint(&newBlock, newBlock.begin());
-  BlockArgument newRegionIterArg = newForOp.getRegionIterArgForOpOperand(
-      newForOp->getOpOperand(operand.getOperandNumber()));
+  BlockArgument newRegionIterArg = newForOp.getTiedLoopRegionIterArg(
+      &newForOp->getOpOperand(operand.getOperandNumber()));
   Value castIn = rewriter.create<tensor::CastOp>(newForOp.getLoc(), oldType,
                                                  newRegionIterArg);
   newBlockTransferArgs[newRegionIterArg.getArgNumber()] = castIn;
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index 455b7d8bcaff0e2..885e00b48ff8434 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -602,7 +602,7 @@ struct ForOpInterface
 
     // scf::ForOp alone doesn't bufferize to a memory read, one of the uses of
     // its matching bbArg may.
-    return state.isValueRead(forOp.getRegionIterArgForOpOperand(opOperand));
+    return state.isValueRead(forOp.getTiedLoopRegionIterArg(&opOperand));
   }
 
   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
@@ -626,7 +626,7 @@ struct ForOpInterface
     // corresponding iter_args and yield values are equivalent.
     auto forOp = cast<scf::ForOp>(op);
     OpOperand &forOperand = forOp.getOpOperandForResult(opResult);
-    auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
+    auto bbArg = forOp.getTiedLoopRegionIterArg(&forOperand);
     bool equivalentYield = state.areEquivalentBufferizedValues(
         bbArg, forOp.getYieldedValues()[opResult.getResultNumber()]);
     return equivalentYield ? BufferRelation::Equivalent
@@ -703,15 +703,15 @@ struct ForOpInterface
 
     if (auto opResult = dyn_cast<OpResult>(value)) {
       // The type of an OpResult must match the corresponding iter_arg type.
-      BlockArgument bbArg = forOp.getRegionIterArgForOpOperand(
-          forOp.getOpOperandForResult(opResult));
+      BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(
+          &forOp.getOpOperandForResult(opResult));
       return bufferization::getBufferType(bbArg, options, invocationStack);
     }
 
     // Compute result/argument number.
     BlockArgument bbArg = cast<BlockArgument>(value);
     unsigned resultNum =
-        forOp.getResultForOpOperand(forOp.getOpOperandForRegionIterArg(bbArg))
+        forOp.getResultForOpOperand(*forOp.getTiedLoopInit(bbArg))
             .getResultNumber();
 
     // Compute the bufferized type.
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
index 43e79d309c66784..eee0791b397ae68 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
@@ -98,7 +98,7 @@ struct DimOfIterArgFolder : public OpRewritePattern<OpTy> {
     if (!isShapePreserving(forOp, blockArg.getArgNumber() - 1))
       return failure();
 
-    Value initArg = forOp.getOpOperandForRegionIterArg(blockArg).get();
+    Value initArg = forOp.getTiedLoopInit(blockArg)->get();
     rewriter.updateRootInPlace(
         dimOp, [&]() { dimOp.getSourceMutable().assign(initArg); });
 
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
index 4ff7965c0858a3b..5537a8b212c51f7 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
@@ -194,8 +194,8 @@ cloneAndUpdateOperands(RewriterBase &rewriter, Operation *op,
 
 void LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) {
   // Initialize the iteration argument to the loop initiale values.
-  for (BlockArgument &arg : forOp.getRegionIterArgs()) {
-    OpOperand &operand = forOp.getOpOperandForRegionIterArg(arg);
+  for (auto [arg, operand] :
+       llvm::zip(forOp.getRegionIterArgs(), forOp.getInitsMutable())) {
     setValueMapping(arg, operand.get(), 0);
   }
   auto yield = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 2bddb498c21ac33..e649125a09fea6a 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -523,7 +523,7 @@ getUntiledProducerFromSliceSource(OpOperand *source,
     scf::ForOp loop = *loopIt;
     if (iterArg.getOwner()->getParentOp() != loop)
       break;
-    source = &loop.getOpOperandForRegionIterArg(iterArg);
+    source = loop.getTiedLoopInit(iterArg);
     loopIt++;
   }
   if (loopIt == loops.rend())



More information about the Mlir-commits mailing list