[Mlir-commits] [mlir] [mlir][Interfaces] Clean up `DestinationStyleOpInterface` (PR #67015)
lorenzo chelini
llvmlistbot at llvm.org
Thu Sep 21 08:48:50 PDT 2023
================
@@ -50,241 +50,156 @@ def DestinationStyleOpInterface : OpInterface<"DestinationStyleOpInterface"> {
Example of an op that is not in destination style: `%r = tensor.pad %t`.
This op is not in destination style because `%r` and `%t` have different
shape.
-
- Each op that wants to implement DestinationStyleOpInterface needs to define
- the getDpsInitsPositionRange() method.
}];
let cppNamespace = "::mlir";
let methods = [
- // This method has to be defined for every DPS op.
InterfaceMethod<
/*desc=*/"Return start and end indices of the init operands range.",
- /*retTy=*/"std::pair<int64_t, int64_t>",
- /*methodName=*/"getDpsInitsPositionRange",
- /*args=*/(ins),
- /*methodBody=*/"",
- /*defaultImplementation=*/""
- >,
- //===------------------------------------------------------------------===//
- // Operands handling.
- //===------------------------------------------------------------------===//
- // The operand list is assumed to start with the input operands and end
- // with the init operands. Therefore, all methods to access the inputs
- // and inits can be expressed if the number of init operands is know.
- InterfaceMethod<
- /*desc=*/"Return the number of inits.",
- /*retTy=*/"int64_t",
- /*methodName=*/"getNumDpsInits",
- /*args=*/(ins),
- /*methodBody=*/"",
- /*defaultImplementation=*/[{
- auto [start, end] = $_op.getDpsInitsPositionRange();
- return end - start;
- }]
- >,
- InterfaceMethod<
- /*desc=*/"Return the init operands.",
- /*retTy=*/"::mlir::OpOperandVector",
- /*methodName=*/"getDpsInitOperands",
- /*args=*/(ins),
- /*methodBody=*/"",
- /*defaultImplementation=*/[{
- auto [start, end] = $_op.getDpsInitsPositionRange();
-
- ::mlir::OpOperandVector result;
- result.reserve(end - start);
- for (int i = start; i < end; ++i)
- result.push_back(&$_op->getOpOperand(i));
- return result;
- }]
- >,
- InterfaceMethod<
- /*desc=*/"Return the `i`-th init operand.",
- /*retTy=*/"::mlir::OpOperand *",
- /*methodName=*/"getDpsInitOperand",
- /*args=*/(ins "int64_t":$i),
- /*methodBody=*/"",
- /*defaultImplementation=*/[{
- assert(i >= 0 && i < $_op.getNumDpsInits());
- auto [start, end] = $_op.getDpsInitsPositionRange();
- return &$_op->getOpOperand(start + i);
- }]
- >,
- InterfaceMethod<
- /*desc=*/"Set the `i`-th init operand.",
- /*retTy=*/"void",
- /*methodName=*/"setDpsInitOperand",
- /*args=*/(ins "int64_t":$i, "::mlir::Value":$value),
- /*methodBody=*/"",
- /*defaultImplementation=*/[{
- assert(i >= 0 && i < $_op.getNumDpsInits());
- auto [start, end] = $_op.getDpsInitsPositionRange();
- $_op->setOperand(start + i, value);
- }]
- >,
- InterfaceMethod<
- /*desc=*/"Return the number of inputs.",
- /*retTy=*/"int64_t",
- /*methodName=*/"getNumDpsInputs",
- /*args=*/(ins),
- /*methodBody=*/"",
- /*defaultImplementation=*/[{
- return $_op.getNumOperands() - $_op.getNumDpsInits();
- }]
+ /*retTy=*/"::mlir::MutableOperandRange",
+ /*methodName=*/"getDpsInitsMutable",
+ /*args=*/(ins)
>,
- InterfaceMethod<
- /*desc=*/"Return the input operands.",
- /*retTy=*/"::mlir::OpOperandVector",
- /*methodName=*/"getDpsInputOperands",
- /*args=*/(ins),
- /*methodBody=*/"",
- /*defaultImplementation=*/[{
- auto [start, end] = $_op.getDpsInitsPositionRange();
- int64_t numInits = end - start;
- int64_t numOperands = $_op.getNumOperands();
+ ];
- ::mlir::OpOperandVector result;
- result.reserve(numOperands - numInits);
- for (int i = 0; i < start; ++i)
+ let extraSharedClassDeclaration = [{
+ ::mlir::OperandRange getDpsInits() {
+ return $_op.getDpsInitsMutable();
+ }
+
+ /// Return the number of DPS inits.
+ int64_t getNumDpsInits() { return $_op.getDpsInits().size(); }
+
+ /// Return the `i`-th DPS init.
+ ::mlir::OpOperand *getDpsInitOperand(int64_t i) {
+ return &$_op.getDpsInitsMutable()[i];
+ }
+
+ /// Set the `i`-th DPS init.
+ void setDpsInitOperand(int64_t i, Value value) {
+ assert(i >= 0 && i < $_op.getNumDpsInits() && "invalid index");
+ $_op->setOperand($_op.getDpsInits().getBeginOperandIndex() + i, value);
+ }
+
+ /// Return the number of DPS inits.
+ int64_t getNumDpsInputs() {
+ return $_op->getNumOperands() - $_op.getNumDpsInits();
+ }
+
+ /// Return the DPS input operands.
+ ::llvm::SmallVector<::mlir::OpOperand *> getDpsInputOperands() {
+ ::llvm::SmallVector<::mlir::OpOperand *> result;
+ int64_t numOperands = $_op->getNumOperands();
+ ::mlir::OperandRange range = $_op.getDpsInits();
+ if (range.empty()) {
+ result.reserve(numOperands);
+ for (int64_t i = 0; i < numOperands; ++i)
result.push_back(&$_op->getOpOperand(i));
- for (int i = end; i < numOperands; ++i)
- result.push_back(&$_op->getOpOperand(end + i));
-
return result;
- }]
- >,
- InterfaceMethod<
- /*desc=*/[{ Return the `i`-th input operand. }],
- /*retTy=*/"::mlir::OpOperand *",
- /*methodName=*/"getDpsInputOperand",
- /*args=*/(ins "int64_t":$i),
- /*methodBody=*/"",
- /*defaultImplementation=*/[{
- assert(i >= 0 && i < getNumDpsInputs());
- auto [start, end] = $_op.getDpsInitsPositionRange();
- return &$_op->getOpOperand(i < start ? i : i + end - start) ;
- }]
- >,
- //===------------------------------------------------------------------===//
- // Input and DpsInit arguments handling.
- //===------------------------------------------------------------------===//
- InterfaceMethod<
- /*desc=*/"Return true if `opOperand` is an input.",
- /*retTy=*/"bool",
- /*methodName=*/"isDpsInput",
- /*args=*/(ins "::mlir::OpOperand *":$opOperand),
- /*methodBody=*/"",
- /*defaultImplementation=*/[{
- auto [start, end] = $_op.getDpsInitsPositionRange();
- auto operandNumber = opOperand->getOperandNumber();
- return operandNumber < start || operandNumber >= end;
- }]
- >,
- InterfaceMethod<
- /*desc=*/"Return true if `opOperand` is an init.",
- /*retTy=*/"bool",
- /*methodName=*/"isDpsInit",
- /*args=*/(ins "::mlir::OpOperand *":$opOperand),
- /*methodBody=*/"",
- /*defaultImplementation=*/[{
- auto [start, end] = $_op.getDpsInitsPositionRange();
- auto operandNumber = opOperand->getOperandNumber();
- return operandNumber >= start && operandNumber < end;
- }]
- >,
- InterfaceMethod<
- /*desc=*/[{
- Return true if the `opOperand` is a scalar value. A scalar is defined
- as neither a memref nor a tensor value.
- }],
- /*retTy=*/"bool",
- /*methodName=*/"isScalar",
- /*args=*/(ins "::mlir::OpOperand *":$opOperand),
- /*methodBody=*/"",
- /*defaultImplementation=*/[{
- assert(opOperand->getOwner() == $_op.getOperation());
- return !::llvm::isa<MemRefType, TensorType>(opOperand->get().getType());
- }]
- >,
- InterfaceMethod<
- /*desc=*/"Return the OpResult that is tied to the given OpOperand.",
- /*retTy=*/"::mlir::OpResult",
- /*methodName=*/"getTiedOpResult",
- /*args=*/(ins "::mlir::OpOperand *":$opOperand),
- /*methodBody=*/"",
- /*defaultImplementation=*/[{
- assert(opOperand->getOwner() == $_op.getOperation());
-
- auto [start, end] = $_op.getDpsInitsPositionRange();
- int64_t resultIndex = opOperand->getOperandNumber() - start;
+ }
+ int64_t firstInitPos = range.getBeginOperandIndex();
+ int64_t numInits = range.size();
+ result.reserve(numOperands - numInits);
+ for (int64_t i = 0; i < firstInitPos; ++i)
+ result.push_back(&$_op->getOpOperand(i));
+ for (int64_t i = firstInitPos + numInits; i < numOperands; ++i)
+ result.push_back(&$_op->getOpOperand(i));
+ return result;
+ }
+
+ /// Return the DPS input operands.
+ ::llvm::SmallVector<::mlir::Value> getDpsInputs() {
+ return ::llvm::to_vector(::llvm::map_range($_op.getDpsInputOperands(), [](OpOperand *o) { return o->get(); }));
----------------
chelini wrote:
nit: maybe wrap to 80?
https://github.com/llvm/llvm-project/pull/67015
More information about the Mlir-commits
mailing list