[Mlir-commits] [mlir] [mlir][IR] Support op interfaces in `HasParent` trait (PR #91471)

Matthias Springer llvmlistbot at llvm.org
Wed May 8 06:43:41 PDT 2024


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/91471

This commit adds support for op interfaces to `HasParent`: an op interface can now be specified as a parent.

To produce useful error messages, a new helper function `getInterfaceName` is generated for every op interface. This is similar to `getOperationName`, which is generated for operations.

This commit addresses a TODO in `TensorOps.td`.

>From 7b1d6a75b9f048e85b8b752218999eb39bf14d22 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Wed, 8 May 2024 15:39:08 +0200
Subject: [PATCH] [mlir][IR] Support op interfaces in `HasParent` trait

This commit adds support for op interfaces to `HasParent`: an op interface can now be specified as a parent.

To produce useful error messages, a new helper function `getInterfaceName` is generated for every op interface. This is similar to `getOperationName`, which is generated for operations.

This commit addresses a TODO in `TensorOps.td`.
---
 .../mlir/Dialect/Tensor/IR/TensorOps.td       |  3 +--
 mlir/include/mlir/IR/OpBase.td                |  2 +-
 mlir/include/mlir/IR/OpDefinition.h           | 23 ++++++++++++++++++-
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp      |  4 ----
 mlir/test/Dialect/Tensor/invalid.mlir         |  9 ++++++++
 mlir/tools/mlir-tblgen/OpInterfacesGen.cpp    | 10 ++++++--
 6 files changed, 41 insertions(+), 10 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index a403e89a39f98..2d9f4c29f7aad 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1463,8 +1463,7 @@ def Tensor_PadOp : Tensor_Op<"pad", [
 def Tensor_ParallelInsertSliceOp : Tensor_Op<"parallel_insert_slice", [
        AttrSizedOperandSegments,
        OffsetSizeAndStrideOpInterface,
-       // TODO: Cannot use an interface here atm, verify this manually for now.
-       // HasParent<"ParallelCombiningOpInterface">
+       HasParent<"ParallelCombiningOpInterface">
   ]> {
   let summary = [{
     Specify the tensor slice update of a single thread of a parent
diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 7866ac24c1ccb..b089e72fe8928 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -133,7 +133,7 @@ class SingleBlockImplicitTerminator<string op>
 // Op's regions don't have terminator.
 def NoTerminator : NativeOpTrait<"NoTerminator">, StructuralOpTrait;
 
-// Op's parent operation is the provided one.
+// Op's parent operation or op interface is the provided one.
 class HasParent<string op>
     : ParamNativeOpTrait<"HasParent", op>, StructuralOpTrait;
 
diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index 59f094d669099..550f04d9a373b 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -1298,7 +1298,9 @@ struct HasParent {
       return op->emitOpError()
              << "expects parent op "
              << (sizeof...(ParentOpTypes) != 1 ? "to be one of '" : "'")
-             << llvm::ArrayRef({ParentOpTypes::getOperationName()...}) << "'";
+             << llvm::ArrayRef(
+                    {getOperationOrInterfaceName<ParentOpTypes>()...})
+             << "'";
     }
 
     template <typename ParentOpType =
@@ -1309,6 +1311,25 @@ struct HasParent {
       return llvm::cast<ParentOpType>(parent);
     }
   };
+
+private:
+  /// A class is an op interface if it has a `getInterfaceName` function.
+  template <typename T, typename = int>
+  struct IsInterface : std::false_type {};
+  template <typename T>
+  struct IsInterface<T, decltype((void)T::getInterfaceName(), 0)>
+      : std::true_type {};
+
+  /// Helper function that returns the name of the given operation or interface
+  /// as a string literal.
+  template <typename T>
+  static constexpr StringLiteral getOperationOrInterfaceName() {
+    if constexpr (IsInterface<T>::value) {
+      return T::getInterfaceName();
+    } else {
+      return T::getOperationName();
+    }
+  }
 };
 
 /// A trait for operations that have an attribute specifying operand segments.
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 7a13f7a7d1355..f45c2e4efdf58 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -3455,10 +3455,6 @@ void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
 }
 
 LogicalResult ParallelInsertSliceOp::verify() {
-  if (!isa<ParallelCombiningOpInterface>(getOperation()->getParentOp()))
-    return this->emitError("expected ParallelCombiningOpInterface parent, got:")
-           << *(getOperation()->getParentOp());
-
   RankedTensorType expectedType;
   SliceVerificationResult result =
       verifyInsertSliceOp(getSourceType(), getDestType(), getStaticOffsets(),
diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir
index 41b6529f64afa..4205d9c3dcd31 100644
--- a/mlir/test/Dialect/Tensor/invalid.mlir
+++ b/mlir/test/Dialect/Tensor/invalid.mlir
@@ -698,3 +698,12 @@ func.func @unpack_mismatch_inner_tile_size_and_output_shape(
   %0 = tensor.unpack %input inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %output : tensor<?x?x8x8xf32> -> tensor<?x?xf32>
   return %0 : tensor<?x?xf32>
 }
+
+// -----
+
+func.func @parallel_insert_slice_out_of_context(%a: tensor<5xf32>, %b: tensor<100xf32>) {
+  // expected-error at +1 {{expects parent op 'ParallelCombiningOpInterface'}}
+  tensor.parallel_insert_slice %a into %b[0][5][1]
+      : tensor<5xf32> into tensor<100xf32>
+  return
+}
diff --git a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
index 2a7406f42f34b..17babee913f04 100644
--- a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
@@ -537,7 +537,7 @@ void InterfaceGenerator::emitInterfaceDecl(const Interface &interface) {
 
   // Emit the derived trait for the interface.
   os << "template <typename " << valueTemplate << ">\n";
-  os << "struct " << interface.getName() << "Trait;\n";
+  os << "struct " << interfaceName << "Trait;\n";
 
   os << "\n} // namespace detail\n";
 
@@ -548,6 +548,11 @@ void InterfaceGenerator::emitInterfaceDecl(const Interface &interface) {
                       interfaceName, interfaceName, interfaceTraitsName,
                       interfaceBaseType);
 
+  // Insert function that returns the name of the interface as a string.
+  os << "  static constexpr ::llvm::StringLiteral getInterfaceName() {\n"
+     << "    return \"" << interfaceName << "\";\n"
+     << "  }\n\n";
+
   // Emit a utility wrapper trait class.
   os << llvm::formatv("  template <typename {1}>\n"
                       "  struct Trait : public detail::{0}Trait<{1}> {{};\n",
@@ -588,7 +593,8 @@ void InterfaceGenerator::emitInterfaceDecl(const Interface &interface) {
        << "    auto* interface = getInterfaceFor(base);\n"
        << "    if (!interface)\n"
           "      return false;\n"
-          "    " << interfaceName << " odsInterfaceInstance(base, interface);\n"
+          "    "
+       << interfaceName << " odsInterfaceInstance(base, interface);\n"
        << "    " << tblgen::tgfmt(extraClassOf->trim(), &extraClassOfFmt)
        << "\n  }\n";
   }



More information about the Mlir-commits mailing list