[Mlir-commits] [mlir] [mlir][ODS] Verify type constraints in Types and Attributes (PR #102326)

River Riddle llvmlistbot at llvm.org
Wed Aug 7 09:17:08 PDT 2024


================
@@ -295,24 +309,91 @@ void DefGen::emitDialectName() {
 void DefGen::emitBuilders() {
   if (!def.skipDefaultBuilders()) {
     emitDefaultBuilder();
-    if (def.genVerifyDecl())
+    if (def.genVerifyDecl() || def.genVerifyInvariantsImpl())
       emitCheckedBuilder();
   }
   for (auto &builder : def.getBuilders()) {
     emitCustomBuilder(builder);
-    if (def.genVerifyDecl())
+    if (def.genVerifyDecl() || def.genVerifyInvariantsImpl())
       emitCheckedCustomBuilder(builder);
   }
 }
 
-void DefGen::emitVerifier() {
-  defCls.declare<UsingDeclaration>("Base::getChecked");
+void DefGen::emitVerifierDecl() {
   defCls.declareStaticMethod(
       "::llvm::LogicalResult", "verify",
       getBuilderParams({{"::llvm::function_ref<::mlir::InFlightDiagnostic()>",
                          "emitError"}}));
 }
 
+static const char *const patternParameterVerificationCode = R"(
+if (!({0})) {
+  emitError() << "failed to verify '{1}': {2}";
+  return ::mlir::failure();
+}
+)";
+
+void DefGen::emitInvariantsVerifierImpl() {
+  SmallVector<MethodParameter> builderParams = getBuilderParams(
+      {{"::llvm::function_ref<::mlir::InFlightDiagnostic()>", "emitError"}});
+  Method *verifier =
+      defCls.addMethod("::llvm::LogicalResult", "verifyInvariantsImpl",
+                       Method::Static, builderParams);
+  verifier->body().indent();
+
+  // Generate verification for each parameter that is a type constraint.
+  for (auto it : llvm::enumerate(def.getParameters())) {
+    const AttrOrTypeParameter &param = it.value();
+    std::optional<TypeConstraint> constraint = param.getTypeConstraint();
+    // No verification needed for parameters that are not type constraints.
+    if (!constraint.has_value())
+      continue;
+    FmtContext ctx;
+    // Note: Skip over the first method parameter (`emitError`).
+    ctx.withSelf(builderParams[it.index() + 1].getName());
+    std::string condition = tgfmt(constraint->getConditionTemplate(), &ctx);
+    verifier->body() << formatv(patternParameterVerificationCode, condition,
+                                param.getName(), constraint->getSummary())
+                     << "\n";
+  }
+  verifier->body() << "return ::mlir::success();";
+}
+
+void DefGen::emitInvariantsVerifier(bool hasImpl, bool hasCustomVerifier) {
+  if (!hasImpl && !hasCustomVerifier)
+    return;
+  defCls.declare<UsingDeclaration>("Base::getChecked");
+  SmallVector<MethodParameter> builderParams = getBuilderParams(
+      {{"::llvm::function_ref<::mlir::InFlightDiagnostic()>", "emitError"}});
+  Method *verifier =
+      defCls.addMethod("::llvm::LogicalResult", "verifyInvariants",
+                       Method::Static, builderParams);
+  verifier->body().indent();
+  if (hasImpl) {
+    // Call the verifier that checks the type constraints.
+    verifier->body() << "if (::mlir::failed(verifyInvariantsImpl(";
+    for (int i = 0, e = builderParams.size(); i < e; ++i) {
+      if (i > 0)
+        verifier->body() << ", ";
+      verifier->body() << builderParams[i].getName();
+    }
----------------
River707 wrote:

Can you use llvm::interleave(comma) here? Same below.

https://github.com/llvm/llvm-project/pull/102326


More information about the Mlir-commits mailing list