[Mlir-commits] [mlir] 62df4df - [mlir-tblgen] Minor Refactor for StaticVerifierFunctionEmitter.

Chia-hung Duan llvmlistbot at llvm.org
Thu Aug 12 13:54:00 PDT 2021


Author: Chia-hung Duan
Date: 2021-08-12T20:53:05Z
New Revision: 62df4df41c939205b2dc0a2a3bfb75b8c1ed74fa

URL: https://github.com/llvm/llvm-project/commit/62df4df41c939205b2dc0a2a3bfb75b8c1ed74fa
DIFF: https://github.com/llvm/llvm-project/commit/62df4df41c939205b2dc0a2a3bfb75b8c1ed74fa.diff

LOG: [mlir-tblgen] Minor Refactor for StaticVerifierFunctionEmitter.

Move StaticVerifierFunctionEmitter to CodeGenHelper.h so that it can be
used for both ODS and DRR.

Reviewed By: jpienaar

Differential Revision: https://reviews.llvm.org/D106636

Added: 
    mlir/tools/mlir-tblgen/CodeGenHelpers.cpp

Modified: 
    mlir/include/mlir/TableGen/CodeGenHelpers.h
    mlir/tools/mlir-tblgen/CMakeLists.txt
    mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/TableGen/CodeGenHelpers.h b/mlir/include/mlir/TableGen/CodeGenHelpers.h
index acd80698b3463..68fdf545e23b6 100644
--- a/mlir/include/mlir/TableGen/CodeGenHelpers.h
+++ b/mlir/include/mlir/TableGen/CodeGenHelpers.h
@@ -13,13 +13,21 @@
 #ifndef MLIR_TABLEGEN_CODEGENHELPERS_H
 #define MLIR_TABLEGEN_CODEGENHELPERS_H
 
+#include "mlir/Support/IndentedOstream.h"
 #include "mlir/TableGen/Dialect.h"
+#include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/ADT/StringRef.h"
 
+namespace llvm {
+class RecordKeeper;
+} // namespace llvm
+
 namespace mlir {
 namespace tblgen {
 
+class Constraint;
+
 // Simple RAII helper for defining ifdef-undef-endif scopes.
 class IfDefScope {
 public:
@@ -62,6 +70,82 @@ class NamespaceEmitter {
   SmallVector<StringRef, 2> namespaces;
 };
 
+/// This class deduplicates shared operation verification code by emitting
+/// static functions alongside the op definitions. These methods are local to
+/// the definition file, and are invoked within the operation verify methods.
+/// An example is shown below:
+///
+/// static LogicalResult localVerify(...)
+///
+/// LogicalResult OpA::verify(...) {
+///  if (failed(localVerify(...)))
+///    return failure();
+///  ...
+/// }
+///
+/// LogicalResult OpB::verify(...) {
+///  if (failed(localVerify(...)))
+///    return failure();
+///  ...
+/// }
+///
+class StaticVerifierFunctionEmitter {
+public:
+  StaticVerifierFunctionEmitter(const llvm::RecordKeeper &records,
+                                raw_ostream &os);
+
+  /// Emit the static verifier functions for `llvm::Record`s. The
+  /// `signatureFormat` describes the required arguments and it must have a
+  /// placeholder for function name.
+  /// Example,
+  ///   const char *typeVerifierSignature =
+  ///     "static ::mlir::LogicalResult {0}(::mlir::Operation *op, ::mlir::Type"
+  ///     " type, ::llvm::StringRef valueKind, unsigned valueGroupStartIndex)";
+  ///
+  /// `errorHandlerFormat` describes the error message to return. It may have a
+  /// placeholder for the summary of Constraint and bring more information for
+  /// the error message.
+  /// Example,
+  ///   const char *typeVerifierErrorHandler =
+  ///       " op->emitOpError(valueKind) << \" #\" << valueGroupStartIndex << "
+  ///       "\" must be {0}, but got \" << type";
+  ///
+  /// `typeArgName` is used to identify the argument that needs to check its
+  /// type. The constraint template will replace `$_self` with it.
+  void emitFunctionsFor(StringRef signatureFormat, StringRef errorHandlerFormat,
+                        StringRef typeArgName, ArrayRef<llvm::Record *> opDefs,
+                        bool emitDecl);
+
+  /// Get the name of the local function used for the given type constraint.
+  /// These functions are used for operand and result constraints and have the
+  /// form:
+  ///   LogicalResult(Operation *op, Type type, StringRef valueKind,
+  ///                 unsigned valueGroupStartIndex);
+  StringRef getTypeConstraintFn(const Constraint &constraint) const;
+
+private:
+  /// Returns a unique name to use when generating local methods.
+  static std::string getUniqueName(const llvm::RecordKeeper &records);
+
+  /// Emit local methods for the type constraints used within the provided op
+  /// definitions.
+  void emitTypeConstraintMethods(StringRef signatureFormat,
+                                 StringRef errorHandlerFormat,
+                                 StringRef typeArgName,
+                                 ArrayRef<llvm::Record *> opDefs,
+                                 bool emitDecl);
+
+  raw_indented_ostream os;
+
+  /// A unique label for the file currently being generated. This is used to
+  /// ensure that the local functions have a unique name.
+  std::string uniqueOutputLabel;
+
+  /// A set of functions implementing type constraints, used for operand and
+  /// result verification.
+  llvm::DenseMap<const void *, std::string> localTypeConstraints;
+};
+
 } // namespace tblgen
 } // namespace mlir
 

diff  --git a/mlir/tools/mlir-tblgen/CMakeLists.txt b/mlir/tools/mlir-tblgen/CMakeLists.txt
index 083f70e40e6d6..f16e8965daca4 100644
--- a/mlir/tools/mlir-tblgen/CMakeLists.txt
+++ b/mlir/tools/mlir-tblgen/CMakeLists.txt
@@ -6,6 +6,7 @@ set(LLVM_LINK_COMPONENTS
 
 add_tablegen(mlir-tblgen MLIR
   AttrOrTypeDefGen.cpp
+  CodeGenHelpers.cpp
   DialectGen.cpp
   DirectiveCommonGen.cpp
   EnumsGen.cpp

diff  --git a/mlir/tools/mlir-tblgen/CodeGenHelpers.cpp b/mlir/tools/mlir-tblgen/CodeGenHelpers.cpp
new file mode 100644
index 0000000000000..c003ad9eb673c
--- /dev/null
+++ b/mlir/tools/mlir-tblgen/CodeGenHelpers.cpp
@@ -0,0 +1,139 @@
+//===- CodeGenHelpers.cpp - MLIR op definitions generator ---------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// OpDefinitionsGen uses the description of operations to generate C++
+// definitions for ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/TableGen/CodeGenHelpers.h"
+#include "mlir/TableGen/Format.h"
+#include "mlir/TableGen/Operator.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/Path.h"
+#include "llvm/TableGen/Record.h"
+
+using namespace llvm;
+using namespace mlir;
+using namespace mlir::tblgen;
+
+StaticVerifierFunctionEmitter::StaticVerifierFunctionEmitter(
+    const llvm::RecordKeeper &records, raw_ostream &os)
+    : os(os), uniqueOutputLabel(getUniqueName(records)) {}
+
+void StaticVerifierFunctionEmitter::emitFunctionsFor(
+    StringRef signatureFormat, StringRef errorHandlerFormat,
+    StringRef typeArgName, ArrayRef<llvm::Record *> opDefs, bool emitDecl) {
+  llvm::Optional<NamespaceEmitter> namespaceEmitter;
+  if (!emitDecl)
+    namespaceEmitter.emplace(os, Operator(*opDefs[0]).getCppNamespace());
+
+  emitTypeConstraintMethods(signatureFormat, errorHandlerFormat, typeArgName,
+                            opDefs, emitDecl);
+}
+
+StringRef StaticVerifierFunctionEmitter::getTypeConstraintFn(
+    const Constraint &constraint) const {
+  auto it = localTypeConstraints.find(constraint.getAsOpaquePointer());
+  assert(it != localTypeConstraints.end() && "expected valid constraint fn");
+  return it->second;
+}
+
+std::string StaticVerifierFunctionEmitter::getUniqueName(
+    const llvm::RecordKeeper &records) {
+  // Use the input file name when generating a unique name.
+  std::string inputFilename = records.getInputFilename();
+
+  // Drop all but the base filename.
+  StringRef nameRef = llvm::sys::path::filename(inputFilename);
+  nameRef.consume_back(".td");
+
+  // Sanitize any invalid characters.
+  std::string uniqueName;
+  for (char c : nameRef) {
+    if (llvm::isAlnum(c) || c == '_')
+      uniqueName.push_back(c);
+    else
+      uniqueName.append(llvm::utohexstr((unsigned char)c));
+  }
+  return uniqueName;
+}
+
+void StaticVerifierFunctionEmitter::emitTypeConstraintMethods(
+    StringRef signatureFormat, StringRef errorHandlerFormat,
+    StringRef typeArgName, ArrayRef<llvm::Record *> opDefs, bool emitDecl) {
+  // Collect a set of all of the used type constraints within the operation
+  // definitions.
+  llvm::SetVector<const void *> typeConstraints;
+  for (Record *def : opDefs) {
+    Operator op(*def);
+    for (NamedTypeConstraint &operand : op.getOperands())
+      if (operand.hasPredicate())
+        typeConstraints.insert(operand.constraint.getAsOpaquePointer());
+    for (NamedTypeConstraint &result : op.getResults())
+      if (result.hasPredicate())
+        typeConstraints.insert(result.constraint.getAsOpaquePointer());
+  }
+
+  // Record the mapping from predicate to constraint. If two constraints has the
+  // same predicate and constraint summary, they can share the same verification
+  // function.
+  llvm::DenseMap<Pred, const void *> predToConstraint;
+  FmtContext fctx;
+  for (auto it : llvm::enumerate(typeConstraints)) {
+    std::string name;
+    Constraint constraint = Constraint::getFromOpaquePointer(it.value());
+    Pred pred = constraint.getPredicate();
+    auto iter = predToConstraint.find(pred);
+    if (iter != predToConstraint.end()) {
+      do {
+        Constraint built = Constraint::getFromOpaquePointer(iter->second);
+        // We may have the 
diff erent constraints but have the same predicate,
+        // for example, ConstraintA and Variadic<ConstraintA>, note that
+        // Variadic<> doesn't introduce new predicate. In this case, we can
+        // share the same predicate function if they also have consistent
+        // summary, otherwise we may report the wrong message while verification
+        // fails.
+        if (constraint.getSummary() == built.getSummary()) {
+          name = getTypeConstraintFn(built).str();
+          break;
+        }
+        ++iter;
+      } while (iter != predToConstraint.end() && iter->first == pred);
+    }
+
+    if (!name.empty()) {
+      localTypeConstraints.try_emplace(it.value(), name);
+      continue;
+    }
+
+    // Generate an obscure and unique name for this type constraint.
+    name = (Twine("__mlir_ods_local_type_constraint_") + uniqueOutputLabel +
+            Twine(it.index()))
+               .str();
+    predToConstraint.insert(
+        std::make_pair(constraint.getPredicate(), it.value()));
+    localTypeConstraints.try_emplace(it.value(), name);
+
+    // Only generate the methods if we are generating definitions.
+    if (emitDecl)
+      continue;
+
+    os << formatv(signatureFormat.data(), name) << " {\n";
+    os.indent() << "if (!("
+                << tgfmt(constraint.getConditionTemplate(),
+                         &fctx.withSelf(typeArgName))
+                << ")) {\n";
+    os.indent() << "return "
+                << formatv(errorHandlerFormat.data(), constraint.getSummary())
+                << ";\n";
+    os.unindent() << "}\nreturn ::mlir::success();\n";
+    os.unindent() << "}\n\n";
+  }
+}

diff  --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 2a851618938b0..269803f50788f 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -24,7 +24,6 @@
 #include "llvm/ADT/MapVector.h"
 #include "llvm/ADT/Sequence.h"
 #include "llvm/ADT/StringExtras.h"
-#include "llvm/Support/Path.h"
 #include "llvm/Support/Signals.h"
 #include "llvm/TableGen/Error.h"
 #include "llvm/TableGen/Record.h"
@@ -101,6 +100,14 @@ const char *valueRangeReturnCode = R"(
            std::next({0}, valueRange.first + valueRange.second)};
 )";
 
+const char *typeVerifierSignature =
+    "static ::mlir::LogicalResult {0}(::mlir::Operation *op, ::mlir::Type "
+    "type, ::llvm::StringRef valueKind, unsigned valueGroupStartIndex)";
+
+const char *typeVerifierErrorHandler =
+    " op->emitOpError(valueKind) << \" #\" << valueGroupStartIndex << \" must "
+    "be {0}, but got \" << type";
+
 static const char *const opCommentHeader = R"(
 //===----------------------------------------------------------------------===//
 // {0} {1}
@@ -108,175 +115,6 @@ static const char *const opCommentHeader = R"(
 
 )";
 
-//===----------------------------------------------------------------------===//
-// StaticVerifierFunctionEmitter
-//===----------------------------------------------------------------------===//
-
-namespace {
-/// This class deduplicates shared operation verification code by emitting
-/// static functions alongside the op definitions. These methods are local to
-/// the definition file, and are invoked within the operation verify methods.
-/// An example is shown below:
-///
-/// static LogicalResult localVerify(...)
-///
-/// LogicalResult OpA::verify(...) {
-///  if (failed(localVerify(...)))
-///    return failure();
-///  ...
-/// }
-///
-/// LogicalResult OpB::verify(...) {
-///  if (failed(localVerify(...)))
-///    return failure();
-///  ...
-/// }
-///
-class StaticVerifierFunctionEmitter {
-public:
-  StaticVerifierFunctionEmitter(const llvm::RecordKeeper &records,
-                                ArrayRef<llvm::Record *> opDefs,
-                                raw_ostream &os, bool emitDecl);
-
-  /// Get the name of the local function used for the given type constraint.
-  /// These functions are used for operand and result constraints and have the
-  /// form:
-  ///   LogicalResult(Operation *op, Type type, StringRef valueKind,
-  ///                 unsigned valueGroupStartIndex);
-  StringRef getTypeConstraintFn(const Constraint &constraint) const {
-    auto it = localTypeConstraints.find(constraint.getAsOpaquePointer());
-    assert(it != localTypeConstraints.end() && "expected valid constraint fn");
-    return it->second;
-  }
-
-private:
-  /// Returns a unique name to use when generating local methods.
-  static std::string getUniqueName(const llvm::RecordKeeper &records);
-
-  /// Emit local methods for the type constraints used within the provided op
-  /// definitions.
-  void emitTypeConstraintMethods(ArrayRef<llvm::Record *> opDefs,
-                                 raw_ostream &os, bool emitDecl);
-
-  /// A unique label for the file currently being generated. This is used to
-  /// ensure that the local functions have a unique name.
-  std::string uniqueOutputLabel;
-
-  /// A set of functions implementing type constraints, used for operand and
-  /// result verification.
-  llvm::DenseMap<const void *, std::string> localTypeConstraints;
-};
-} // namespace
-
-StaticVerifierFunctionEmitter::StaticVerifierFunctionEmitter(
-    const llvm::RecordKeeper &records, ArrayRef<llvm::Record *> opDefs,
-    raw_ostream &os, bool emitDecl)
-    : uniqueOutputLabel(getUniqueName(records)) {
-  llvm::Optional<NamespaceEmitter> namespaceEmitter;
-  if (!emitDecl) {
-    os << formatv(opCommentHeader, "Local Utility Method", "Definitions");
-    namespaceEmitter.emplace(os, Operator(*opDefs[0]).getCppNamespace());
-  }
-
-  emitTypeConstraintMethods(opDefs, os, emitDecl);
-}
-
-std::string StaticVerifierFunctionEmitter::getUniqueName(
-    const llvm::RecordKeeper &records) {
-  // Use the input file name when generating a unique name.
-  std::string inputFilename = records.getInputFilename();
-
-  // Drop all but the base filename.
-  StringRef nameRef = llvm::sys::path::filename(inputFilename);
-  nameRef.consume_back(".td");
-
-  // Sanitize any invalid characters.
-  std::string uniqueName;
-  for (char c : nameRef) {
-    if (llvm::isAlnum(c) || c == '_')
-      uniqueName.push_back(c);
-    else
-      uniqueName.append(llvm::utohexstr((unsigned char)c));
-  }
-  return uniqueName;
-}
-
-void StaticVerifierFunctionEmitter::emitTypeConstraintMethods(
-    ArrayRef<llvm::Record *> opDefs, raw_ostream &os, bool emitDecl) {
-  // Collect a set of all of the used type constraints within the operation
-  // definitions.
-  llvm::SetVector<const void *> typeConstraints;
-  for (Record *def : opDefs) {
-    Operator op(*def);
-    for (NamedTypeConstraint &operand : op.getOperands())
-      if (operand.hasPredicate())
-        typeConstraints.insert(operand.constraint.getAsOpaquePointer());
-    for (NamedTypeConstraint &result : op.getResults())
-      if (result.hasPredicate())
-        typeConstraints.insert(result.constraint.getAsOpaquePointer());
-  }
-
-  // Record the mapping from predicate to constraint. If two constraints has the
-  // same predicate and constraint summary, they can share the same verification
-  // function.
-  llvm::DenseMap<Pred, const void *> predToConstraint;
-  FmtContext fctx;
-  for (auto it : llvm::enumerate(typeConstraints)) {
-    std::string name;
-    Constraint constraint = Constraint::getFromOpaquePointer(it.value());
-    Pred pred = constraint.getPredicate();
-    auto iter = predToConstraint.find(pred);
-    if (iter != predToConstraint.end()) {
-      do {
-        Constraint built = Constraint::getFromOpaquePointer(iter->second);
-        // We may have the 
diff erent constraints but have the same predicate,
-        // for example, ConstraintA and Variadic<ConstraintA>, note that
-        // Variadic<> doesn't introduce new predicate. In this case, we can
-        // share the same predicate function if they also have consistent
-        // summary, otherwise we may report the wrong message while verification
-        // fails.
-        if (constraint.getSummary() == built.getSummary()) {
-          name = getTypeConstraintFn(built).str();
-          break;
-        }
-        ++iter;
-      } while (iter != predToConstraint.end() && iter->first == pred);
-    }
-
-    if (!name.empty()) {
-      localTypeConstraints.try_emplace(it.value(), name);
-      continue;
-    }
-
-    // Generate an obscure and unique name for this type constraint.
-    name = (Twine("__mlir_ods_local_type_constraint_") + uniqueOutputLabel +
-            Twine(it.index()))
-               .str();
-    predToConstraint.insert(
-        std::make_pair(constraint.getPredicate(), it.value()));
-    localTypeConstraints.try_emplace(it.value(), name);
-
-    // Only generate the methods if we are generating definitions.
-    if (emitDecl)
-      continue;
-
-    os << "static ::mlir::LogicalResult " << name
-       << "(::mlir::Operation *op, ::mlir::Type type, ::llvm::StringRef "
-          "valueKind, unsigned valueGroupStartIndex) {\n";
-
-    os << "  if (!("
-       << tgfmt(constraint.getConditionTemplate(), &fctx.withSelf("type"))
-       << ")) {\n"
-       << formatv(
-              "    return op->emitOpError(valueKind) << \" #\" << "
-              "valueGroupStartIndex << \" must be {0}, but got \" << type;\n",
-              constraint.getSummary())
-       << "  }\n"
-       << "  return ::mlir::success();\n"
-       << "}\n\n";
-  }
-}
-
 //===----------------------------------------------------------------------===//
 // Utility structs and functions
 //===----------------------------------------------------------------------===//
@@ -2560,8 +2398,12 @@ static void emitOpClasses(const RecordKeeper &recordKeeper,
     return;
 
   // Generate all of the locally instantiated methods first.
-  StaticVerifierFunctionEmitter staticVerifierEmitter(recordKeeper, defs, os,
-                                                      emitDecl);
+  StaticVerifierFunctionEmitter staticVerifierEmitter(recordKeeper, os);
+  os << formatv(opCommentHeader, "Local Utility Method", "Definitions");
+  staticVerifierEmitter.emitFunctionsFor(
+      typeVerifierSignature, typeVerifierErrorHandler, /*typeArgName=*/"type",
+      defs, emitDecl);
+
   for (auto *def : defs) {
     Operator op(*def);
     if (emitDecl) {


        


More information about the Mlir-commits mailing list