[Mlir-commits] [mlir] [mlir][Interfaces] `LoopLikeOpInterface`: Expose mutable inits/yielded values (PR #69137)

Matthias Springer llvmlistbot at llvm.org
Mon Oct 23 16:52:18 PDT 2023


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/69137

>From 8a0ef3002265113c9e9f4de23341e95137a578ae Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Tue, 24 Oct 2023 08:34:51 +0900
Subject: [PATCH] [mlir][Interfaces] `LoopLikeOpInterface`: Expose mutable
 inits/yielded values

Expose a `MutableOperandRange` instead of `ValueRange`/`OperandRange`. This allows users of this interface to change the yielded values and the init values. The names of the interface methods are the same as the auto-generated op accessor names (`get...()` returns `OperandRange`, `get...Mutable()` returns `MutableOperandRange`).
---
 .../include/flang/Optimizer/Dialect/FIROps.td | 18 +++++---
 flang/lib/Optimizer/Dialect/FIROps.cpp        | 14 ++++---
 .../mlir/Dialect/Affine/IR/AffineOps.td       |  2 +-
 mlir/include/mlir/Dialect/SCF/IR/SCFOps.td    | 11 ++---
 mlir/include/mlir/IR/ValueRange.h             |  3 ++
 .../mlir/Interfaces/LoopLikeInterface.td      | 42 +++++++++++++++----
 mlir/lib/Dialect/Affine/IR/AffineOps.cpp      |  4 +-
 mlir/lib/Dialect/SCF/IR/SCF.cpp               | 12 ++++--
 mlir/lib/IR/OperationSupport.cpp              |  4 ++
 9 files changed, 75 insertions(+), 35 deletions(-)

diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index 80d1635e50da24a..dd2e90c3b1a1fde 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -2063,7 +2063,8 @@ class region_Op<string mnemonic, list<Trait> traits = []> :
 }
 
 def fir_DoLoopOp : region_Op<"do_loop",
-    [DeclareOpInterfaceMethods<LoopLikeOpInterface>]> {
+    [DeclareOpInterfaceMethods<LoopLikeOpInterface,
+        ["getYieldedValuesMutable"]>]> {
   let summary = "generalized loop operation";
   let description = [{
     Generalized high-level looping construct. This operation is similar to
@@ -2119,8 +2120,10 @@ def fir_DoLoopOp : region_Op<"do_loop",
     mlir::Operation::operand_range getIterOperands() {
       return getOperands().drop_front(getNumControlOperands());
     }
-    mlir::OperandRange getInits() { return getIterOperands(); }
-    mlir::ValueRange getYieldedValues();
+    llvm::MutableArrayRef<mlir::OpOperand> getInitsMutable() {
+      return
+          getOperation()->getOpOperands().drop_front(getNumControlOperands());
+    }
 
     void setLowerBound(mlir::Value bound) { (*this)->setOperand(0, bound); }
     void setUpperBound(mlir::Value bound) { (*this)->setOperand(1, bound); }
@@ -2207,7 +2210,8 @@ def fir_IfOp : region_Op<"if", [DeclareOpInterfaceMethods<RegionBranchOpInterfac
 }
 
 def fir_IterWhileOp : region_Op<"iterate_while",
-    [DeclareOpInterfaceMethods<LoopLikeOpInterface>]> {
+    [DeclareOpInterfaceMethods<LoopLikeOpInterface,
+        ["getYieldedValuesMutable"]>]> {
   let summary = "DO loop with early exit condition";
   let description = [{
     This single-entry, single-exit looping construct is useful for lowering
@@ -2272,8 +2276,10 @@ def fir_IterWhileOp : region_Op<"iterate_while",
     mlir::Operation::operand_range getIterOperands() {
       return getOperands().drop_front(getNumControlOperands());
     }
-    mlir::OperandRange getInits() { return getIterOperands(); }
-    mlir::ValueRange getYieldedValues();
+    llvm::MutableArrayRef<mlir::OpOperand> getInitsMutable() {
+      return
+          getOperation()->getOpOperands().drop_front(getNumControlOperands());
+    }
 
     void setLowerBound(mlir::Value bound) { (*this)->setOperand(0, bound); }
     void setUpperBound(mlir::Value bound) { (*this)->setOperand(1, bound); }
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index 38311832f20dd26..9641b46d4725c80 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -1933,10 +1933,11 @@ mlir::Value fir::IterWhileOp::blockArgToSourceOp(unsigned blockArgNum) {
   return {};
 }
 
-mlir::ValueRange fir::IterWhileOp::getYieldedValues() {
+llvm::MutableArrayRef<mlir::OpOperand>
+fir::IterWhileOp::getYieldedValuesMutable() {
   auto *term = getRegion().front().getTerminator();
-  return getFinalValue() ? term->getOperands().drop_front()
-                         : term->getOperands();
+  return getFinalValue() ? term->getOpOperands().drop_front()
+                         : term->getOpOperands();
 }
 
 //===----------------------------------------------------------------------===//
@@ -2244,10 +2245,11 @@ mlir::Value fir::DoLoopOp::blockArgToSourceOp(unsigned blockArgNum) {
   return {};
 }
 
-mlir::ValueRange fir::DoLoopOp::getYieldedValues() {
+llvm::MutableArrayRef<mlir::OpOperand>
+fir::DoLoopOp::getYieldedValuesMutable() {
   auto *term = getRegion().front().getTerminator();
-  return getFinalValue() ? term->getOperands().drop_front()
-                         : term->getOperands();
+  return getFinalValue() ? term->getOpOperands().drop_front()
+                         : term->getOpOperands();
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index 36fdf390a761744..f9578cf37d5d768 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -121,7 +121,7 @@ def AffineForOp : Affine_Op<"for",
      ImplicitAffineTerminator, ConditionallySpeculatable,
      RecursiveMemoryEffects, DeclareOpInterfaceMethods<LoopLikeOpInterface,
      ["getSingleInductionVar", "getSingleLowerBound", "getSingleStep",
-      "getSingleUpperBound", "getYieldedValues",
+      "getSingleUpperBound", "getYieldedValuesMutable",
       "replaceWithAdditionalYields"]>,
      DeclareOpInterfaceMethods<RegionBranchOpInterface,
      ["getEntrySuccessorOperands"]>]> {
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 044ca756b31062c..fde9176c670bc6b 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -121,8 +121,8 @@ def ExecuteRegionOp : SCF_Op<"execute_region", [
 
 def ForOp : SCF_Op<"for",
       [AutomaticAllocationScope, DeclareOpInterfaceMethods<LoopLikeOpInterface,
-       ["getInits", "getSingleInductionVar", "getSingleLowerBound",
-        "getSingleStep", "getSingleUpperBound", "getYieldedValues",
+       ["getInitsMutable", "getSingleInductionVar", "getSingleLowerBound",
+        "getSingleStep", "getSingleUpperBound", "getYieldedValuesMutable",
         "promoteIfSingleIteration", "replaceWithAdditionalYields"]>,
        AllTypesMatch<["lowerBound", "upperBound", "step"]>,
        ConditionallySpeculatable,
@@ -962,7 +962,8 @@ def ReduceReturnOp :
 def WhileOp : SCF_Op<"while",
     [DeclareOpInterfaceMethods<RegionBranchOpInterface,
         ["getEntrySuccessorOperands"]>,
-     DeclareOpInterfaceMethods<LoopLikeOpInterface, ["getRegionIterArgs"]>,
+     DeclareOpInterfaceMethods<LoopLikeOpInterface,
+        ["getRegionIterArgs", "getYieldedValuesMutable"]>,
      RecursiveMemoryEffects, SingleBlock]> {
   let summary = "a generic 'while' loop";
   let description = [{
@@ -1095,10 +1096,6 @@ def WhileOp : SCF_Op<"while",
     ConditionOp getConditionOp();
     YieldOp getYieldOp();
 
-    /// Return the values that are yielded from the "after" region (by the
-    /// scf.yield op).
-    ValueRange getYieldedValues();
-
     Block::BlockArgListType getBeforeArguments();
     Block::BlockArgListType getAfterArguments();
     Block *getBeforeBody() { return &getBefore().front(); }
diff --git a/mlir/include/mlir/IR/ValueRange.h b/mlir/include/mlir/IR/ValueRange.h
index ed69e5824f70b51..51262e2d78716ec 100644
--- a/mlir/include/mlir/IR/ValueRange.h
+++ b/mlir/include/mlir/IR/ValueRange.h
@@ -158,6 +158,9 @@ class MutableOperandRange {
   /// Allow implicit conversion to an OperandRange.
   operator OperandRange() const;
 
+  /// Allow implicit conversion to a MutableArrayRef.
+  operator MutableArrayRef<OpOperand>() const;
+
   /// Returns the owning operation.
   Operation *getOwner() const { return owner; }
 
diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.td b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
index 4d2a66dd3143d28..afb7860491664da 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
@@ -130,15 +130,15 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
       }]
     >,
     InterfaceMethod<[{
-        Return the "init" operands that are used as initialization values for
-        the region "iter_args" of this loop.
+        Return the mutable "init" operands that are used as initialization
+        values for the region "iter_args" of this loop.
       }],
-      /*retTy=*/"::mlir::OperandRange",
-      /*methodName=*/"getInits",
+      /*retTy=*/"::llvm::MutableArrayRef<::mlir::OpOperand>",
+      /*methodName=*/"getInitsMutable",
       /*args=*/(ins),
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
-        return ::mlir::OperandRange($_op->operand_end(), $_op->operand_end());
+        return {};
       }]
     >,
     InterfaceMethod<[{
@@ -155,14 +155,15 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
       }]
     >,
     InterfaceMethod<[{
-        Return the values that are yielded to the next iteration.
+        Return the mutable operand range of values that are yielded to the next
+        iteration by the loop terminator.
       }],
-      /*retTy=*/"::mlir::ValueRange",
-      /*methodName=*/"getYieldedValues",
+      /*retTy=*/"::llvm::MutableArrayRef<::mlir::OpOperand>",
+      /*methodName=*/"getYieldedValuesMutable",
       /*args=*/(ins),
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
-        return ::mlir::ValueRange();
+        return {};
       }]
     >,
     InterfaceMethod<[{
@@ -215,6 +216,29 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
             return SmallVector<Value>(newBBArgs);
           });
     }
+
+    /// Return the values that are yielded to the next iteration.
+    ::mlir::ValueRange getYieldedValues() {
+      auto mutableValues = $_op.getYieldedValuesMutable();
+      if (mutableValues.empty())
+        return {};
+      Operation *yieldOp = mutableValues.begin()->getOwner();
+      unsigned firstOperandIndex = mutableValues.begin()->getOperandNumber();
+      return OperandRange(
+          yieldOp->operand_begin() + firstOperandIndex,
+          yieldOp->operand_begin() + firstOperandIndex + mutableValues.size());
+    }
+
+    /// Return the "init" operands that are used as initialization values for
+    /// the region "iter_args" of this loop.
+    ::mlir::OperandRange getInits() {
+      auto initsMutable = $_op.getInitsMutable();
+      if (initsMutable.empty())
+        return ::mlir::OperandRange($_op->operand_end(), $_op->operand_end());
+      unsigned firstOperandIndex = initsMutable.begin()->getOperandNumber();
+      return OperandRange(
+          $_op->operand_begin() + firstOperandIndex,
+          $_op->operand_begin() + firstOperandIndex + initsMutable.size());    }
   }];
 
   let verifyWithRegions = 1;
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index f2b3171c1ab837b..85d16088c43fb1e 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -2215,8 +2215,8 @@ unsigned AffineForOp::getNumIterOperands() {
   return getNumOperands() - lbMap.getNumInputs() - ubMap.getNumInputs();
 }
 
-ValueRange AffineForOp::getYieldedValues() {
-  return cast<AffineYieldOp>(getBody()->getTerminator()).getOperands();
+MutableArrayRef<OpOperand> AffineForOp::getYieldedValuesMutable() {
+  return cast<AffineYieldOp>(getBody()->getTerminator()).getOperandsMutable();
 }
 
 void AffineForOp::print(OpAsmPrinter &p) {
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 20a7b283c938d00..cb888bc17c571fe 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -527,7 +527,9 @@ ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
 
 SmallVector<Region *> ForOp::getLoopRegions() { return {&getRegion()}; }
 
-OperandRange ForOp::getInits() { return getInitArgs(); }
+MutableArrayRef<OpOperand> ForOp::getInitsMutable() {
+  return getInitArgsMutable();
+}
 
 FailureOr<LoopLikeOpInterface>
 ForOp::replaceWithAdditionalYields(RewriterBase &rewriter,
@@ -1221,8 +1223,8 @@ std::optional<APInt> ForOp::getConstantStep() {
   return {};
 }
 
-ValueRange ForOp::getYieldedValues() {
-  return cast<scf::YieldOp>(getBody()->getTerminator()).getResults();
+MutableArrayRef<OpOperand> ForOp::getYieldedValuesMutable() {
+  return cast<scf::YieldOp>(getBody()->getTerminator()).getResultsMutable();
 }
 
 Speculation::Speculatability ForOp::getSpeculatability() {
@@ -3254,7 +3256,9 @@ YieldOp WhileOp::getYieldOp() {
   return cast<YieldOp>(getAfterBody()->getTerminator());
 }
 
-ValueRange WhileOp::getYieldedValues() { return getYieldOp().getResults(); }
+MutableArrayRef<OpOperand> WhileOp::getYieldedValuesMutable() {
+  return getYieldOp().getResultsMutable();
+}
 
 Block::BlockArgListType WhileOp::getBeforeArguments() {
   return getBeforeBody()->getArguments();
diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp
index 6726b49dd3d3103..fc5ccd23b5108d8 100644
--- a/mlir/lib/IR/OperationSupport.cpp
+++ b/mlir/lib/IR/OperationSupport.cpp
@@ -502,6 +502,10 @@ MutableOperandRange::operator OperandRange() const {
   return owner->getOperands().slice(start, length);
 }
 
+MutableOperandRange::operator MutableArrayRef<OpOperand>() const {
+  return owner->getOpOperands().slice(start, length);
+}
+
 MutableOperandRangeRange
 MutableOperandRange::split(NamedAttribute segmentSizes) const {
   return MutableOperandRangeRange(*this, segmentSizes);



More information about the Mlir-commits mailing list