[Mlir-commits] [mlir] [mlir][linalg][NFC] Make `LinalgOp` inherit from `DestinationStyleOpInterface` (PR #66995)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Sep 21 02:57:08 PDT 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
<details>
<summary>Changes</summary>
Dependent interfaces have been added a while ago and these TODOs can be addressed now.
---
Full diff: https://github.com/llvm/llvm-project/pull/66995.diff
1 Files Affected:
- (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td (+7-103)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index 78431b9f66f9014..839861c2369ca1d 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -13,6 +13,7 @@
#ifndef LINALG_IR_LINALGINTERFACES
#define LINALG_IR_LINALGINTERFACES
+include "mlir/Interfaces/DestinationStyleOpInterface.td"
include "mlir/IR/OpBase.td"
// The 'LinalgContractionOpInterface' provides access to the
@@ -178,7 +179,8 @@ def LinalgFillOpInterface : OpInterface<"FillOpInterface"> {
}
// The 'LinalgStructuredInterface' provides access to the 'LinalgOp' interface.
-def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
+def LinalgStructuredInterface
+ : OpInterface<"LinalgOp", [DestinationStyleOpInterface]> {
let cppNamespace = "::mlir::linalg";
let methods = [
//===------------------------------------------------------------------===//
@@ -321,13 +323,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- // MLIR currently does not support dependent interfaces or interface
- // inheritance. By construction all ops with StructuredOpInterface must
- // implement DestinationStyleOpInterface.
- // TODO: reevaluate the need for a cast when a better mechanism exists.
- return getBlock()->getArguments().take_front(
- cast<DestinationStyleOpInterface>(*this->getOperation())
- .getNumDpsInputs());
+ return getBlock()->getArguments().take_front($_op.getNumDpsInputs());
}]
>,
InterfaceMethod<
@@ -339,13 +335,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- // MLIR currently does not support dependent interfaces or interface
- // inheritance. By construction all ops with StructuredOpInterface must
- // implement DestinationStyleOpInterface.
- // TODO: reevaluate the need for a cast when a better mechanism exists.
- return getBlock()->getArguments().take_back(
- cast<DestinationStyleOpInterface>(*this->getOperation())
- .getNumDpsInits());
+ return getBlock()->getArguments().take_back($_op.getNumDpsInits());
}]
>,
InterfaceMethod<
@@ -418,13 +408,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
assert(result.getOwner() == this->getOperation());
auto indexingMaps =
$_op.getIndexingMaps().template getAsValueRange<AffineMapAttr>();
- // MLIR currently does not support dependent interfaces or interface
- // inheritance. By construction all ops with StructuredOpInterface must
- // implement DestinationStyleOpInterface.
- // TODO: reevaluate the need for a cast when a better mechanism exists.
- return *(indexingMaps.begin() +
- cast<DestinationStyleOpInterface>(*this->getOperation())
- .getNumDpsInputs() +
+ return *(indexingMaps.begin() + $_op.getNumDpsInputs() +
result.getResultNumber());
}]
>,
@@ -439,14 +423,8 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
/*methodBody=*/"",
/*defaultImplementation=*/[{
assert(opOperand->getOwner() == this->getOperation());
- // MLIR currently does not support dependent interfaces or interface
- // inheritance. By construction all ops with StructuredOpInterface must
- // implement DestinationStyleOpInterface.
- // TODO: reevaluate the need for a cast when a better mechanism exists.
int64_t resultIndex =
- opOperand->getOperandNumber() -
- cast<DestinationStyleOpInterface>(*this->getOperation())
- .getNumDpsInputs();
+ opOperand->getOperandNumber() - $_op.getNumDpsInputs();
assert(resultIndex >= 0 &&
resultIndex < this->getOperation()->getNumResults());
Operation *yieldOp = getBlock()->getTerminator();
@@ -800,80 +778,6 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
/// Return the index in the indexingMaps vector that corresponds to this `opOperand`
int64_t getIndexingMapIndex(OpOperand *opOperand);
-
- //========================================================================//
- // Forwarding functions to access interface methods from the
- // DestinationStyleOpInterface.
- // MLIR currently does not support dependent interfaces or interface
- // inheritance. By construction all ops with StructuredOpInterface must
- // implement DestinationStyleOpInterface.
- // TODO: reevaluate the need for a cast when a better mechanism exists.
- //========================================================================//
-
- int64_t getNumDpsInputs() {
- return cast<DestinationStyleOpInterface>(*this->getOperation())
- .getNumDpsInputs();
- }
-
- int64_t getNumDpsInits() {
- return cast<DestinationStyleOpInterface>(*this->getOperation())
- .getNumDpsInits();
- }
-
- OpOperandVector getDpsInputOperands() {
- return cast<DestinationStyleOpInterface>(*this->getOperation())
- .getDpsInputOperands();
- }
-
- OpOperand *getDpsInputOperand(int64_t i) {
- return cast<DestinationStyleOpInterface>(*this->getOperation())
- .getDpsInputOperand(i);
- }
-
- void setDpsInitOperand(int64_t i, Value value) {
- return cast<DestinationStyleOpInterface>(*this->getOperation())
- .setDpsInitOperand(i, value);
- }
-
- OpOperandVector getDpsInitOperands() {
- return cast<DestinationStyleOpInterface>(*this->getOperation())
- .getDpsInitOperands();
- }
-
- OpOperand *getDpsInitOperand(int64_t i) {
- return cast<DestinationStyleOpInterface>(*this->getOperation())
- .getDpsInitOperand(i);
- }
-
- bool isDpsInput(OpOperand *opOperand) {
- return cast<DestinationStyleOpInterface>(*this->getOperation())
- .isDpsInput(opOperand);
- }
-
- bool isDpsInit(OpOperand *opOperand) {
- return cast<DestinationStyleOpInterface>(*this->getOperation())
- .isDpsInit(opOperand);
- }
-
- bool isScalar(OpOperand *opOperand) {
- return cast<DestinationStyleOpInterface>(*this->getOperation())
- .isScalar(opOperand);
- }
-
- OpResult getTiedOpResult(OpOperand *opOperand) {
- return cast<DestinationStyleOpInterface>(*this->getOperation())
- .getTiedOpResult(opOperand);
- }
-
- bool hasBufferSemantics() {
- return cast<DestinationStyleOpInterface>(*this->getOperation())
- .hasBufferSemantics();
- }
-
- bool hasTensorSemantics() {
- return cast<DestinationStyleOpInterface>(*this->getOperation())
- .hasTensorSemantics();
- }
}];
let verify = [{ return detail::verifyStructuredOpInterface($_op); }];
``````````
</details>
https://github.com/llvm/llvm-project/pull/66995
More information about the Mlir-commits
mailing list