[Mlir-commits] [mlir] Refactor tblgen-to-irdl script and support more types (PR #105505)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Aug 21 05:05:00 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-core
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-ods

Author: Alex Rice (alexarice)

<details>
<summary>Changes</summary>

Refactors the tblgen-to-irdl script slightly and adds support for
- Various integer types
- Various Float types
- Confined types
- Complex types (with fixed element type)

Also doesn't add the operand and result ops if they are empty.

I could potentially split this into smaller PRs if that'd be helpful (refactor + integer/float/complex, confined type, optional operand/result).

@<!-- -->math-fehr 

---
Full diff: https://github.com/llvm/llvm-project/pull/105505.diff


4 Files Affected:

- (modified) mlir/include/mlir/IR/CommonTypeConstraints.td (+4-1) 
- (modified) mlir/test/tblgen-to-irdl/CMathDialect.td (-1) 
- (modified) mlir/test/tblgen-to-irdl/TestDialect.td (+51-6) 
- (modified) mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp (+169-9) 


``````````diff
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 4536d781ef674f..0e076413d0d9f3 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -198,7 +198,10 @@ class AllOfType<list<Type> allowedTypeList, string summary = "",
 class ConfinedType<Type type, list<Pred> predicates, string summary = "",
                    string cppType = type.cppType> : Type<
     And<!listconcat([type.predicate], !foreach(pred, predicates, pred))>,
-    summary, cppType>;
+    summary, cppType> {
+    Type baseType = type;
+    list<Pred> predicateList = predicates;
+}
 
 // Integer types.
 
diff --git a/mlir/test/tblgen-to-irdl/CMathDialect.td b/mlir/test/tblgen-to-irdl/CMathDialect.td
index 5b9e756727cb36..454543e074c489 100644
--- a/mlir/test/tblgen-to-irdl/CMathDialect.td
+++ b/mlir/test/tblgen-to-irdl/CMathDialect.td
@@ -25,7 +25,6 @@ def CMath_ComplexType : CMath_Type<"ComplexType", "complex"> {
 
 // CHECK:      irdl.operation @identity {
 // CHECK-NEXT:   %0 = irdl.base "!cmath.complex"
-// CHECK-NEXT:   irdl.operands()
 // CHECK-NEXT:   irdl.results(%0)
 // CHECK-NEXT: }
 def CMath_IdentityOp : CMath_Op<"identity"> {
diff --git a/mlir/test/tblgen-to-irdl/TestDialect.td b/mlir/test/tblgen-to-irdl/TestDialect.td
index fc40da527db00a..a86dcb5b3b66e2 100644
--- a/mlir/test/tblgen-to-irdl/TestDialect.td
+++ b/mlir/test/tblgen-to-irdl/TestDialect.td
@@ -28,9 +28,8 @@ def Test_AndOp : Test_Op<"and"> {
 // CHECK-LABEL: irdl.operation @and {
 // CHECK-NEXT:    %[[v0:[^ ]*]] = irdl.base "!test.singleton_a"
 // CHECK-NEXT:    %[[v1:[^ ]*]] = irdl.any
-// CHECK-NEXT:    %[[v2:[^ ]*]] = irdl.all_of(%[[v0]], %[[v1]]) 
+// CHECK-NEXT:    %[[v2:[^ ]*]] = irdl.all_of(%[[v0]], %[[v1]])
 // CHECK-NEXT:    irdl.operands(%[[v2]])
-// CHECK-NEXT:    irdl.results()
 // CHECK-NEXT:  }
 
 
@@ -41,9 +40,37 @@ def Test_AnyOp : Test_Op<"any"> {
 // CHECK-LABEL: irdl.operation @any {
 // CHECK-NEXT:    %[[v0:[^ ]*]] = irdl.any
 // CHECK-NEXT:    irdl.operands(%[[v0]])
-// CHECK-NEXT:    irdl.results()
 // CHECK-NEXT:  }
 
+// Check confined types are converted correctly.
+def Test_ConfinedOp : Test_Op<"confined"> {
+  let arguments = (ins ConfinedType<I32, [IntNonNegative.predicate]>:$confined,
+                       ConfinedType<I8, [And<[IntMinValue<1>.predicate, IntMaxValue<2>.predicate]>]>:$bounded);
+}
+// CHECK-LABEL: irdl.operation @confined {
+// CHECK-NEXT:    %[[v0:[^ ]*]] = irdl.is i32
+// CHECK-NEXT:    %[[v1:[^ ]*]] = irdl.c_pred "{{.*}}"
+// CHECK-NEXT:    %[[v2:[^ ]*]] = irdl.all_of(%[[v0]], %[[v1]])
+// CHECK-NEXT:    %[[v3:[^ ]*]] = irdl.is i8
+// CHECK-NEXT:    %[[v4:[^ ]*]] = irdl.c_pred "{{.*}}"
+// CHECK-NEXT:    %[[v5:[^ ]*]] = irdl.c_pred "{{.*}}"
+// CHECK-NEXT:    %[[v6:[^ ]*]] = irdl.all_of(%[[v4]], %[[v5]])
+// CHECK-NEXT:    %[[v7:[^ ]*]] = irdl.all_of(%[[v3]], %[[v6]])
+// CHECK-NEXT:    irdl.operands(%[[v2]], %[[v7]])
+// CHECK-NEXT:  }
+
+def Test_Integers : Test_Op<"integers"> {
+  let arguments = (ins AnyI8:$any_int,
+                       AnyInteger:$any_integer);
+}
+// CHECK-LABEL: irdl.operation @integers {
+// CHECK-NEXT:    %[[v0:[^ ]*]] = irdl.is i8
+// CHECK-NEXT:    %[[v1:[^ ]*]] = irdl.is si8
+// CHECK-NEXT:    %[[v2:[^ ]*]] = irdl.is ui8
+// CHECK-NEXT:    %[[v3:[^ ]*]] = irdl.any_of(%[[v0]], %[[v1]], %[[v2]])
+// CHECK-NEXT:    %[[v4:[^ ]*]] = irdl.base "!builtin.integer"
+// CHECK-NEXT:    irdl.operands(%[[v3]], %[[v4]])
+// CHECK-NEXT:  }
 
 // Check that AnyTypeOf is converted correctly.
 def Test_OrOp : Test_Op<"or"> {
@@ -53,11 +80,30 @@ def Test_OrOp : Test_Op<"or"> {
 // CHECK-NEXT:    %[[v0:[^ ]*]] = irdl.base "!test.singleton_a"
 // CHECK-NEXT:    %[[v1:[^ ]*]] = irdl.base "!test.singleton_b"
 // CHECK-NEXT:    %[[v2:[^ ]*]] = irdl.base "!test.singleton_c"
-// CHECK-NEXT:    %[[v3:[^ ]*]] = irdl.any_of(%[[v0]], %[[v1]], %[[v2]]) 
+// CHECK-NEXT:    %[[v3:[^ ]*]] = irdl.any_of(%[[v0]], %[[v1]], %[[v2]])
 // CHECK-NEXT:    irdl.operands(%[[v3]])
-// CHECK-NEXT:    irdl.results()
 // CHECK-NEXT:  }
 
+// Check that various types are converted correctly.
+def Test_TypesOp : Test_Op<"types"> {
+  let arguments = (ins I32:$a,
+                       SI64:$b,
+                       UI8:$c,
+                       Index:$d,
+                       F32:$e,
+                       NoneType:$f,
+                       Complex<F8E4M3FN>);
+}
+// CHECK-LABEL: irdl.operation @types {
+// CHECK-NEXT:    %{{.*}} = irdl.is i32
+// CHECK-NEXT:    %{{.*}} = irdl.is si64
+// CHECK-NEXT:    %{{.*}} = irdl.is ui8
+// CHECK-NEXT:    %{{.*}} = irdl.is index
+// CHECK-NEXT:    %{{.*}} = irdl.is f32
+// CHECK-NEXT:    %{{.*}} = irdl.is none
+// CHECK-NEXT:    %{{.*}} = irdl.is complex<f8E4M3FN>
+// CHECK-NEXT:    irdl.operands({{.*}})
+// CHECK-NEXT:  }
 
 // Check that variadics and optionals are converted correctly.
 def Test_VariadicityOp : Test_Op<"variadicity"> {
@@ -70,5 +116,4 @@ def Test_VariadicityOp : Test_Op<"variadicity"> {
 // CHECK-NEXT:    %[[v1:[^ ]*]] = irdl.base "!test.singleton_b"
 // CHECK-NEXT:    %[[v2:[^ ]*]] = irdl.base "!test.singleton_c"
 // CHECK-NEXT:    irdl.operands(variadic %[[v0]], optional %[[v1]], %[[v2]])
-// CHECK-NEXT:    irdl.results()
 // CHECK-NEXT:  }
diff --git a/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp b/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp
index a55f3539f31db0..181d02c6608bdb 100644
--- a/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp
+++ b/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp
@@ -39,6 +39,130 @@ llvm::cl::opt<std::string>
     selectedDialect("dialect", llvm::cl::desc("The dialect to gen for"),
                     llvm::cl::cat(dialectGenCat), llvm::cl::Required);
 
+Value createPredicate(OpBuilder &builder, tblgen::Pred pred) {
+  MLIRContext *ctx = builder.getContext();
+
+  if (pred.isCombined()) {
+    auto combiner = pred.getDef().getValueAsDef("kind")->getName();
+    if (combiner == "PredCombinerAnd" || combiner == "PredCombinerOr") {
+      std::vector<Value> constraints;
+      for (auto *child : pred.getDef().getValueAsListOfDefs("children")) {
+        constraints.push_back(createPredicate(builder, tblgen::Pred(child)));
+      }
+      if (combiner == "PredCombinerAnd") {
+        auto op =
+            builder.create<irdl::AllOfOp>(UnknownLoc::get(ctx), constraints);
+        return op.getOutput();
+      }
+      auto op =
+          builder.create<irdl::AnyOfOp>(UnknownLoc::get(ctx), constraints);
+      return op.getOutput();
+    }
+  }
+
+  std::string condition = pred.getCondition();
+  // Build a CPredOp to match the C constraint built.
+  irdl::CPredOp op = builder.create<irdl::CPredOp>(
+      UnknownLoc::get(ctx), StringAttr::get(ctx, condition));
+  return op;
+}
+
+Value typeToConstraint(OpBuilder &builder, MLIRContext *ctx, Type type) {
+  auto op =
+      builder.create<irdl::IsOp>(UnknownLoc::get(ctx), TypeAttr::get(type));
+  return op.getOutput();
+}
+
+std::optional<Type> recordToType(MLIRContext *ctx, const Record &predRec) {
+
+  if (predRec.isSubClassOf("I")) {
+    auto width = predRec.getValueAsInt("bitwidth");
+    return IntegerType::get(ctx, width, IntegerType::Signless);
+  }
+
+  if (predRec.isSubClassOf("SI")) {
+    auto width = predRec.getValueAsInt("bitwidth");
+    return IntegerType::get(ctx, width, IntegerType::Signed);
+  }
+
+  if (predRec.isSubClassOf("UI")) {
+    auto width = predRec.getValueAsInt("bitwidth");
+    return IntegerType::get(ctx, width, IntegerType::Unsigned);
+  }
+
+  // Index type
+  if (predRec.getName() == "Index") {
+    return IndexType::get(ctx);
+  }
+
+  // Float types
+  if (predRec.isSubClassOf("F")) {
+    auto width = predRec.getValueAsInt("bitwidth");
+    switch (width) {
+    case 16:
+      return FloatType::getF16(ctx);
+    case 32:
+      return FloatType::getF32(ctx);
+    case 64:
+      return FloatType::getF64(ctx);
+    case 80:
+      return FloatType::getF80(ctx);
+    case 128:
+      return FloatType::getF128(ctx);
+    }
+  }
+
+  if (predRec.getName() == "NoneType") {
+    return NoneType::get(ctx);
+  }
+
+  if (predRec.getName() == "BF16") {
+    return FloatType::getBF16(ctx);
+  }
+
+  if (predRec.getName() == "TF32") {
+    return FloatType::getTF32(ctx);
+  }
+
+  if (predRec.getName() == "F8E4M3FN") {
+    return FloatType::getFloat8E4M3FN(ctx);
+  }
+
+  if (predRec.getName() == "F8E5M2") {
+    return FloatType::getFloat8E5M2(ctx);
+  }
+
+  if (predRec.getName() == "F8E4M3") {
+    return FloatType::getFloat8E4M3(ctx);
+  }
+
+  if (predRec.getName() == "F8E4M3FNUZ") {
+    return FloatType::getFloat8E4M3FNUZ(ctx);
+  }
+
+  if (predRec.getName() == "F8E4M3B11FNUZ") {
+    return FloatType::getFloat8E4M3B11FNUZ(ctx);
+  }
+
+  if (predRec.getName() == "F8E5M2FNUZ") {
+    return FloatType::getFloat8E5M2FNUZ(ctx);
+  }
+
+  if (predRec.getName() == "F8E3M4") {
+    return FloatType::getFloat8E3M4(ctx);
+  }
+
+  if (predRec.isSubClassOf("Complex")) {
+    const Record *elementRec = predRec.getValueAsDef("elementType");
+    auto elementType = recordToType(ctx, *elementRec);
+    if (elementType.has_value()) {
+      return ComplexType::get(elementType.value());
+    }
+  }
+
+  return std::nullopt;
+}
+
 Value createConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
   MLIRContext *ctx = builder.getContext();
   const Record &predRec = constraint.getDef();
@@ -78,11 +202,45 @@ Value createConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
     return op.getOutput();
   }
 
-  std::string condition = constraint.getPredicate().getCondition();
-  // Build a CPredOp to match the C constraint built.
-  irdl::CPredOp op = builder.create<irdl::CPredOp>(
-      UnknownLoc::get(ctx), StringAttr::get(ctx, condition));
-  return op;
+  // Integer types
+  if (predRec.getName() == "AnyInteger") {
+    auto op = builder.create<irdl::BaseOp>(
+        UnknownLoc::get(ctx), StringAttr::get(ctx, "!builtin.integer"));
+    return op.getOutput();
+  }
+
+  if (predRec.isSubClassOf("AnyI")) {
+    auto width = predRec.getValueAsInt("bitwidth");
+    std::vector<Value> types = {
+        typeToConstraint(builder, ctx,
+                         IntegerType::get(ctx, width, IntegerType::Signless)),
+        typeToConstraint(builder, ctx,
+                         IntegerType::get(ctx, width, IntegerType::Signed)),
+        typeToConstraint(builder, ctx,
+                         IntegerType::get(ctx, width, IntegerType::Unsigned))};
+    auto op = builder.create<irdl::AnyOfOp>(UnknownLoc::get(ctx), types);
+    return op.getOutput();
+  }
+
+  auto type = recordToType(ctx, predRec);
+
+  if (type.has_value()) {
+    return typeToConstraint(builder, ctx, type.value());
+  }
+
+  // Confined type
+  if (predRec.isSubClassOf("ConfinedType")) {
+    std::vector<Value> constraints;
+    constraints.push_back(createConstraint(
+        builder, tblgen::Constraint(predRec.getValueAsDef("baseType"))));
+    for (Record *child : predRec.getValueAsListOfDefs("predicateList")) {
+      constraints.push_back(createPredicate(builder, tblgen::Pred(child)));
+    }
+    auto op = builder.create<irdl::AllOfOp>(UnknownLoc::get(ctx), constraints);
+    return op.getOutput();
+  }
+
+  return createPredicate(builder, constraint.getPredicate());
 }
 
 /// Returns the name of the operation without the dialect prefix.
@@ -131,10 +289,12 @@ irdl::OperationOp createIRDLOperation(OpBuilder &builder,
   auto [results, resultVariadicity] = getValues(tblgenOp.getResults());
 
   // Create the operands and results operations.
-  consBuilder.create<irdl::OperandsOp>(UnknownLoc::get(ctx), operands,
-                                       operandVariadicity);
-  consBuilder.create<irdl::ResultsOp>(UnknownLoc::get(ctx), results,
-                                      resultVariadicity);
+  if (!operands.empty())
+    consBuilder.create<irdl::OperandsOp>(UnknownLoc::get(ctx), operands,
+					 operandVariadicity);
+  if (!results.empty())
+    consBuilder.create<irdl::ResultsOp>(UnknownLoc::get(ctx), results,
+					resultVariadicity);
 
   return op;
 }

``````````

</details>


https://github.com/llvm/llvm-project/pull/105505


More information about the Mlir-commits mailing list