[Mlir-commits] [mlir] 9f186bb - [mlir][ods] Make Type- and AttrInterfaces also `Type`s and `Attr`s

Markus Böck llvmlistbot at llvm.org
Thu Jul 7 03:21:57 PDT 2022


Author: Markus Böck
Date: 2022-07-07T11:54:47+02:00
New Revision: 9f186bb125d697786066f1fdd1d0c0e0479a3a4d

URL: https://github.com/llvm/llvm-project/commit/9f186bb125d697786066f1fdd1d0c0e0479a3a4d
DIFF: https://github.com/llvm/llvm-project/commit/9f186bb125d697786066f1fdd1d0c0e0479a3a4d.diff

LOG: [mlir][ods] Make Type- and AttrInterfaces also `Type`s and `Attr`s

By making TypeInterfaces and AttrInterfaces, Types and Attrs respectively it'd then be possible to use them anywhere where a Type or Attr may go. That is within the arguments and results of an Op definition, in a RewritePattern etc.

Prior to this change users had to separately define a Type or Attr, with a predicate to check whether a type or attribute implements a given interface. Such code will be redundant now.
Removing such occurrences in upstream dialects will be part of a separate patch.

As part of implementing this patch, slight refactoring had to be done. In particular, Interfaces cppClassName field was renamed to cppInterfaceName as it "clashed" with TypeConstraints cppClassName. In particular Interfaces cppClassName expected just the class name, without any namespaces, while TypeConstraints cppClassName expected a fully qualified class name.

Differential Revision: https://reviews.llvm.org/D129209

Added: 
    mlir/test/mlir-tblgen/interfaces-as-constraints.td

Modified: 
    mlir/docs/PDLL.md
    mlir/include/mlir/IR/OpBase.td
    mlir/lib/TableGen/Interfaces.cpp
    mlir/lib/Tools/PDLL/Parser/Parser.cpp
    mlir/test/mlir-pdll/Parser/include_td.pdll
    mlir/tools/mlir-tblgen/OpFormatGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/PDLL.md b/mlir/docs/PDLL.md
index 2aadbb6035d0f..340940f38547c 100644
--- a/mlir/docs/PDLL.md
+++ b/mlir/docs/PDLL.md
@@ -1225,7 +1225,7 @@ was imported:
     - Imported `Type` constraints utilize the `cppClassName` field for native type translation.
 
   * `AttrInterface`/`OpInterface`/`TypeInterface` constraints
-    - Imported interfaces utilize the `cppClassName` field for native type translation.
+    - Imported interfaces utilize the `cppInterfaceName` field for native type translation.
 
 #### Defining Constraints Inline
 

diff  --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 89c7122b5ba78..16174454b7a14 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -1949,7 +1949,7 @@ class Interface<string name> {
   string description = "";
 
   // The name given to the c++ interface class.
-  string cppClassName = name;
+  string cppInterfaceName = name;
 
   // The C++ namespace that this interface should be placed into.
   //
@@ -1970,13 +1970,25 @@ class Interface<string name> {
 }
 
 // AttrInterface represents an interface registered to an attribute.
-class AttrInterface<string name> : Interface<name>, InterfaceTrait<name>;
+class AttrInterface<string name> : Interface<name>, InterfaceTrait<name>,
+	Attr<CPred<"$_self.isa<"
+		# !if(!empty(cppNamespace),"", cppNamespace # "::") # name # ">()">,
+			name # " instance">
+{
+	let storageType = !if(!empty(cppNamespace), "", cppNamespace # "::") # name;
+	let returnType = storageType;
+	let convertFromStorage = "$_self";
+}
 
 // OpInterface represents an interface registered to an operation.
 class OpInterface<string name> : Interface<name>, OpInterfaceTrait<name>;
 
 // TypeInterface represents an interface registered to a type.
-class TypeInterface<string name> : Interface<name>, InterfaceTrait<name>;
+class TypeInterface<string name> : Interface<name>, InterfaceTrait<name>,
+	Type<CPred<"$_self.isa<"
+		# !if(!empty(cppNamespace),"", cppNamespace # "::") # name # ">()">,
+			name # " instance",
+				!if(!empty(cppNamespace),"", cppNamespace # "::") # name>;
 
 // Whether to declare the interface methods in the user entity's header. This
 // class simply wraps an Interface but is used to indicate that the method
@@ -1992,27 +2004,27 @@ class DeclareInterfaceMethods<list<string> overridenMethods = []> {
 class DeclareAttrInterfaceMethods<AttrInterface interface,
                                   list<string> overridenMethods = []>
       : DeclareInterfaceMethods<overridenMethods>,
-        AttrInterface<interface.cppClassName> {
+        AttrInterface<interface.cppInterfaceName> {
     let description = interface.description;
-    let cppClassName = interface.cppClassName;
+    let cppInterfaceName = interface.cppInterfaceName;
     let cppNamespace = interface.cppNamespace;
     let methods = interface.methods;
 }
 class DeclareOpInterfaceMethods<OpInterface interface,
                                 list<string> overridenMethods = []>
       : DeclareInterfaceMethods<overridenMethods>,
-        OpInterface<interface.cppClassName> {
+        OpInterface<interface.cppInterfaceName> {
     let description = interface.description;
-    let cppClassName = interface.cppClassName;
+    let cppInterfaceName = interface.cppInterfaceName;
     let cppNamespace = interface.cppNamespace;
     let methods = interface.methods;
 }
 class DeclareTypeInterfaceMethods<TypeInterface interface,
                                   list<string> overridenMethods = []>
       : DeclareInterfaceMethods<overridenMethods>,
-        TypeInterface<interface.cppClassName> {
+        TypeInterface<interface.cppInterfaceName> {
     let description = interface.description;
-    let cppClassName = interface.cppClassName;
+    let cppInterfaceName = interface.cppInterfaceName;
     let cppNamespace = interface.cppNamespace;
     let methods = interface.methods;
 }

diff  --git a/mlir/lib/TableGen/Interfaces.cpp b/mlir/lib/TableGen/Interfaces.cpp
index 4d72ceeb45fc9..1ee0b140756f6 100644
--- a/mlir/lib/TableGen/Interfaces.cpp
+++ b/mlir/lib/TableGen/Interfaces.cpp
@@ -81,7 +81,7 @@ Interface::Interface(const llvm::Record *def) : def(def) {
 
 // Return the name of this interface.
 StringRef Interface::getName() const {
-  return def->getValueAsString("cppClassName");
+  return def->getValueAsString("cppInterfaceName");
 }
 
 // Return the C++ namespace of this interface.

diff  --git a/mlir/lib/Tools/PDLL/Parser/Parser.cpp b/mlir/lib/Tools/PDLL/Parser/Parser.cpp
index 4b7fd85227aa0..55b1e3947f3bc 100644
--- a/mlir/lib/Tools/PDLL/Parser/Parser.cpp
+++ b/mlir/lib/Tools/PDLL/Parser/Parser.cpp
@@ -873,38 +873,43 @@ void Parser::processTdIncludeRecords(llvm::RecordKeeper &tdRecords,
                           addTypeConstraint(result));
     }
   }
+
+  auto shouldBeSkipped = [this](llvm::Record *def) {
+    return def->isAnonymous() || curDeclScope->lookup(def->getName()) ||
+           def->isSubClassOf("DeclareInterfaceMethods");
+  };
+
   /// Attr constraints.
   for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Attr")) {
-    if (!def->isAnonymous() && !curDeclScope->lookup(def->getName())) {
-      tblgen::Attribute constraint(def);
-      decls.push_back(
-          createODSNativePDLLConstraintDecl<ast::AttrConstraintDecl>(
-              constraint, convertLocToRange(def->getLoc().front()), attrTy,
-              constraint.getStorageType()));
-    }
+    if (shouldBeSkipped(def))
+      continue;
+
+    tblgen::Attribute constraint(def);
+    decls.push_back(createODSNativePDLLConstraintDecl<ast::AttrConstraintDecl>(
+        constraint, convertLocToRange(def->getLoc().front()), attrTy,
+        constraint.getStorageType()));
   }
   /// Type constraints.
   for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Type")) {
-    if (!def->isAnonymous() && !curDeclScope->lookup(def->getName())) {
-      tblgen::TypeConstraint constraint(def);
-      decls.push_back(
-          createODSNativePDLLConstraintDecl<ast::TypeConstraintDecl>(
-              constraint, convertLocToRange(def->getLoc().front()), typeTy,
-              constraint.getCPPClassName()));
-    }
+    if (shouldBeSkipped(def))
+      continue;
+
+    tblgen::TypeConstraint constraint(def);
+    decls.push_back(createODSNativePDLLConstraintDecl<ast::TypeConstraintDecl>(
+        constraint, convertLocToRange(def->getLoc().front()), typeTy,
+        constraint.getCPPClassName()));
   }
-  /// Interfaces.
+  /// OpInterfaces.
   ast::Type opTy = ast::OperationType::get(ctx);
-  for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Interface")) {
-    StringRef name = def->getName();
-    if (def->isAnonymous() || curDeclScope->lookup(name) ||
-        def->isSubClassOf("DeclareInterfaceMethods"))
+  for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("OpInterface")) {
+    if (shouldBeSkipped(def))
       continue;
+
     SMRange loc = convertLocToRange(def->getLoc().front());
 
     std::string cppClassName =
         llvm::formatv("{0}::{1}", def->getValueAsString("cppNamespace"),
-                      def->getValueAsString("cppClassName"))
+                      def->getValueAsString("cppInterfaceName"))
             .str();
     std::string codeBlock =
         llvm::formatv("return ::mlir::success(llvm::isa<{0}>(self));",
@@ -913,18 +918,8 @@ void Parser::processTdIncludeRecords(llvm::RecordKeeper &tdRecords,
 
     std::string desc =
         processAndFormatDoc(def->getValueAsString("description"));
-    if (def->isSubClassOf("OpInterface")) {
-      decls.push_back(createODSNativePDLLConstraintDecl<ast::OpConstraintDecl>(
-          name, codeBlock, loc, opTy, cppClassName, desc));
-    } else if (def->isSubClassOf("AttrInterface")) {
-      decls.push_back(
-          createODSNativePDLLConstraintDecl<ast::AttrConstraintDecl>(
-              name, codeBlock, loc, attrTy, cppClassName, desc));
-    } else if (def->isSubClassOf("TypeInterface")) {
-      decls.push_back(
-          createODSNativePDLLConstraintDecl<ast::TypeConstraintDecl>(
-              name, codeBlock, loc, typeTy, cppClassName, desc));
-    }
+    decls.push_back(createODSNativePDLLConstraintDecl<ast::OpConstraintDecl>(
+        def->getName(), codeBlock, loc, opTy, cppClassName, desc));
   }
 }
 

diff  --git a/mlir/test/mlir-pdll/Parser/include_td.pdll b/mlir/test/mlir-pdll/Parser/include_td.pdll
index f90f7ab8a4126..5526aa852482f 100644
--- a/mlir/test/mlir-pdll/Parser/include_td.pdll
+++ b/mlir/test/mlir-pdll/Parser/include_td.pdll
@@ -32,21 +32,21 @@
 // CHECK-NEXT:   CppClass: ::mlir::IntegerType
 // CHECK-NEXT: }
 
-// CHECK: UserConstraintDecl {{.*}} Name<TestAttrInterface> ResultType<Tuple<>> Code<return ::mlir::success(llvm::isa<::TestAttrInterface>(self));>
+// CHECK: UserConstraintDecl {{.*}} Name<TestAttrInterface> ResultType<Tuple<>> Code<return ::mlir::success((self.isa<TestAttrInterface>()));>
 // CHECK:  `Inputs`
 // CHECK:    `-VariableDecl {{.*}} Name<self> Type<Attr>
 // CHECK:      `Constraints`
 // CHECK:        `-AttrConstraintDecl
 
+// CHECK: UserConstraintDecl {{.*}} Name<TestTypeInterface> ResultType<Tuple<>> Code<return ::mlir::success((self.isa<TestTypeInterface>()));>
+// CHECK:  `Inputs`
+// CHECK:    `-VariableDecl {{.*}} Name<self> Type<Type>
+// CHECK:      `Constraints`
+// CHECK:        `-TypeConstraintDecl {{.*}}
+
 // CHECK: UserConstraintDecl {{.*}} Name<TestOpInterface> ResultType<Tuple<>> Code<return ::mlir::success(llvm::isa<::TestOpInterface>(self));>
 // CHECK:  `Inputs`
 // CHECK:    `-VariableDecl {{.*}} Name<self> Type<Op>
 // CHECK:      `Constraints`
 // CHECK:        `-OpConstraintDecl
 // CHECK:          `-OpNameDecl
-
-// CHECK: UserConstraintDecl {{.*}} Name<TestTypeInterface> ResultType<Tuple<>> Code<return ::mlir::success(llvm::isa<::TestTypeInterface>(self));>
-// CHECK:  `Inputs`
-// CHECK:    `-VariableDecl {{.*}} Name<self> Type<Type>
-// CHECK:      `Constraints`
-// CHECK:        `-TypeConstraintDecl {{.*}}

diff  --git a/mlir/test/mlir-tblgen/interfaces-as-constraints.td b/mlir/test/mlir-tblgen/interfaces-as-constraints.td
new file mode 100644
index 0000000000000..5963dd8bb8acc
--- /dev/null
+++ b/mlir/test/mlir-tblgen/interfaces-as-constraints.td
@@ -0,0 +1,47 @@
+// RUN: mlir-tblgen -gen-op-defs -I %S/../../include %s | FileCheck %s
+
+include "mlir/IR/OpBase.td"
+
+def Test_Dialect : Dialect {
+  let name = "test";
+}
+
+def TopLevelTypeInterface : TypeInterface<"TopLevelTypeInterface">;
+
+def TypeInterfaceInNamespace : TypeInterface<"TypeInterfaceInNamespace"> {
+	let cppNamespace = "test";
+}
+
+def TopLevelAttrInterface : AttrInterface<"TopLevelAttrInterface">;
+
+def AttrInterfaceInNamespace : AttrInterface<"AttrInterfaceInNamespace"> {
+	let cppNamespace = "test";
+}
+
+def OpUsingAllOfThose : Op<Test_Dialect, "OpUsingAllOfThose"> {
+	let arguments = (ins TopLevelAttrInterface:$attr1, AttrInterfaceInNamespace:$attr2);
+	let results = (outs TopLevelTypeInterface:$res1, TypeInterfaceInNamespace:$res2);
+}
+
+// CHECK: static ::mlir::LogicalResult {{__mlir_ods_local_type_constraint.*}}(
+// CHECK:   if (!((type.isa<TopLevelTypeInterface>()))) {
+// CHECK-NEXT:    return op->emitOpError(valueKind) << " #" << valueIndex
+// CHECK-NEXT:        << " must be TopLevelTypeInterface instance, but got " << type;
+
+// CHECK: static ::mlir::LogicalResult {{__mlir_ods_local_type_constraint.*}}(
+// CHECK:   if (!((type.isa<test::TypeInterfaceInNamespace>()))) {
+// CHECK-NEXT:    return op->emitOpError(valueKind) << " #" << valueIndex
+// CHECK-NEXT:        << " must be TypeInterfaceInNamespace instance, but got " << type;
+
+// CHECK: static ::mlir::LogicalResult {{__mlir_ods_local_attr_constraint.*}}(
+// CHECK:   if (attr && !((attr.isa<TopLevelAttrInterface>()))) {
+// CHECK-NEXT:    return op->emitOpError("attribute '") << attrName
+// CHECK-NEXT:        << "' failed to satisfy constraint: TopLevelAttrInterface instance";
+
+// CHECK: static ::mlir::LogicalResult {{__mlir_ods_local_attr_constraint.*}}(
+// CHECK:   if (attr && !((attr.isa<test::AttrInterfaceInNamespace>()))) {
+// CHECK-NEXT:    return op->emitOpError("attribute '") << attrName
+// CHECK-NEXT:        << "' failed to satisfy constraint: AttrInterfaceInNamespace instance";
+
+// CHECK: TopLevelAttrInterface OpUsingAllOfThose::attr1()
+// CHECK: test::AttrInterfaceInNamespace OpUsingAllOfThose::attr2()

diff  --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index 212fe0e1204e5..54190aac15fa0 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -2304,7 +2304,7 @@ LogicalResult OpFormatParser::verify(SMLoc loc,
       //    DeclareOpInterfaceMethods<InferTypeOpInterface>
       // and the like.
       // TODO: Add hasCppInterface check.
-      if (auto name = def.getValueAsOptionalString("cppClassName")) {
+      if (auto name = def.getValueAsOptionalString("cppInterfaceName")) {
         if (*name == "InferTypeOpInterface" &&
             def.getValueAsString("cppNamespace") == "::mlir")
           canInferResultTypes = true;


        


More information about the Mlir-commits mailing list