[Mlir-commits] [mlir] [mlir][linalg][NFC] Make `LinalgOp` inherit from `DestinationStyleOpInterface` (PR #66995)

Matthias Springer llvmlistbot at llvm.org
Thu Sep 21 02:55:56 PDT 2023


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/66995

Dependent interfaces have been added a while ago and these TODOs can be addressed now.


>From 54c03962b9c5f370da4b303da71178a635a20174 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Thu, 21 Sep 2023 11:54:11 +0200
Subject: [PATCH] [mlir][linalg][NFC] Make `LinalgOp` inherit from
 `DestinationStyleOpInterface`

Dependent interfaces have been added a while ago and these TODOs can be addressed now.
---
 .../Dialect/Linalg/IR/LinalgInterfaces.td     | 110 ++----------------
 1 file changed, 7 insertions(+), 103 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index 78431b9f66f9014..839861c2369ca1d 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -13,6 +13,7 @@
 #ifndef LINALG_IR_LINALGINTERFACES
 #define LINALG_IR_LINALGINTERFACES
 
+include "mlir/Interfaces/DestinationStyleOpInterface.td"
 include "mlir/IR/OpBase.td"
 
 // The 'LinalgContractionOpInterface' provides access to the
@@ -178,7 +179,8 @@ def LinalgFillOpInterface : OpInterface<"FillOpInterface"> {
 }
 
 // The 'LinalgStructuredInterface' provides access to the 'LinalgOp' interface.
-def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
+def LinalgStructuredInterface
+    : OpInterface<"LinalgOp", [DestinationStyleOpInterface]> {
   let cppNamespace = "::mlir::linalg";
   let methods = [
     //===------------------------------------------------------------------===//
@@ -321,13 +323,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
       /*args=*/(ins),
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
-        // MLIR currently does not support dependent interfaces or interface
-        // inheritance. By construction all ops with StructuredOpInterface must
-        // implement DestinationStyleOpInterface.
-        // TODO: reevaluate the need for a cast when a better mechanism exists.
-        return getBlock()->getArguments().take_front(
-            cast<DestinationStyleOpInterface>(*this->getOperation())
-                .getNumDpsInputs());
+        return getBlock()->getArguments().take_front($_op.getNumDpsInputs());
       }]
     >,
     InterfaceMethod<
@@ -339,13 +335,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
       /*args=*/(ins),
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
-        // MLIR currently does not support dependent interfaces or interface
-        // inheritance. By construction all ops with StructuredOpInterface must
-        // implement DestinationStyleOpInterface.
-        // TODO: reevaluate the need for a cast when a better mechanism exists.
-        return getBlock()->getArguments().take_back(
-            cast<DestinationStyleOpInterface>(*this->getOperation())
-                .getNumDpsInits());
+        return getBlock()->getArguments().take_back($_op.getNumDpsInits());
       }]
     >,
     InterfaceMethod<
@@ -418,13 +408,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
         assert(result.getOwner() == this->getOperation());
         auto indexingMaps =
           $_op.getIndexingMaps().template getAsValueRange<AffineMapAttr>();
-        // MLIR currently does not support dependent interfaces or interface
-        // inheritance. By construction all ops with StructuredOpInterface must
-        // implement DestinationStyleOpInterface.
-        // TODO: reevaluate the need for a cast when a better mechanism exists.
-        return *(indexingMaps.begin() +
-                 cast<DestinationStyleOpInterface>(*this->getOperation())
-                     .getNumDpsInputs() +
+        return *(indexingMaps.begin() + $_op.getNumDpsInputs() +
                  result.getResultNumber());
       }]
     >,
@@ -439,14 +423,8 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
         assert(opOperand->getOwner() == this->getOperation());
-        // MLIR currently does not support dependent interfaces or interface
-        // inheritance. By construction all ops with StructuredOpInterface must
-        // implement DestinationStyleOpInterface.
-        // TODO: reevaluate the need for a cast when a better mechanism exists.
         int64_t resultIndex =
-            opOperand->getOperandNumber() -
-            cast<DestinationStyleOpInterface>(*this->getOperation())
-                .getNumDpsInputs();
+            opOperand->getOperandNumber() - $_op.getNumDpsInputs();
         assert(resultIndex >= 0 &&
                resultIndex < this->getOperation()->getNumResults());
         Operation *yieldOp = getBlock()->getTerminator();
@@ -800,80 +778,6 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
 
     /// Return the index in the indexingMaps vector that corresponds to this `opOperand`
     int64_t getIndexingMapIndex(OpOperand *opOperand);
-
-    //========================================================================//
-    // Forwarding functions to access interface methods from the
-    // DestinationStyleOpInterface.
-    // MLIR currently does not support dependent interfaces or interface
-    // inheritance. By construction all ops with StructuredOpInterface must
-    // implement DestinationStyleOpInterface.
-    // TODO: reevaluate the need for a cast when a better mechanism exists.
-    //========================================================================//
-
-    int64_t getNumDpsInputs() {
-      return cast<DestinationStyleOpInterface>(*this->getOperation())
-          .getNumDpsInputs();
-    }
-
-    int64_t getNumDpsInits() {
-      return cast<DestinationStyleOpInterface>(*this->getOperation())
-          .getNumDpsInits();
-    }
-
-    OpOperandVector getDpsInputOperands() {
-      return cast<DestinationStyleOpInterface>(*this->getOperation())
-          .getDpsInputOperands();
-    }
-
-    OpOperand *getDpsInputOperand(int64_t i) {
-      return cast<DestinationStyleOpInterface>(*this->getOperation())
-          .getDpsInputOperand(i);
-    }
-
-    void setDpsInitOperand(int64_t i, Value value) {
-      return cast<DestinationStyleOpInterface>(*this->getOperation())
-          .setDpsInitOperand(i, value);
-    }
-
-    OpOperandVector getDpsInitOperands() {
-      return cast<DestinationStyleOpInterface>(*this->getOperation())
-          .getDpsInitOperands();
-    }
-
-    OpOperand *getDpsInitOperand(int64_t i) {
-      return cast<DestinationStyleOpInterface>(*this->getOperation())
-          .getDpsInitOperand(i);
-    }
-
-    bool isDpsInput(OpOperand *opOperand) {
-      return cast<DestinationStyleOpInterface>(*this->getOperation())
-          .isDpsInput(opOperand);
-    }
-
-    bool isDpsInit(OpOperand *opOperand) {
-      return cast<DestinationStyleOpInterface>(*this->getOperation())
-          .isDpsInit(opOperand);
-    }
-
-    bool isScalar(OpOperand *opOperand) {
-      return cast<DestinationStyleOpInterface>(*this->getOperation())
-          .isScalar(opOperand);
-    }
-
-    OpResult getTiedOpResult(OpOperand *opOperand) {
-      return cast<DestinationStyleOpInterface>(*this->getOperation())
-          .getTiedOpResult(opOperand);
-    }
-
-    bool hasBufferSemantics() {
-      return cast<DestinationStyleOpInterface>(*this->getOperation())
-          .hasBufferSemantics();
-    }
-
-    bool hasTensorSemantics() {
-      return cast<DestinationStyleOpInterface>(*this->getOperation())
-          .hasTensorSemantics();
-    }
   }];
 
   let verify = [{ return detail::verifyStructuredOpInterface($_op); }];



More information about the Mlir-commits mailing list