[Mlir-commits] [mlir] [mlir][Interfaces] Change `getDpsInitsMutable` to return `MutableArrayRef` (PR #69145)

Matthias Springer llvmlistbot at llvm.org
Sun Oct 15 22:29:23 PDT 2023


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/69145

`getDpsInitsMutable` now returns a `MutableArrayRef`. This is so that ops can implement the `DestinationStyleOpInterface` even if they do not have any "inits". An example for such an op is `vector.transfer_read`. The current implementation returns a `MutableOperandRange` with range 0 and length 0. This is problematic because the API could be misused to append operands, which would create an invalid op. `MutableArrayRef<OpOperand>` is a better abstraction, which does not allow users to change the number of operands.

>From e82ef29f8f90718b36d940ad5cf178a4e6d594ef Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Mon, 16 Oct 2023 14:27:21 +0900
Subject: [PATCH] [mlir][Interfaces] Change `getDpsInitsMutable` to return
 `MutableArrayRef`

`getDpsInitsMutable` now returns a `MutableArrayRef`. This is so that ops can implement the `DestinationStyleOpInterface` even if they do not have any "inits". An example for such an op is `vector.transfer_read`. The current implementation returns a `MutableOperandRange` with range 0 and length 0. This is problematic because the API could be misused to append operands, which would create an invalid op. `MutableArrayRef<OpOperand>` is a better abstraction, which does not allow users to change the number of operands.
---
 .../Bufferization/IR/BufferizationOps.td       |  5 ++---
 .../mlir/Dialect/Linalg/IR/LinalgOps.td        |  4 +++-
 .../Dialect/Linalg/IR/LinalgStructuredOps.td   | 14 +++++++++-----
 mlir/include/mlir/Dialect/SCF/IR/SCFOps.td     |  4 +++-
 .../mlir/Dialect/Tensor/IR/TensorOps.td        |  6 +++---
 .../mlir/Dialect/Vector/IR/VectorOps.td        |  8 +++++---
 mlir/include/mlir/IR/ValueRange.h              |  3 +++
 .../Interfaces/DestinationStyleOpInterface.td  | 18 ++++++++++++------
 .../Bufferization/IR/BufferizationOps.cpp      |  2 +-
 mlir/lib/IR/OperationSupport.cpp               |  8 ++++++--
 .../Interfaces/DestinationStyleOpInterface.cpp | 17 +++++++++++++++++
 mlir/test/lib/Dialect/Test/TestOps.td          |  6 +++---
 12 files changed, 67 insertions(+), 28 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index c779d1f843d76a0..86c013fd1323680 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -218,7 +218,8 @@ def Bufferization_MaterializeInDestinationOp
     : Bufferization_Op<"materialize_in_destination",
         [AllShapesMatch<["source", "dest"]>,
          AllElementTypesMatch<["source", "dest"]>,
-         BufferizableOpInterface, DestinationStyleOpInterface,
+         BufferizableOpInterface,
+         DeclareOpInterfaceMethods<DestinationStyleOpInterface>,
          DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
          DeclareOpInterfaceMethods<SubsetInsertionOpInterface,
             ["getSourceOperand", "getValuesNeededToBuildSubsetExtraction",
@@ -301,8 +302,6 @@ def Bufferization_MaterializeInDestinationOp
       return ::llvm::cast<RankedTensorType>(getResult().getType());
     }
 
-    MutableOperandRange getDpsInitsMutable();
-
     bool isWritable(Value value, const AnalysisState &state);
   }];
 
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index da12e7c83b22b89..ce2dfbd8d5f813e 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -149,7 +149,9 @@ def Linalg_SoftmaxOp : Linalg_Op<"softmax",
     int64_t getOutputOperandRank() {
       return getOutputOperandType().getRank();
     }
-    MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
+    MutableArrayRef<OpOperand> getDpsInitsMutable() {
+      return getOutputMutable();
+    }
   }];
   let hasVerifier = 1;
 }
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 21a5e5cc47aeb5c..74b4067855b86b7 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -208,7 +208,9 @@ def GenericOp : LinalgStructuredBase_Op<"generic", [
       return nullptr;
     }
 
-    MutableOperandRange getDpsInitsMutable() { return getOutputsMutable(); }
+    MutableArrayRef<OpOperand> getDpsInitsMutable() {
+      return getOutputsMutable();
+    }
   }];
 
   let hasCanonicalizer = 1;
@@ -281,7 +283,7 @@ def MapOp : LinalgStructuredBase_Op<"map", [
     }
 
     // Implement functions necessary for DestinationStyleOpInterface.
-    MutableOperandRange getDpsInitsMutable() { return getInitMutable(); }
+    MutableArrayRef<OpOperand> getDpsInitsMutable() { return getInitMutable(); }
 
     SmallVector<OpOperand *> getOpOperandsMatchingBBargs() {
       return getDpsInputOperands();
@@ -377,7 +379,9 @@ def ReduceOp : LinalgStructuredBase_Op<"reduce", [
     getRegionBuilder() {
       return nullptr;
     }
-    MutableOperandRange getDpsInitsMutable() { return getInitsMutable(); }
+    MutableArrayRef<OpOperand> getDpsInitsMutable() {
+      return getInitsMutable();
+    }
   }];
 
   let hasCustomAssemblyFormat = 1;
@@ -440,7 +444,7 @@ def TransposeOp : LinalgStructuredBase_Op<"transpose", [
     }
 
     // Implement functions necessary for DestinationStyleOpInterface.
-    MutableOperandRange getDpsInitsMutable() { return getInitMutable(); }
+    MutableArrayRef<OpOperand> getDpsInitsMutable() { return getInitMutable(); }
 
     static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
         mlir::ArrayRef<mlir::NamedAttribute>)>
@@ -508,7 +512,7 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [
     }
 
     // Implement functions necessary for DestinationStyleOpInterface.
-    MutableOperandRange getDpsInitsMutable() { return getInitMutable(); }
+    MutableArrayRef<OpOperand> getDpsInitsMutable() { return getInitMutable(); }
 
     static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
         mlir::ArrayRef<mlir::NamedAttribute>)>
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index e1a604a88715f0e..f1c2fe969287fce 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -635,7 +635,9 @@ def ForallOp : SCF_Op<"forall", [
     InParallelOp getTerminator();
 
     // Declare the shared_outs as inits/outs to DestinationStyleOpInterface.
-    MutableOperandRange getDpsInitsMutable() { return getOutputsMutable(); }
+    MutableArrayRef<OpOperand> getDpsInitsMutable() {
+      return getOutputsMutable();
+    }
   }];
 }
 
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 86a250b77dcc8ee..b033a242143f4d3 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -750,7 +750,7 @@ def Tensor_InsertOp : Tensor_Op<"insert", [
   }];
 
   let extraClassDeclaration = [{
-    MutableOperandRange getDpsInitsMutable() { return getDestMutable(); }
+    MutableArrayRef<OpOperand> getDpsInitsMutable() { return getDestMutable(); }
   }];
 
   let hasFolder = 1;
@@ -890,7 +890,7 @@ def Tensor_InsertSliceOp : Tensor_OpWithOffsetSizesAndStrides<"insert_slice", [
     /// and `strides` operands.
     static unsigned getOffsetSizeAndStrideStartOperandIndex() { return 2; }
 
-    MutableOperandRange getDpsInitsMutable() { return getDestMutable(); }
+    MutableArrayRef<OpOperand> getDpsInitsMutable() { return getDestMutable(); }
   }];
 
   let hasCanonicalizer = 1;
@@ -1710,7 +1710,7 @@ class Tensor_RelayoutOp<string mnemonic, list<Trait> traits = []> :
     RankedTensorType getDestType() {
       return ::llvm::cast<RankedTensorType>(getDest().getType()); };
 
-    MutableOperandRange getDpsInitsMutable() { return getDestMutable(); }
+    MutableArrayRef<OpOperand> getDpsInitsMutable() { return getDestMutable(); }
 
     /// Interface method for ConditionallySpeculatable.
     Speculation::Speculatability getSpeculatability();
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 2df2fe4c5ce8e9c..0e1d1af33904898 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -1389,8 +1389,8 @@ def Vector_TransferReadOp :
     // MaskableOpInterface methods.
     bool supportsPassthru() { return true; }
 
-    MutableOperandRange getDpsInitsMutable() {
-      return MutableOperandRange(getOperation(), /*start=*/0, /*length=*/0);
+    MutableArrayRef<OpOperand> getDpsInitsMutable() {
+      return {};
     }
   }];
 
@@ -1553,7 +1553,9 @@ def Vector_TransferWriteOp :
     ///  ops of other dialects.
     Value getValue() { return getVector(); }
 
-    MutableOperandRange getDpsInitsMutable() { return getSourceMutable(); }
+    MutableArrayRef<OpOperand> getDpsInitsMutable() {
+      return getSourceMutable();
+    }
   }];
 
   let hasFolder = 1;
diff --git a/mlir/include/mlir/IR/ValueRange.h b/mlir/include/mlir/IR/ValueRange.h
index ed69e5824f70b51..51262e2d78716ec 100644
--- a/mlir/include/mlir/IR/ValueRange.h
+++ b/mlir/include/mlir/IR/ValueRange.h
@@ -158,6 +158,9 @@ class MutableOperandRange {
   /// Allow implicit conversion to an OperandRange.
   operator OperandRange() const;
 
+  /// Allow implicit conversion to a MutableArrayRef.
+  operator MutableArrayRef<OpOperand>() const;
+
   /// Returns the owning operation.
   Operation *getOwner() const { return owner; }
 
diff --git a/mlir/include/mlir/Interfaces/DestinationStyleOpInterface.td b/mlir/include/mlir/Interfaces/DestinationStyleOpInterface.td
index 4c52d803e114762..673e3d6160c212e 100644
--- a/mlir/include/mlir/Interfaces/DestinationStyleOpInterface.td
+++ b/mlir/include/mlir/Interfaces/DestinationStyleOpInterface.td
@@ -20,9 +20,9 @@ def DestinationStyleOpInterface : OpInterface<"DestinationStyleOpInterface"> {
     Init operands must be ranked tensors or ranked memrefs. Input operands can
     have any type. All non-init operands are DPS inputs.
 
-    The init operands of this op are specified by the MutableOperandRange that
-    the `getDpsInitsMutable` interface methods returns. This implies that the
-    init operands must be a consecutive range of operands.
+    The init operands of this op are specified by the OpOperands that
+    the `getDpsInitsMutable` interface methods returns. The init operands must
+    be a consecutive range of operands.
 
     If the op has "tensor semantics", then the input operands are either ranked
     tensors or other non-tensor/memref types ("scalars"). The init operands are
@@ -56,8 +56,8 @@ def DestinationStyleOpInterface : OpInterface<"DestinationStyleOpInterface"> {
 
   let methods = [
     InterfaceMethod<
-      /*desc=*/"Return start and end indices of the init operands range.",
-      /*retTy=*/"::mlir::MutableOperandRange",
+      /*desc=*/"Return the consecutive range of init operands.",
+      /*retTy=*/"::llvm::MutableArrayRef<::mlir::OpOperand>",
       /*methodName=*/"getDpsInitsMutable",
       /*args=*/(ins)
     >,
@@ -65,7 +65,13 @@ def DestinationStyleOpInterface : OpInterface<"DestinationStyleOpInterface"> {
 
   let extraSharedClassDeclaration = [{
     ::mlir::OperandRange getDpsInits() {
-      return $_op.getDpsInitsMutable();
+      auto initsMutable = $_op.getDpsInitsMutable();
+      if (initsMutable.empty())
+        return ::mlir::OperandRange($_op->operand_end(), $_op->operand_end());
+      unsigned firstOperandIndex = initsMutable.begin()->getOperandNumber();
+      return OperandRange(
+          $_op->operand_begin() + firstOperandIndex,
+          $_op->operand_begin() + firstOperandIndex + initsMutable.size());
     }
 
     /// Return the number of DPS inits.
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 5716dcc9d905016..7109c1d9f31ad5b 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -675,7 +675,7 @@ bool MaterializeInDestinationOp::isWritable(Value value,
   return isa<TensorType>(getDest().getType()) ? true : getWritable();
 }
 
-MutableOperandRange MaterializeInDestinationOp::getDpsInitsMutable() {
+MutableArrayRef<OpOperand> MaterializeInDestinationOp::getDpsInitsMutable() {
   return getDestMutable();
 }
 
diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp
index 6726b49dd3d3103..41214b998bd6985 100644
--- a/mlir/lib/IR/OperationSupport.cpp
+++ b/mlir/lib/IR/OperationSupport.cpp
@@ -502,6 +502,10 @@ MutableOperandRange::operator OperandRange() const {
   return owner->getOperands().slice(start, length);
 }
 
+MutableOperandRange::operator MutableArrayRef<OpOperand>() const {
+  return owner->getOpOperands().slice(start, length);
+}
+
 MutableOperandRangeRange
 MutableOperandRange::split(NamedAttribute segmentSizes) const {
   return MutableOperandRangeRange(*this, segmentSizes);
@@ -529,11 +533,11 @@ OpOperand &MutableOperandRange::operator[](unsigned index) const {
 }
 
 MutableArrayRef<OpOperand>::iterator MutableOperandRange::begin() const {
-  return owner->getOpOperands().slice(start, length).begin();
+  return static_cast<MutableArrayRef<OpOperand>>(*this).begin();
 }
 
 MutableArrayRef<OpOperand>::iterator MutableOperandRange::end() const {
-  return owner->getOpOperands().slice(start, length).end();
+  return static_cast<MutableArrayRef<OpOperand>>(*this).end();
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Interfaces/DestinationStyleOpInterface.cpp b/mlir/lib/Interfaces/DestinationStyleOpInterface.cpp
index 4e5ef66887cadf8..b8a676a9595b781 100644
--- a/mlir/lib/Interfaces/DestinationStyleOpInterface.cpp
+++ b/mlir/lib/Interfaces/DestinationStyleOpInterface.cpp
@@ -31,7 +31,24 @@ LogicalResult detail::verifyDestinationStyleOpInterface(Operation *op) {
       cast<DestinationStyleOpInterface>(op);
 
   SmallVector<OpOperand *> outputTensorOperands;
+#ifndef NDEBUG
+  int64_t lastOperandIdx;
+  if (!dstStyleOp.getDpsInitsMutable().empty())
+    lastOperandIdx =
+        static_cast<int64_t>(
+            dstStyleOp.getDpsInitsMutable().begin()->getOperandNumber()) -
+        1;
+#endif // NDEBUG
   for (OpOperand &operand : dstStyleOp.getDpsInitsMutable()) {
+#ifndef NDEBUG
+    // DPS inits must be consecutive operands. Since `getDpsInitsMutable`
+    // returns a MutableArrayRef (that does not own the underlying data), it is
+    // currently not possible to return non-consecutive operands and this check
+    // just guards against future changes of this interface.
+    assert(lastOperandIdx + 1 == operand.getOperandNumber() &&
+           "DPS inits must be consecutive operands");
+    ++lastOperandIdx;
+#endif // NDEBUG
     Type type = operand.get().getType();
     if (isa<RankedTensorType>(type)) {
       outputTensorOperands.push_back(&operand);
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index edb63924b3553f2..75785e42c667569 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -2350,7 +2350,7 @@ def TestDestinationStyleOp :
   }];
 
   let extraClassDeclaration = [{
-    mlir::MutableOperandRange getDpsInitsMutable() {
+    mlir::MutableArrayRef<mlir::OpOperand> getDpsInitsMutable() {
       return getOutputsMutable();
     }
   }];
@@ -2411,7 +2411,7 @@ def TestLinalgConvOp :
       return "";
     }
 
-    mlir::MutableOperandRange getDpsInitsMutable() {
+    mlir::MutableArrayRef<mlir::OpOperand> getDpsInitsMutable() {
       return getOutputsMutable();
     }
   }];
@@ -2472,7 +2472,7 @@ def TestLinalgFillOp :
       return "";
     }
 
-    mlir::MutableOperandRange getDpsInitsMutable() {
+    mlir::MutableArrayRef<mlir::OpOperand> getDpsInitsMutable() {
       return getOutputsMutable();
     }
   }];



More information about the Mlir-commits mailing list