[Mlir-commits] [mlir] 98a6edd - [mlir][Interfaces] `LoopLikeOpInterface`: Expose tied loop results (#70535)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Oct 31 16:34:18 PDT 2023


Author: Matthias Springer
Date: 2023-11-01T08:34:14+09:00
New Revision: 98a6edd38f960679e65124d52e3c61f4abd1419f

URL: https://github.com/llvm/llvm-project/commit/98a6edd38f960679e65124d52e3c61f4abd1419f
DIFF: https://github.com/llvm/llvm-project/commit/98a6edd38f960679e65124d52e3c61f4abd1419f.diff

LOG: [mlir][Interfaces] `LoopLikeOpInterface`: Expose tied loop results (#70535)

Expose loop results, which correspond to the region iter_arg values that
are returned from the loop when there are no more iterations. Exposing
loop results is optional because some loops (e.g., `scf.while`) do not
have a 1-to-1 mapping between region iter_args and op results.

Also add additional helper functions to query tied
results/iter_args/inits.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
    mlir/include/mlir/Interfaces/LoopLikeInterface.td
    mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
    mlir/lib/Dialect/SCF/IR/SCF.cpp
    mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
    mlir/lib/Interfaces/LoopLikeInterface.cpp
    mlir/test/Dialect/SCF/invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 43beebc1bf54166..38937fe28949436 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -269,28 +269,6 @@ def ForOp : SCF_Op<"for",
     /// Number of operands controlling the loop: lb, ub, step
     unsigned getNumControlOperands() { return 3; }
 
-    /// Get the OpResult that corresponds to an OpOperand.
-    /// Assert that opOperand is an iterArg.
-    /// This helper prevents internal op implementation detail leakage to
-    /// clients by hiding the operand / block argument mapping.
-    OpResult getResultForOpOperand(OpOperand &opOperand) {
-      assert(opOperand.getOperandNumber() >= getNumControlOperands() &&
-             "expected an iter args operand");
-      assert(opOperand.getOwner() == getOperation() &&
-             "opOperand does not belong to this scf::ForOp operation");
-      return getOperation()->getResult(
-        opOperand.getOperandNumber() - getNumControlOperands());
-    }
-    /// Get the OpOperand& that corresponds to an OpResultOpOperand.
-    /// This helper prevents internal op implementation detail leakage to
-    /// clients by hiding the operand / block argument mapping.
-    OpOperand &getOpOperandForResult(OpResult opResult) {
-      assert(opResult.getDefiningOp() == getOperation() &&
-             "opResult does not belong to the scf::ForOp operation");
-      return getOperation()->getOpOperand(
-        getNumControlOperands() + opResult.getResultNumber());
-    }
-
     /// Returns the step as an `APInt` if it is constant.
     std::optional<APInt> getConstantStep();
 
@@ -942,7 +920,7 @@ def WhileOp : SCF_Op<"while",
     [DeclareOpInterfaceMethods<RegionBranchOpInterface,
         ["getEntrySuccessorOperands"]>,
      DeclareOpInterfaceMethods<LoopLikeOpInterface,
-        ["getLoopResults", "getRegionIterArgs", "getYieldedValuesMutable"]>,
+        ["getRegionIterArgs", "getYieldedValuesMutable"]>,
      RecursiveMemoryEffects, SingleBlock]> {
   let summary = "a generic 'while' loop";
   let description = [{
@@ -1156,7 +1134,7 @@ def IndexSwitchOp : SCF_Op<"index_switch", [RecursiveMemoryEffects,
 //===----------------------------------------------------------------------===//
 
 def YieldOp : SCF_Op<"yield", [Pure, ReturnLike, Terminator,
-    ParentOneOf<["ExecuteRegionOp, ForOp", "IfOp", "IndexSwitchOp",
+    ParentOneOf<["ExecuteRegionOp", "ForOp", "IfOp", "IndexSwitchOp",
                  "ParallelOp", "WhileOp"]>]> {
   let summary = "loop yield and termination operation";
   let description = [{

diff  --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.td b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
index d3d07eec8ebff57..75d90b67bd82f36 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
@@ -33,6 +33,13 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
     If one of the respective interface methods is implemented, so must the other
     two. The interface verifier ensures that the number of types of the region
     iter_args, init values and yielded values match.
+
+    Optionally, "loop results" can be exposed through this interface. These are
+    the values that are returned from the loop op when there are no more
+    iterations. The number and types of the loop results must match with the
+    region iter_args. Note: Loop results are optional because some loops
+    (e.g., `scf.while`) may produce results that do match 1-to-1 with the
+    region iter_args.
   }];
   let cppNamespace = "::mlir";
 
@@ -166,6 +173,26 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
         return {};
       }]
     >,
+    InterfaceMethod<[{
+        Return the range of results that are return from this loop and
+        correspond to the "init" operands.
+
+        Note: This interface method is optional. If loop results are not
+        exposed via this interface, "std::nullopt" should be returned.
+        Otherwise, the number and types of results must match with the
+        region iter_args, inits and yielded values that are exposed via this
+        interface. If loop results are exposed but this loop op has no
+        loop-carried variables, an empty result range (and not "std::nullopt")
+        should be returned.
+      }],
+      /*retTy=*/"::std::optional<::mlir::ResultRange>",
+      /*methodName=*/"getLoopResults",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        return ::std::nullopt;
+      }]
+    >,
     InterfaceMethod<[{
         Append the specified additional "init" operands: replace this loop with
         a new loop that has the additional init operands. The loop body of
@@ -242,6 +269,8 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
     }
 
     /// Return the region iter_arg that corresponds to the given init operand.
+    /// Return an "empty" block argument if the given operand is not an init
+    /// operand of this loop op.
     BlockArgument getTiedLoopRegionIterArg(OpOperand *opOperand) {
       auto initsMutable = $_op.getInitsMutable();
       auto it = llvm::find(initsMutable, *opOperand);
@@ -250,7 +279,22 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
       return $_op.getRegionIterArgs()[std::distance(initsMutable.begin(), it)];
     }
 
+    /// Return the region iter_arg that corresponds to the given loop result.
+    /// Return an "empty" block argument if the given OpResult is not a loop
+    /// result or if this op does not expose any loop results.
+    BlockArgument getTiedLoopRegionIterArg(OpResult opResult) {
+      auto loopResults = $_op.getLoopResults();
+      if (!loopResults)
+        return {};
+      auto it = llvm::find(*loopResults, opResult);
+      if (it == loopResults->end())
+        return {};
+      return $_op.getRegionIterArgs()[std::distance(loopResults->begin(), it)];
+    }
+
     /// Return the init operand that corresponds to the given region iter_arg.
+    /// Return "nullptr" if the given block argument is not a region iter_arg
+    /// of this loop op.
     OpOperand *getTiedLoopInit(BlockArgument bbArg) {
       auto iterArgs = $_op.getRegionIterArgs();
       auto it = llvm::find(iterArgs, bbArg);
@@ -259,7 +303,22 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
       return &$_op.getInitsMutable()[std::distance(iterArgs.begin(), it)];
     }
 
+    /// Return the init operand that corresponds to the given loop result.
+    /// Return "nullptr" if the given OpResult is not a loop result or if this
+    /// op does not expose any loop results.
+    OpOperand *getTiedLoopInit(OpResult opResult) {
+      auto loopResults = $_op.getLoopResults();
+      if (!loopResults)
+        return nullptr;
+      auto it = llvm::find(*loopResults, opResult);
+      if (it == loopResults->end())
+        return nullptr;
+      return &$_op.getInitsMutable()[std::distance(loopResults->begin(), it)];
+    }
+
     /// Return the yielded value that corresponds to the given region iter_arg.
+    /// Return "nullptr" if the given block argument is not a region iter_arg
+    /// of this loop op.
     OpOperand *getTiedLoopYieldedValue(BlockArgument bbArg) {
       auto iterArgs = $_op.getRegionIterArgs();
       auto it = llvm::find(iterArgs, bbArg);
@@ -268,6 +327,34 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
       return
           &$_op.getYieldedValuesMutable()[std::distance(iterArgs.begin(), it)];
     }
+
+    /// Return the loop result that corresponds to the given init operand.
+    /// Return an "empty" OpResult if the given operand is not an init operand
+    /// of this loop op or if this op does not expose any loop results.
+    OpResult getTiedLoopResult(OpOperand *opOperand) {
+      auto loopResults = $_op.getLoopResults();
+      if (!loopResults)
+        return {};
+      auto initsMutable = $_op.getInitsMutable();
+      auto it = llvm::find(initsMutable, *opOperand);
+      if (it == initsMutable.end())
+        return {};
+      return (*loopResults)[std::distance(initsMutable.begin(), it)];
+    }
+
+    /// Return the loop result that corresponds to the given region iter_arg.
+    /// Return an "empty" OpResult if the given block argument is not a region
+    /// iter_arg of this loop op or if this op does not expose any loop results.
+    OpResult getTiedLoopResult(BlockArgument bbArg) {
+      auto loopResults = $_op.getLoopResults();
+      if (!loopResults)
+        return {};
+      auto iterArgs = $_op.getRegionIterArgs();
+      auto it = llvm::find(iterArgs, bbArg);
+      if (it == iterArgs.end())
+        return {};
+      return (*loopResults)[std::distance(iterArgs.begin(), it)];
+    }
   }];
 
   let verifyWithRegions = 1;

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
index 19f704f5232ed81..866f51b0e92bbde 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
@@ -810,7 +810,7 @@ padThroughLoopIterArg(RewriterBase &rewriter, Value paddedValueBeforeHoisting,
   OpBuilder::InsertionGuard g(rewriter);
   rewriter.setInsertionPointAfter(hoistedPackedTensor.getDefiningOp());
 
-  unsigned iterArgNumber = forOp.getResultForOpOperand(*pUse).getResultNumber();
+  unsigned iterArgNumber = forOp.getTiedLoopResult(pUse).getResultNumber();
   auto yieldingExtractSliceOp = forOp.getYieldedValues()[iterArgNumber]
                                     .getDefiningOp<tensor::ExtractSliceOp>();
   if (!yieldingExtractSliceOp)

diff  --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index b8b75f3f476a5da..bc33fe2a9a01079 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -390,6 +390,8 @@ std::optional<OpFoldResult> ForOp::getSingleUpperBound() {
   return OpFoldResult(getUpperBound());
 }
 
+std::optional<ResultRange> ForOp::getLoopResults() { return getResults(); }
+
 /// Promotes the loop body of a forOp to its containing block if the forOp
 /// it can be determined that the loop has a single iteration.
 LogicalResult ForOp::promoteIfSingleIteration(RewriterBase &rewriter) {

diff  --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index 885e00b48ff8434..dc3c46bf896a9cf 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -614,7 +614,7 @@ struct ForOpInterface
   AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
                                       const AnalysisState &state) const {
     auto forOp = cast<scf::ForOp>(op);
-    OpResult opResult = forOp.getResultForOpOperand(opOperand);
+    OpResult opResult = forOp.getTiedLoopResult(&opOperand);
     BufferRelation relation = bufferRelation(op, opResult, state);
     return {{opResult, relation,
              /*isDefinite=*/relation == BufferRelation::Equivalent}};
@@ -625,10 +625,9 @@ struct ForOpInterface
     // ForOp results are equivalent to their corresponding init_args if the
     // corresponding iter_args and yield values are equivalent.
     auto forOp = cast<scf::ForOp>(op);
-    OpOperand &forOperand = forOp.getOpOperandForResult(opResult);
-    auto bbArg = forOp.getTiedLoopRegionIterArg(&forOperand);
+    BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult);
     bool equivalentYield = state.areEquivalentBufferizedValues(
-        bbArg, forOp.getYieldedValues()[opResult.getResultNumber()]);
+        bbArg, forOp.getTiedLoopYieldedValue(bbArg)->get());
     return equivalentYield ? BufferRelation::Equivalent
                            : BufferRelation::Unknown;
   }
@@ -703,16 +702,13 @@ struct ForOpInterface
 
     if (auto opResult = dyn_cast<OpResult>(value)) {
       // The type of an OpResult must match the corresponding iter_arg type.
-      BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(
-          &forOp.getOpOperandForResult(opResult));
+      BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult);
       return bufferization::getBufferType(bbArg, options, invocationStack);
     }
 
     // Compute result/argument number.
     BlockArgument bbArg = cast<BlockArgument>(value);
-    unsigned resultNum =
-        forOp.getResultForOpOperand(*forOp.getTiedLoopInit(bbArg))
-            .getResultNumber();
+    unsigned resultNum = forOp.getTiedLoopResult(bbArg).getResultNumber();
 
     // Compute the bufferized type.
     auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());

diff  --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index e649125a09fea6a..df162d29a48eb89 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -609,8 +609,7 @@ mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
   if (destinationInitArg &&
       (*destinationInitArg)->getOwner() == outerMostLoop) {
     unsigned iterArgNumber =
-        outerMostLoop.getResultForOpOperand(**destinationInitArg)
-            .getResultNumber();
+        outerMostLoop.getTiedLoopResult(*destinationInitArg).getResultNumber();
     int64_t resultNumber = fusableProducer.getResultNumber();
     if (auto dstOp =
             dyn_cast<DestinationStyleOpInterface>(fusableProducer.getOwner())) {

diff  --git a/mlir/lib/Interfaces/LoopLikeInterface.cpp b/mlir/lib/Interfaces/LoopLikeInterface.cpp
index 15a816f4e448839..be1316b95688bf2 100644
--- a/mlir/lib/Interfaces/LoopLikeInterface.cpp
+++ b/mlir/lib/Interfaces/LoopLikeInterface.cpp
@@ -58,7 +58,7 @@ LogicalResult detail::verifyLoopLikeOpInterface(Operation *op) {
   // but the LoopLikeOpInterface provides better error messages.
   auto loopLikeOp = cast<LoopLikeOpInterface>(op);
 
-  // Verify number of inits/iter_args/yielded values.
+  // Verify number of inits/iter_args/yielded values/loop results.
   if (loopLikeOp.getInits().size() != loopLikeOp.getRegionIterArgs().size())
     return op->emitOpError("
diff erent number of inits and region iter_args: ")
            << loopLikeOp.getInits().size()
@@ -69,21 +69,43 @@ LogicalResult detail::verifyLoopLikeOpInterface(Operation *op) {
                "
diff erent number of region iter_args and yielded values: ")
            << loopLikeOp.getRegionIterArgs().size()
            << " != " << loopLikeOp.getYieldedValues().size();
+  if (loopLikeOp.getLoopResults() && loopLikeOp.getLoopResults()->size() !=
+                                         loopLikeOp.getRegionIterArgs().size())
+    return op->emitOpError(
+               "
diff erent number of loop results and region iter_args: ")
+           << loopLikeOp.getLoopResults()->size()
+           << " != " << loopLikeOp.getRegionIterArgs().size();
 
-  // Verify types of inits/iter_args/yielded values.
+  // Verify types of inits/iter_args/yielded values/loop results.
   int64_t i = 0;
   for (const auto it :
        llvm::zip_equal(loopLikeOp.getInits(), loopLikeOp.getRegionIterArgs(),
                        loopLikeOp.getYieldedValues())) {
     if (std::get<0>(it).getType() != std::get<1>(it).getType())
-      op->emitOpError(std::to_string(i))
-          << "-th init and " << i << "-th region iter_arg have 
diff erent type: "
-          << std::get<0>(it).getType() << " != " << std::get<1>(it).getType();
+      return op->emitOpError(std::to_string(i))
+             << "-th init and " << i
+             << "-th region iter_arg have 
diff erent type: "
+             << std::get<0>(it).getType()
+             << " != " << std::get<1>(it).getType();
     if (std::get<1>(it).getType() != std::get<2>(it).getType())
-      op->emitOpError(std::to_string(i))
-          << "-th region iter_arg and " << i
-          << "-th yielded value have 
diff erent type: "
-          << std::get<1>(it).getType() << " != " << std::get<2>(it).getType();
+      return op->emitOpError(std::to_string(i))
+             << "-th region iter_arg and " << i
+             << "-th yielded value have 
diff erent type: "
+             << std::get<1>(it).getType()
+             << " != " << std::get<2>(it).getType();
+    ++i;
+  }
+  i = 0;
+  if (loopLikeOp.getLoopResults()) {
+    for (const auto it : llvm::zip_equal(loopLikeOp.getRegionIterArgs(),
+                                         *loopLikeOp.getLoopResults())) {
+      if (std::get<0>(it).getType() != std::get<1>(it).getType())
+        return op->emitOpError(std::to_string(i))
+               << "-th region iter_arg and " << i
+               << "-th loop result have 
diff erent type: "
+               << std::get<0>(it).getType()
+               << " != " << std::get<1>(it).getType();
+    }
     ++i;
   }
 

diff  --git a/mlir/test/Dialect/SCF/invalid.mlir b/mlir/test/Dialect/SCF/invalid.mlir
index 1b2c3f563195c52..ad07a8b11327deb 100644
--- a/mlir/test/Dialect/SCF/invalid.mlir
+++ b/mlir/test/Dialect/SCF/invalid.mlir
@@ -96,6 +96,19 @@ func.func @not_enough_loop_results(%arg0: index, %init: f32) {
 
 // -----
 
+func.func @scf_for_incorrect_result_type(%arg0: index, %init: f32) {
+  // expected-error @below{{0-th region iter_arg and 0-th loop result have 
diff erent type: 'f32' != 'f64'}}
+  "scf.for"(%arg0, %arg0, %arg0, %init) (
+    {
+    ^bb0(%i0 : index, %iter: f32):
+      scf.yield %iter : f32
+    }
+  ) : (index, index, index, f32) -> (f64)
+  return
+}
+
+// -----
+
 func.func @too_many_iter_args(%arg0: index, %init: f32) {
   // expected-error @below{{
diff erent number of inits and region iter_args: 1 != 2}}
   %x = "scf.for"(%arg0, %arg0, %arg0, %init) (
@@ -449,7 +462,6 @@ func.func @std_for_operands_mismatch_4(%arg0 : index, %arg1 : index, %arg2 : ind
   %s0 = arith.constant 0.0 : f32
   %t0 = arith.constant 1.0 : f32
   // expected-error @below {{1-th region iter_arg and 1-th yielded value have 
diff erent type: 'f32' != 'i32'}}
-  // expected-error @below {{along control flow edge from Region #0 to Region #0: source type #1 'i32' should match input type #1 'f32'}}
   %result1:2 = scf.for %i0 = %arg0 to %arg1 step %arg2
                     iter_args(%si = %s0, %ti = %t0) -> (f32, f32) {
     %sn = arith.addf %si, %si : f32


        


More information about the Mlir-commits mailing list