[Mlir-commits] [mlir] [NFC] Make AggregateOpInterface part of mlir:: instead of linalg:: (PR #70089)

Abhishek Varma llvmlistbot at llvm.org
Thu Oct 26 03:25:22 PDT 2023


https://github.com/Abhishek-Varma updated https://github.com/llvm/llvm-project/pull/70089

>From e953a8dc444982a83f701f8b0cbc8377db252953 Mon Sep 17 00:00:00 2001
From: Abhishek Varma <avarma094 at gmail.com>
Date: Tue, 24 Oct 2023 18:23:55 +0000
Subject: [PATCH] [NFC] Make AggregateOpInterface part of mlir:: instead of
 linalg::

-- Currently, AggregateOpInterface is part of mlir::linalg:: namespace
   so this commit makes it part of a generic mlir:: namespace.

Signed-off-by: Abhishek Varma <abhishek at nod-labs.com>
---
 .../mlir/Dialect/Linalg/IR/LinalgInterfaces.h |  1 +
 .../Dialect/Linalg/IR/LinalgInterfaces.td     | 30 ----------
 .../mlir/Dialect/Linalg/IR/LinalgOps.td       |  1 +
 .../mlir/Interfaces/AggregatedOpInterface.h   | 26 +++++++++
 .../mlir/Interfaces/AggregatedOpInterface.td  | 57 +++++++++++++++++++
 mlir/include/mlir/Interfaces/CMakeLists.txt   |  1 +
 .../TransformOps/LinalgTransformOps.cpp       |  1 +
 mlir/lib/Interfaces/AggregatedOpInterface.cpp | 15 +++++
 mlir/lib/Interfaces/CMakeLists.txt            |  5 +-
 9 files changed, 106 insertions(+), 31 deletions(-)
 create mode 100644 mlir/include/mlir/Interfaces/AggregatedOpInterface.h
 create mode 100644 mlir/include/mlir/Interfaces/AggregatedOpInterface.td
 create mode 100644 mlir/lib/Interfaces/AggregatedOpInterface.cpp

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
index f6ba6586a81a244..d10cab16d80ac06 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
@@ -19,6 +19,7 @@
 #include "mlir/IR/IRMapping.h"
 #include "mlir/IR/ImplicitLocOpBuilder.h"
 #include "mlir/IR/OpDefinition.h"
+#include "mlir/Interfaces/AggregatedOpInterface.h"
 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/ViewLikeInterface.h"
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index 44e82f452b3cef1..99b2520b5750da9 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -817,34 +817,4 @@ def LinalgStructuredInterface
   let verifyWithRegions = 1;
 }
 
-def AggregatedOpInterface : OpInterface<"AggregatedOpInterface"> {
-  let description = [{
-    Interface for decomposing aggregated operations into a sequence of simpler
-    ops.
-  }];
-  let cppNamespace = "::mlir::linalg";
-  let methods = [
-      InterfaceMethod<
-        /*desc=*/[{
-          Method to decompose the operation into simpler operations.
-
-          On success, this method returns one `Value` per result in the
-          original operation.
-          The order of the returned values must match the order of the
-          original values.
-          In other words, the returned vector can be used directly with
-          `RewriterBase::replaceOp(this, returnedValues)`.
-        }],
-        /*retType=*/"FailureOr<SmallVector<Value>>",
-        /*methodName=*/"decomposeOperation",
-        /*args=*/(ins
-            "OpBuilder &":$b),
-        /*methodBody=*/"",
-        /*defaultImplementation=*/[{
-          return {};
-        }]
-      >
-  ];
-}
-
 #endif // LINALG_IR_LINALGINTERFACES
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index da12e7c83b22b89..03372753757e280 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -15,6 +15,7 @@
 
 include "mlir/Dialect/Linalg/IR/LinalgBase.td"
 include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td"
+include "mlir/Interfaces/AggregatedOpInterface.td"
 include "mlir/Interfaces/ControlFlowInterfaces.td"
 include "mlir/Interfaces/DestinationStyleOpInterface.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
diff --git a/mlir/include/mlir/Interfaces/AggregatedOpInterface.h b/mlir/include/mlir/Interfaces/AggregatedOpInterface.h
new file mode 100644
index 000000000000000..4dc159db3c38506
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/AggregatedOpInterface.h
@@ -0,0 +1,26 @@
+//===- AggregatedOpInterface.h ----------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains an interface for decomposing operations.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_AGGREGATEDOPINTERFACE_H_
+#define MLIR_INTERFACES_AGGREGATEDOPINTERFACE_H_
+
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/Value.h"
+#include "llvm/ADT/SmallVector.h"
+
+/// Include the generated interface declarations.
+#include "mlir/Interfaces/AggregatedOpInterface.h.inc"
+
+#endif // MLIR_INTERFACES_AGGREGATEDOPINTERFACE_H_
diff --git a/mlir/include/mlir/Interfaces/AggregatedOpInterface.td b/mlir/include/mlir/Interfaces/AggregatedOpInterface.td
new file mode 100644
index 000000000000000..646fd477441bc68
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/AggregatedOpInterface.td
@@ -0,0 +1,57 @@
+//===- AggregatedOpInterface.td ----------------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_AGGREGATEDOPINTERFACE
+#define MLIR_AGGREGATEDOPINTERFACE
+
+include "mlir/IR/OpBase.td"
+
+def AggregatedOpInterface : OpInterface<"AggregatedOpInterface"> {
+  let description = [{
+    This Interface is particularly useful in cases where we have an operation
+    that can be lowered into a sequence of simpler operations, thus essentially
+    decomposing an operation into a set of one or many simpler operations.
+    The operation being decomposed need to implement this Interface by implementing
+    the method `decomposeOperation` and return the set of values which would replace
+    the uses of the operation being decomposed.
+    Eg:
+        Assume there is an operation `CustomOp_Mul_Add` that takes in an input tensor
+        and a constant. It basically performs element-wise multiplication of the input
+        tensor with the given constant, and then performs element-wise addition of the
+        intermediate resulting tensor with the given constant.
+        `CustomOp_Mul_Add` can thus essentially be decomposed by implementing this
+        Interface.
+        `linalg::SoftmaxOp` is one such operation which makes use of this Interface
+        for implementing its decomposition.
+  }];
+  let cppNamespace = "::mlir";
+  let methods = [
+      InterfaceMethod<
+        /*desc=*/[{
+          Method to decompose the operation into simpler operations.
+
+          On success, this method returns one `Value` per result in the
+          original operation.
+          The order of the returned values must match the order of the
+          original values.
+          In other words, the returned vector can be used directly with
+          `RewriterBase::replaceOp(this, returnedValues)`.
+        }],
+        /*retType=*/"FailureOr<SmallVector<Value>>",
+        /*methodName=*/"decomposeOperation",
+        /*args=*/(ins
+            "OpBuilder &":$b),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/[{
+          return {};
+        }]
+      >
+  ];
+}
+
+#endif // MLIR_AGGREGATEDOPINTERFACE
diff --git a/mlir/include/mlir/Interfaces/CMakeLists.txt b/mlir/include/mlir/Interfaces/CMakeLists.txt
index 701b46889194da9..855b669079cd471 100644
--- a/mlir/include/mlir/Interfaces/CMakeLists.txt
+++ b/mlir/include/mlir/Interfaces/CMakeLists.txt
@@ -1,3 +1,4 @@
+add_mlir_interface(AggregatedOpInterface)
 add_mlir_interface(CallInterfaces)
 add_mlir_interface(CastInterfaces)
 add_mlir_interface(ControlFlowInterfaces)
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 8508507871d0c6c..c1871e604422a6e 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -36,6 +36,7 @@
 #include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
+#include "mlir/Interfaces/AggregatedOpInterface.h"
 #include "mlir/Interfaces/TilingInterface.h"
 #include "mlir/Support/LLVM.h"
 #include "mlir/Support/TypeID.h"
diff --git a/mlir/lib/Interfaces/AggregatedOpInterface.cpp b/mlir/lib/Interfaces/AggregatedOpInterface.cpp
new file mode 100644
index 000000000000000..2601cb70eb59878
--- /dev/null
+++ b/mlir/lib/Interfaces/AggregatedOpInterface.cpp
@@ -0,0 +1,15 @@
+//===- AggregatedOpInterface.cpp ------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Interfaces/AggregatedOpInterface.h"
+
+using namespace mlir;
+
+namespace mlir {
+#include "mlir/Interfaces/AggregatedOpInterface.cpp.inc"
+} // namespace mlir
diff --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt
index 6067a7d8a62926d..05b8b2e3633673c 100644
--- a/mlir/lib/Interfaces/CMakeLists.txt
+++ b/mlir/lib/Interfaces/CMakeLists.txt
@@ -1,4 +1,5 @@
 set(LLVM_OPTIONAL_SOURCES
+  AggregatedOpInterface.cpp
   CallInterfaces.cpp
   CastInterfaces.cpp
   ControlFlowInterfaces.cpp
@@ -37,7 +38,7 @@ function(add_mlir_interface_library name)
     )
 endfunction(add_mlir_interface_library)
 
-
+add_mlir_interface_library(AggregatedOpInterface)
 add_mlir_interface_library(CallInterfaces)
 add_mlir_interface_library(CastInterfaces)
 add_mlir_interface_library(ControlFlowInterfaces)
@@ -93,10 +94,12 @@ add_mlir_library(MLIRValueBoundsOpInterface
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Interfaces
 
   DEPENDS
+  MLIRAggregatedOpInterface
   MLIRDestinationStyleOpInterface
   MLIRValueBoundsOpInterfaceIncGen
 
   LINK_LIBS PUBLIC
+  MLIRAggregatedOpInterface
   MLIRAnalysis
   MLIRDestinationStyleOpInterface
   MLIRIR



More information about the Mlir-commits mailing list