[Mlir-commits] [mlir] 53edf12 - [mlir] Add `res()` method to `linalg::ContractionOpInterface` (#76539)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jan 3 19:34:23 PST 2024
Author: Jerry Wu
Date: 2024-01-03T22:34:19-05:00
New Revision: 53edf12e526704cc251b6a6917319c7cb7a653a0
URL: https://github.com/llvm/llvm-project/commit/53edf12e526704cc251b6a6917319c7cb7a653a0
DIFF: https://github.com/llvm/llvm-project/commit/53edf12e526704cc251b6a6917319c7cb7a653a0.diff
LOG: [mlir] Add `res()` method to `linalg::ContractionOpInterface` (#76539)
In addition to `lhs()` and `rhs()` to return left and right operands,
add `res()` to return the result value.
Added:
mlir/unittests/Dialect/Linalg/CMakeLists.txt
mlir/unittests/Dialect/Linalg/LinalgInterfacesTest.cpp
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
mlir/unittests/Dialect/CMakeLists.txt
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index fbf3f19cde0e9b..777d7cfd558d2a 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -54,6 +54,14 @@ def LinalgContractionOpInterface : OpInterface<"ContractionOpInterface"> {
return $_op.getOperation()->getOperand(1);
}]>,
InterfaceMethod<
+ /*desc=*/"Returns the result value.",
+ /*retTy=*/"OpResult",
+ /*methodName=*/"res",
+ /*args=*/(ins),
+ /*methodBody=*/[{
+ return $_op.getOperation()->getResult(0);
+ }]>,
+ InterfaceMethod<
/*desc=*/[{
Returns whether the given op has indexing maps that correspond to a
row-major matmul operation.
diff --git a/mlir/unittests/Dialect/CMakeLists.txt b/mlir/unittests/Dialect/CMakeLists.txt
index 2dec4ba3c001e8..76b698d1d1a7b1 100644
--- a/mlir/unittests/Dialect/CMakeLists.txt
+++ b/mlir/unittests/Dialect/CMakeLists.txt
@@ -8,6 +8,7 @@ target_link_libraries(MLIRDialectTests
add_subdirectory(ArmSME)
add_subdirectory(Index)
+add_subdirectory(Linalg)
add_subdirectory(LLVMIR)
add_subdirectory(MemRef)
add_subdirectory(SCF)
diff --git a/mlir/unittests/Dialect/Linalg/CMakeLists.txt b/mlir/unittests/Dialect/Linalg/CMakeLists.txt
new file mode 100644
index 00000000000000..080caab8d075ec
--- /dev/null
+++ b/mlir/unittests/Dialect/Linalg/CMakeLists.txt
@@ -0,0 +1,8 @@
+add_mlir_unittest(MLIRLinalgTests
+ LinalgInterfacesTest.cpp
+)
+target_link_libraries(MLIRLinalgTests
+ PRIVATE
+ MLIRLinalgDialect
+ )
+
diff --git a/mlir/unittests/Dialect/Linalg/LinalgInterfacesTest.cpp b/mlir/unittests/Dialect/Linalg/LinalgInterfacesTest.cpp
new file mode 100644
index 00000000000000..8cc4a5e37c4529
--- /dev/null
+++ b/mlir/unittests/Dialect/Linalg/LinalgInterfacesTest.cpp
@@ -0,0 +1,43 @@
+//===- LinalgInterfacesTest.cpp - LinalgInterfaces unit tests ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+
+#include "gtest/gtest.h"
+
+using namespace mlir;
+
+class LinalgInterfacesTest : public ::testing::Test {
+protected:
+ LinalgInterfacesTest() {
+ context.getOrLoadDialect<mlir::linalg::LinalgDialect>();
+ }
+
+ mlir::MLIRContext context;
+};
+
+TEST_F(LinalgInterfacesTest, ContractionOpOperandResultAccessor) {
+ OpBuilder b(&context);
+ SmallVector<int64_t> lhsShape = {1, 2};
+ SmallVector<int64_t> rhsShape = {2, 4};
+ SmallVector<int64_t> resShape = {1, 4};
+ auto lhs = b.create<tensor::EmptyOp>(UnknownLoc::get(&context), lhsShape,
+ b.getF32Type());
+ auto rhs = b.create<tensor::EmptyOp>(UnknownLoc::get(&context), rhsShape,
+ b.getF32Type());
+ auto out = b.create<tensor::EmptyOp>(UnknownLoc::get(&context), resShape,
+ b.getF32Type());
+ Operation *op = b.create<linalg::MatmulOp>(
+ UnknownLoc::get(&context), ValueRange{lhs, rhs}, ValueRange{out});
+ auto contractOp = llvm::cast<linalg::ContractionOpInterface>(op);
+
+ EXPECT_EQ(contractOp.lhs(), op->getOperand(0));
+ EXPECT_EQ(contractOp.rhs(), op->getOperand(1));
+ EXPECT_EQ(contractOp.res(), op->getResult(0));
+}
More information about the Mlir-commits
mailing list