[Mlir-commits] [mlir] [mlir][Interfaces] Clean up `DestinationStyleOpInterface` (PR #67015)

lorenzo chelini llvmlistbot at llvm.org
Thu Sep 21 08:48:50 PDT 2023


================
@@ -50,241 +50,156 @@ def DestinationStyleOpInterface : OpInterface<"DestinationStyleOpInterface"> {
     Example of an op that is not in destination style: `%r = tensor.pad %t`.
     This op is not in destination style because `%r` and `%t` have different
     shape.
-
-    Each op that wants to implement DestinationStyleOpInterface needs to define
-    the getDpsInitsPositionRange() method.
   }];
 
   let cppNamespace = "::mlir";
 
   let methods = [
-    // This method has to be defined for every DPS op.
     InterfaceMethod<
       /*desc=*/"Return start and end indices of the init operands range.",
-      /*retTy=*/"std::pair<int64_t, int64_t>",
-      /*methodName=*/"getDpsInitsPositionRange",
-      /*args=*/(ins),
-      /*methodBody=*/"",
-      /*defaultImplementation=*/""
-    >,
-    //===------------------------------------------------------------------===//
-    // Operands handling.
-    //===------------------------------------------------------------------===//
-    // The operand list is assumed to start with the input operands and end
-    // with the init operands. Therefore, all methods to access the inputs
-    // and inits can be expressed if the number of init operands is know.
-    InterfaceMethod<
-      /*desc=*/"Return the number of inits.",
-      /*retTy=*/"int64_t",
-      /*methodName=*/"getNumDpsInits",
-      /*args=*/(ins),
-      /*methodBody=*/"",
-      /*defaultImplementation=*/[{
-        auto [start, end] = $_op.getDpsInitsPositionRange();
-        return end - start;
-      }]
-    >,
-    InterfaceMethod<
-      /*desc=*/"Return the init operands.",
-      /*retTy=*/"::mlir::OpOperandVector",
-      /*methodName=*/"getDpsInitOperands",
-      /*args=*/(ins),
-      /*methodBody=*/"",
-      /*defaultImplementation=*/[{
-        auto [start, end] = $_op.getDpsInitsPositionRange();
-
-        ::mlir::OpOperandVector result;
-        result.reserve(end - start);
-        for (int i = start; i < end; ++i)
-          result.push_back(&$_op->getOpOperand(i));
-        return result;
-      }]
-    >,
-    InterfaceMethod<
-      /*desc=*/"Return the `i`-th init operand.",
-      /*retTy=*/"::mlir::OpOperand *",
-      /*methodName=*/"getDpsInitOperand",
-      /*args=*/(ins "int64_t":$i),
-      /*methodBody=*/"",
-      /*defaultImplementation=*/[{
-        assert(i >= 0 && i < $_op.getNumDpsInits());
-        auto [start, end] = $_op.getDpsInitsPositionRange();
-        return &$_op->getOpOperand(start + i);
-      }]
-    >,
-    InterfaceMethod<
-      /*desc=*/"Set the `i`-th init operand.",
-      /*retTy=*/"void",
-      /*methodName=*/"setDpsInitOperand",
-      /*args=*/(ins "int64_t":$i, "::mlir::Value":$value),
-      /*methodBody=*/"",
-      /*defaultImplementation=*/[{
-        assert(i >= 0 && i < $_op.getNumDpsInits());
-        auto [start, end] = $_op.getDpsInitsPositionRange();
-        $_op->setOperand(start + i, value);
-      }]
-    >,
-    InterfaceMethod<
-      /*desc=*/"Return the number of inputs.",
-      /*retTy=*/"int64_t",
-      /*methodName=*/"getNumDpsInputs",
-      /*args=*/(ins),
-      /*methodBody=*/"",
-      /*defaultImplementation=*/[{
-        return $_op.getNumOperands() - $_op.getNumDpsInits();
-      }]
+      /*retTy=*/"::mlir::MutableOperandRange",
+      /*methodName=*/"getDpsInitsMutable",
+      /*args=*/(ins)
     >,
-    InterfaceMethod<
-      /*desc=*/"Return the input operands.",
-      /*retTy=*/"::mlir::OpOperandVector",
-      /*methodName=*/"getDpsInputOperands",
-      /*args=*/(ins),
-      /*methodBody=*/"",
-      /*defaultImplementation=*/[{
-        auto [start, end] = $_op.getDpsInitsPositionRange();
-        int64_t numInits = end - start;
-        int64_t numOperands = $_op.getNumOperands();
+  ];
 
-        ::mlir::OpOperandVector result;
-        result.reserve(numOperands - numInits);
-        for (int i = 0; i < start; ++i)
+  let extraSharedClassDeclaration = [{
+    ::mlir::OperandRange getDpsInits() {
+      return $_op.getDpsInitsMutable();
+    }
+
+    /// Return the number of DPS inits.
+    int64_t getNumDpsInits() { return $_op.getDpsInits().size(); }
+
+    /// Return the `i`-th DPS init.
+    ::mlir::OpOperand *getDpsInitOperand(int64_t i) {
+      return &$_op.getDpsInitsMutable()[i];
+    }
+
+    /// Set the `i`-th DPS init.
+    void setDpsInitOperand(int64_t i, Value value) {
+      assert(i >= 0 && i < $_op.getNumDpsInits() && "invalid index");
+      $_op->setOperand($_op.getDpsInits().getBeginOperandIndex() + i, value);
+    }
+
+    /// Return the number of DPS inits.
+    int64_t getNumDpsInputs() {
+      return $_op->getNumOperands() - $_op.getNumDpsInits();
+    }
+
+    /// Return the DPS input operands.
+    ::llvm::SmallVector<::mlir::OpOperand *> getDpsInputOperands() {
+      ::llvm::SmallVector<::mlir::OpOperand *> result;
+      int64_t numOperands = $_op->getNumOperands();
+      ::mlir::OperandRange range = $_op.getDpsInits();
+      if (range.empty()) {
+        result.reserve(numOperands);
+        for (int64_t i = 0; i < numOperands; ++i)
           result.push_back(&$_op->getOpOperand(i));
-        for (int i = end; i < numOperands; ++i)
-          result.push_back(&$_op->getOpOperand(end + i));
-
         return result;
-      }]
-    >,
-    InterfaceMethod<
-      /*desc=*/[{ Return the `i`-th input operand.  }],
-      /*retTy=*/"::mlir::OpOperand *",
-      /*methodName=*/"getDpsInputOperand",
-      /*args=*/(ins "int64_t":$i),
-      /*methodBody=*/"",
-      /*defaultImplementation=*/[{
-        assert(i >= 0 && i < getNumDpsInputs());
-        auto [start, end] = $_op.getDpsInitsPositionRange();
-        return &$_op->getOpOperand(i < start ? i : i + end - start) ;
-      }]
-    >,
-    //===------------------------------------------------------------------===//
-    // Input and DpsInit arguments handling.
-    //===------------------------------------------------------------------===//
-    InterfaceMethod<
-      /*desc=*/"Return true if `opOperand` is an input.",
-      /*retTy=*/"bool",
-      /*methodName=*/"isDpsInput",
-      /*args=*/(ins "::mlir::OpOperand *":$opOperand),
-      /*methodBody=*/"",
-      /*defaultImplementation=*/[{
-        auto [start, end] = $_op.getDpsInitsPositionRange();
-        auto operandNumber = opOperand->getOperandNumber();
-        return operandNumber < start || operandNumber >= end;
-      }]
-    >,
-    InterfaceMethod<
-      /*desc=*/"Return true if `opOperand` is an init.",
-      /*retTy=*/"bool",
-      /*methodName=*/"isDpsInit",
-      /*args=*/(ins "::mlir::OpOperand *":$opOperand),
-      /*methodBody=*/"",
-      /*defaultImplementation=*/[{
-        auto [start, end] = $_op.getDpsInitsPositionRange();
-        auto operandNumber = opOperand->getOperandNumber();
-        return operandNumber >= start && operandNumber < end;
-      }]
-    >,
-    InterfaceMethod<
-      /*desc=*/[{
-        Return true if the `opOperand` is a scalar value. A scalar is defined
-        as neither a memref nor a tensor value.
-      }],
-      /*retTy=*/"bool",
-      /*methodName=*/"isScalar",
-      /*args=*/(ins "::mlir::OpOperand *":$opOperand),
-      /*methodBody=*/"",
-      /*defaultImplementation=*/[{
-        assert(opOperand->getOwner() == $_op.getOperation());
-        return !::llvm::isa<MemRefType, TensorType>(opOperand->get().getType());
-      }]
-    >,
-    InterfaceMethod<
-      /*desc=*/"Return the OpResult that is tied to the given OpOperand.",
-      /*retTy=*/"::mlir::OpResult",
-      /*methodName=*/"getTiedOpResult",
-      /*args=*/(ins "::mlir::OpOperand *":$opOperand),
-      /*methodBody=*/"",
-      /*defaultImplementation=*/[{
-        assert(opOperand->getOwner() == $_op.getOperation());
-
-        auto [start, end] = $_op.getDpsInitsPositionRange();
-        int64_t resultIndex = opOperand->getOperandNumber() - start;
+      }
+      int64_t firstInitPos = range.getBeginOperandIndex();
+      int64_t numInits = range.size();
+      result.reserve(numOperands - numInits);
+      for (int64_t i = 0; i < firstInitPos; ++i)
+        result.push_back(&$_op->getOpOperand(i));
+      for (int64_t i = firstInitPos + numInits; i < numOperands; ++i)
+        result.push_back(&$_op->getOpOperand(i));
+      return result;
+    }
+
+    /// Return the DPS input operands.
+    ::llvm::SmallVector<::mlir::Value> getDpsInputs() {
+      return ::llvm::to_vector(::llvm::map_range($_op.getDpsInputOperands(), [](OpOperand *o) { return o->get(); }));
----------------
chelini wrote:

nit: maybe wrap to 80?

https://github.com/llvm/llvm-project/pull/67015


More information about the Mlir-commits mailing list