[Mlir-commits] [mlir] [mlir] Add `res()` method to `linalg::ContractionOpInterface` (PR #76539)

Jerry Wu llvmlistbot at llvm.org
Wed Jan 3 11:03:28 PST 2024


https://github.com/pzread updated https://github.com/llvm/llvm-project/pull/76539

>From 524cf6b787d0289d8855cf055a73013c37746f44 Mon Sep 17 00:00:00 2001
From: Jerry Wu <cheyuw at google.com>
Date: Thu, 28 Dec 2023 21:50:40 +0000
Subject: [PATCH 1/2] Add res() method

---
 mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td | 8 ++++++++
 1 file changed, 8 insertions(+)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index fbf3f19cde0e9b..033cd25a99c025 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -54,6 +54,14 @@ def LinalgContractionOpInterface : OpInterface<"ContractionOpInterface"> {
       return $_op.getOperation()->getOperand(1);
     }]>,
     InterfaceMethod<
+    /*desc=*/"Returns the result value.",
+    /*retTy=*/"OpResult",
+    /*methodName=*/"res",
+    /*args=*/(ins),
+    /*methodBody=*/[{
+      return $_op.getOperation()->getResult(1);
+    }]>,
+    InterfaceMethod<
     /*desc=*/[{
       Returns whether the given op has indexing maps that correspond to a
       row-major matmul operation.

>From c2c55892a5964b0be3e826dfec9fb514b3c0df6d Mon Sep 17 00:00:00 2001
From: Jerry Wu <cheyuw at google.com>
Date: Wed, 3 Jan 2024 19:02:33 +0000
Subject: [PATCH 2/2] Fix and add tests

---
 .../Dialect/Linalg/IR/LinalgInterfaces.td     |  2 +-
 mlir/unittests/Dialect/CMakeLists.txt         |  1 +
 mlir/unittests/Dialect/Linalg/CMakeLists.txt  |  8 ++++
 .../Dialect/Linalg/LinalgInterfacesTest.cpp   | 43 +++++++++++++++++++
 4 files changed, 53 insertions(+), 1 deletion(-)
 create mode 100644 mlir/unittests/Dialect/Linalg/CMakeLists.txt
 create mode 100644 mlir/unittests/Dialect/Linalg/LinalgInterfacesTest.cpp

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index 033cd25a99c025..777d7cfd558d2a 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -59,7 +59,7 @@ def LinalgContractionOpInterface : OpInterface<"ContractionOpInterface"> {
     /*methodName=*/"res",
     /*args=*/(ins),
     /*methodBody=*/[{
-      return $_op.getOperation()->getResult(1);
+      return $_op.getOperation()->getResult(0);
     }]>,
     InterfaceMethod<
     /*desc=*/[{
diff --git a/mlir/unittests/Dialect/CMakeLists.txt b/mlir/unittests/Dialect/CMakeLists.txt
index 2dec4ba3c001e8..76b698d1d1a7b1 100644
--- a/mlir/unittests/Dialect/CMakeLists.txt
+++ b/mlir/unittests/Dialect/CMakeLists.txt
@@ -8,6 +8,7 @@ target_link_libraries(MLIRDialectTests
 
 add_subdirectory(ArmSME)
 add_subdirectory(Index)
+add_subdirectory(Linalg)
 add_subdirectory(LLVMIR)
 add_subdirectory(MemRef)
 add_subdirectory(SCF)
diff --git a/mlir/unittests/Dialect/Linalg/CMakeLists.txt b/mlir/unittests/Dialect/Linalg/CMakeLists.txt
new file mode 100644
index 00000000000000..080caab8d075ec
--- /dev/null
+++ b/mlir/unittests/Dialect/Linalg/CMakeLists.txt
@@ -0,0 +1,8 @@
+add_mlir_unittest(MLIRLinalgTests
+  LinalgInterfacesTest.cpp
+)
+target_link_libraries(MLIRLinalgTests
+  PRIVATE
+  MLIRLinalgDialect
+  )
+
diff --git a/mlir/unittests/Dialect/Linalg/LinalgInterfacesTest.cpp b/mlir/unittests/Dialect/Linalg/LinalgInterfacesTest.cpp
new file mode 100644
index 00000000000000..8cc4a5e37c4529
--- /dev/null
+++ b/mlir/unittests/Dialect/Linalg/LinalgInterfacesTest.cpp
@@ -0,0 +1,43 @@
+//===- LinalgInterfacesTest.cpp - LinalgInterfaces unit tests ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+
+#include "gtest/gtest.h"
+
+using namespace mlir;
+
+class LinalgInterfacesTest : public ::testing::Test {
+protected:
+  LinalgInterfacesTest() {
+    context.getOrLoadDialect<mlir::linalg::LinalgDialect>();
+  }
+
+  mlir::MLIRContext context;
+};
+
+TEST_F(LinalgInterfacesTest, ContractionOpOperandResultAccessor) {
+  OpBuilder b(&context);
+  SmallVector<int64_t> lhsShape = {1, 2};
+  SmallVector<int64_t> rhsShape = {2, 4};
+  SmallVector<int64_t> resShape = {1, 4};
+  auto lhs = b.create<tensor::EmptyOp>(UnknownLoc::get(&context), lhsShape,
+                                       b.getF32Type());
+  auto rhs = b.create<tensor::EmptyOp>(UnknownLoc::get(&context), rhsShape,
+                                       b.getF32Type());
+  auto out = b.create<tensor::EmptyOp>(UnknownLoc::get(&context), resShape,
+                                       b.getF32Type());
+  Operation *op = b.create<linalg::MatmulOp>(
+      UnknownLoc::get(&context), ValueRange{lhs, rhs}, ValueRange{out});
+  auto contractOp = llvm::cast<linalg::ContractionOpInterface>(op);
+
+  EXPECT_EQ(contractOp.lhs(), op->getOperand(0));
+  EXPECT_EQ(contractOp.rhs(), op->getOperand(1));
+  EXPECT_EQ(contractOp.res(), op->getResult(0));
+}



More information about the Mlir-commits mailing list