[Mlir-commits] [mlir] [mlir][irdl] Add support for basic structural constraints in tblgen-to-irdl (PR #82862)

Fehr Mathieu llvmlistbot at llvm.org
Fri Feb 23 18:58:39 PST 2024


https://github.com/math-fehr created https://github.com/llvm/llvm-project/pull/82862

Adds tblgen-to-irdl support for `TypeDef`, `AnyType`, `AnyTypeOf`, and `AllOfType` ODS constraints.
This is done by introspecting the TableGen constructs directly.

For instance, `shape.add` now looks like:
```
    irdl.operation @add {
      %0 = irdl.base "!shape.size" 
      %1 = irdl.c_pred "(::llvm::isa<::mlir::IndexType>($_self))" 
      %2 = irdl.any_of(%0, %1) 
      %3 = irdl.base "!shape.size" 
      %4 = irdl.c_pred "(::llvm::isa<::mlir::IndexType>($_self))" 
      %5 = irdl.any_of(%3, %4) 
      %6 = irdl.base "!shape.size" 
      %7 = irdl.c_pred "(::llvm::isa<::mlir::IndexType>($_self))" 
      %8 = irdl.any_of(%6, %7) 
      irdl.operands(%2, %5)
      irdl.results(%8)
    }
```

instead of previously

```
    irdl.operation @add {
      %0 = irdl.c_pred "((::llvm::isa<::mlir::shape::SizeType>($_self))) || ((::llvm::isa<::mlir::IndexType>($_self)))" 
      %1 = irdl.c_pred "((::llvm::isa<::mlir::shape::SizeType>($_self))) || ((::llvm::isa<::mlir::IndexType>($_self)))" 
      %2 = irdl.c_pred "((::llvm::isa<::mlir::shape::SizeType>($_self))) || ((::llvm::isa<::mlir::IndexType>($_self)))" 
      irdl.operands(%0, %1)
      irdl.results(%2)
    }
```

>From 15d244bffc72606efc118342fef7dad882d8f5da Mon Sep 17 00:00:00 2001
From: Mathieu Fehr <mathieu.fehr at gmail.com>
Date: Tue, 26 Dec 2023 22:28:19 +0000
Subject: [PATCH] [mlir][irdl] Add support for structural constraints in
 tblgen-to-irdl

---
 mlir/include/mlir/IR/CommonTypeConstraints.td | 22 +++---
 mlir/test/tblgen-to-irdl/CMathDialect.td      | 12 +--
 mlir/test/tblgen-to-irdl/TestDialect.td       | 74 +++++++++++++++++++
 .../tools/tblgen-to-irdl/OpDefinitionsGen.cpp | 48 ++++++++++--
 4 files changed, 134 insertions(+), 22 deletions(-)
 create mode 100644 mlir/test/tblgen-to-irdl/TestDialect.td

diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 03180a687523bf..af4f13dc09360d 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -168,24 +168,28 @@ def NoneType : Type<CPred<"::llvm::isa<::mlir::NoneType>($_self)">, "none type",
       BuildableType<"$_builder.getType<::mlir::NoneType>()">;
 
 // Any type from the given list
-class AnyTypeOf<list<Type> allowedTypes, string summary = "",
+class AnyTypeOf<list<Type> allowedTypeList, string summary = "",
                 string cppClassName = "::mlir::Type"> : Type<
     // Satisfy any of the allowed types' conditions.
-    Or<!foreach(allowedtype, allowedTypes, allowedtype.predicate)>,
+    Or<!foreach(allowedtype, allowedTypeList, allowedtype.predicate)>,
     !if(!eq(summary, ""),
-        !interleave(!foreach(t, allowedTypes, t.summary), " or "),
+        !interleave(!foreach(t, allowedTypeList, t.summary), " or "),
         summary),
-    cppClassName>;
+    cppClassName> {
+  list<Type> allowedTypes = allowedTypeList;
+}
 
 // A type that satisfies the constraints of all given types.
-class AllOfType<list<Type> allowedTypes, string summary = "",
+class AllOfType<list<Type> allowedTypeList, string summary = "",
                 string cppClassName = "::mlir::Type"> : Type<
-    // Satisfy all of the allowedf types' conditions.
-    And<!foreach(allowedType, allowedTypes, allowedType.predicate)>,
+    // Satisfy all of the allowed types' conditions.
+    And<!foreach(allowedType, allowedTypeList, allowedType.predicate)>,
     !if(!eq(summary, ""),
-        !interleave(!foreach(t, allowedTypes, t.summary), " and "),
+        !interleave(!foreach(t, allowedTypeList, t.summary), " and "),
         summary),
-    cppClassName>;
+    cppClassName> {
+  list<Type> allowedTypes = allowedTypeList;
+}
 
 // A type that satisfies additional predicates.
 class ConfinedType<Type type, list<Pred> predicates, string summary = "",
diff --git a/mlir/test/tblgen-to-irdl/CMathDialect.td b/mlir/test/tblgen-to-irdl/CMathDialect.td
index 57ae8afbba5eeb..5b9e756727cb36 100644
--- a/mlir/test/tblgen-to-irdl/CMathDialect.td
+++ b/mlir/test/tblgen-to-irdl/CMathDialect.td
@@ -24,7 +24,7 @@ def CMath_ComplexType : CMath_Type<"ComplexType", "complex"> {
 }
 
 // CHECK:      irdl.operation @identity {
-// CHECK-NEXT:   %0 = irdl.c_pred "(::llvm::isa<cmath::ComplexTypeType>($_self))" 
+// CHECK-NEXT:   %0 = irdl.base "!cmath.complex"
 // CHECK-NEXT:   irdl.operands()
 // CHECK-NEXT:   irdl.results(%0)
 // CHECK-NEXT: }
@@ -33,9 +33,9 @@ def CMath_IdentityOp : CMath_Op<"identity"> {
 }
 
 // CHECK:      irdl.operation @mul {
-// CHECK-NEXT:   %0 = irdl.c_pred "(::llvm::isa<cmath::ComplexTypeType>($_self))" 
-// CHECK-NEXT:   %1 = irdl.c_pred "(::llvm::isa<cmath::ComplexTypeType>($_self))" 
-// CHECK-NEXT:   %2 = irdl.c_pred "(::llvm::isa<cmath::ComplexTypeType>($_self))" 
+// CHECK-NEXT:   %0 = irdl.base "!cmath.complex"
+// CHECK-NEXT:   %1 = irdl.base "!cmath.complex"
+// CHECK-NEXT:   %2 = irdl.base "!cmath.complex"
 // CHECK-NEXT:   irdl.operands(%0, %1)
 // CHECK-NEXT:   irdl.results(%2)
 // CHECK-NEXT: }
@@ -45,8 +45,8 @@ def CMath_MulOp : CMath_Op<"mul"> {
 }
 
 // CHECK:      irdl.operation @norm {
-// CHECK-NEXT:   %0 = irdl.c_pred "(true)" 
-// CHECK-NEXT:   %1 = irdl.c_pred "(::llvm::isa<cmath::ComplexTypeType>($_self))" 
+// CHECK-NEXT:   %0 = irdl.any
+// CHECK-NEXT:   %1 = irdl.base "!cmath.complex"
 // CHECK-NEXT:   irdl.operands(%0)
 // CHECK-NEXT:   irdl.results(%1)
 // CHECK-NEXT: }
diff --git a/mlir/test/tblgen-to-irdl/TestDialect.td b/mlir/test/tblgen-to-irdl/TestDialect.td
new file mode 100644
index 00000000000000..fc40da527db00a
--- /dev/null
+++ b/mlir/test/tblgen-to-irdl/TestDialect.td
@@ -0,0 +1,74 @@
+// RUN: tblgen-to-irdl %s -I=%S/../../include --gen-dialect-irdl-defs --dialect=test | FileCheck %s
+
+include "mlir/IR/OpBase.td"
+include "mlir/IR/AttrTypeBase.td"
+
+// CHECK-LABEL: irdl.dialect @test {
+def Test_Dialect : Dialect {
+  let name = "test";
+}
+
+class Test_Type<string name, string typeMnemonic, list<Trait> traits = []>
+: TypeDef<Test_Dialect, name, traits> {
+  let mnemonic = typeMnemonic;
+}
+
+class Test_Op<string mnemonic, list<Trait> traits = []>
+    : Op<Test_Dialect, mnemonic, traits>;
+
+def Test_SingletonAType : Test_Type<"SingletonAType", "singleton_a"> {}
+def Test_SingletonBType : Test_Type<"SingletonBType", "singleton_b"> {}
+def Test_SingletonCType : Test_Type<"SingletonCType", "singleton_c"> {}
+
+
+// Check that AllOfType is converted correctly.
+def Test_AndOp : Test_Op<"and"> {
+  let arguments = (ins AllOfType<[Test_SingletonAType, AnyType]>:$in);
+}
+// CHECK-LABEL: irdl.operation @and {
+// CHECK-NEXT:    %[[v0:[^ ]*]] = irdl.base "!test.singleton_a"
+// CHECK-NEXT:    %[[v1:[^ ]*]] = irdl.any
+// CHECK-NEXT:    %[[v2:[^ ]*]] = irdl.all_of(%[[v0]], %[[v1]]) 
+// CHECK-NEXT:    irdl.operands(%[[v2]])
+// CHECK-NEXT:    irdl.results()
+// CHECK-NEXT:  }
+
+
+// Check that AnyType is converted correctly.
+def Test_AnyOp : Test_Op<"any"> {
+  let arguments = (ins AnyType:$in);
+}
+// CHECK-LABEL: irdl.operation @any {
+// CHECK-NEXT:    %[[v0:[^ ]*]] = irdl.any
+// CHECK-NEXT:    irdl.operands(%[[v0]])
+// CHECK-NEXT:    irdl.results()
+// CHECK-NEXT:  }
+
+
+// Check that AnyTypeOf is converted correctly.
+def Test_OrOp : Test_Op<"or"> {
+  let arguments = (ins AnyTypeOf<[Test_SingletonAType, Test_SingletonBType, Test_SingletonCType]>:$in);
+}
+// CHECK-LABEL: irdl.operation @or {
+// CHECK-NEXT:    %[[v0:[^ ]*]] = irdl.base "!test.singleton_a"
+// CHECK-NEXT:    %[[v1:[^ ]*]] = irdl.base "!test.singleton_b"
+// CHECK-NEXT:    %[[v2:[^ ]*]] = irdl.base "!test.singleton_c"
+// CHECK-NEXT:    %[[v3:[^ ]*]] = irdl.any_of(%[[v0]], %[[v1]], %[[v2]]) 
+// CHECK-NEXT:    irdl.operands(%[[v3]])
+// CHECK-NEXT:    irdl.results()
+// CHECK-NEXT:  }
+
+
+// Check that variadics and optionals are converted correctly.
+def Test_VariadicityOp : Test_Op<"variadicity"> {
+  let arguments = (ins Variadic<Test_SingletonAType>:$variadic,
+                       Optional<Test_SingletonBType>:$optional,
+                       Test_SingletonCType:$required);
+}
+// CHECK-LABEL: irdl.operation @variadicity {
+// CHECK-NEXT:    %[[v0:[^ ]*]] = irdl.base "!test.singleton_a"
+// CHECK-NEXT:    %[[v1:[^ ]*]] = irdl.base "!test.singleton_b"
+// CHECK-NEXT:    %[[v2:[^ ]*]] = irdl.base "!test.singleton_c"
+// CHECK-NEXT:    irdl.operands(variadic %[[v0]], optional %[[v1]], %[[v2]])
+// CHECK-NEXT:    irdl.results()
+// CHECK-NEXT:  }
diff --git a/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp b/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp
index ba5bf4d9d4abbc..a55f3539f31db0 100644
--- a/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp
+++ b/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp
@@ -39,15 +39,49 @@ llvm::cl::opt<std::string>
     selectedDialect("dialect", llvm::cl::desc("The dialect to gen for"),
                     llvm::cl::cat(dialectGenCat), llvm::cl::Required);
 
-irdl::CPredOp createConstraint(OpBuilder &builder,
-                               NamedTypeConstraint namedConstraint) {
+Value createConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
   MLIRContext *ctx = builder.getContext();
-  // Build the constraint as a string.
-  std::string constraint =
-      namedConstraint.constraint.getPredicate().getCondition();
+  const Record &predRec = constraint.getDef();
+
+  if (predRec.isSubClassOf("Variadic") || predRec.isSubClassOf("Optional"))
+    return createConstraint(builder, predRec.getValueAsDef("baseType"));
+
+  if (predRec.getName() == "AnyType") {
+    auto op = builder.create<irdl::AnyOp>(UnknownLoc::get(ctx));
+    return op.getOutput();
+  }
+
+  if (predRec.isSubClassOf("TypeDef")) {
+    std::string typeName = ("!" + predRec.getValueAsString("typeName")).str();
+    auto op = builder.create<irdl::BaseOp>(UnknownLoc::get(ctx),
+                                           StringAttr::get(ctx, typeName));
+    return op.getOutput();
+  }
+
+  if (predRec.isSubClassOf("AnyTypeOf")) {
+    std::vector<Value> constraints;
+    for (Record *child : predRec.getValueAsListOfDefs("allowedTypes")) {
+      constraints.push_back(
+          createConstraint(builder, tblgen::Constraint(child)));
+    }
+    auto op = builder.create<irdl::AnyOfOp>(UnknownLoc::get(ctx), constraints);
+    return op.getOutput();
+  }
+
+  if (predRec.isSubClassOf("AllOfType")) {
+    std::vector<Value> constraints;
+    for (Record *child : predRec.getValueAsListOfDefs("allowedTypes")) {
+      constraints.push_back(
+          createConstraint(builder, tblgen::Constraint(child)));
+    }
+    auto op = builder.create<irdl::AllOfOp>(UnknownLoc::get(ctx), constraints);
+    return op.getOutput();
+  }
+
+  std::string condition = constraint.getPredicate().getCondition();
   // Build a CPredOp to match the C constraint built.
   irdl::CPredOp op = builder.create<irdl::CPredOp>(
-      UnknownLoc::get(ctx), StringAttr::get(ctx, constraint));
+      UnknownLoc::get(ctx), StringAttr::get(ctx, condition));
   return op;
 }
 
@@ -74,7 +108,7 @@ irdl::OperationOp createIRDLOperation(OpBuilder &builder,
     SmallVector<Value> operands;
     SmallVector<irdl::VariadicityAttr> variadicity;
     for (const NamedTypeConstraint &namedCons : namedCons) {
-      auto operand = createConstraint(consBuilder, namedCons);
+      auto operand = createConstraint(consBuilder, namedCons.constraint);
       operands.push_back(operand);
 
       irdl::VariadicityAttr var;



More information about the Mlir-commits mailing list