[Mlir-commits] [mlir] 62df4df - [mlir-tblgen] Minor Refactor for StaticVerifierFunctionEmitter.
Chia-hung Duan
llvmlistbot at llvm.org
Thu Aug 12 13:54:00 PDT 2021
Author: Chia-hung Duan
Date: 2021-08-12T20:53:05Z
New Revision: 62df4df41c939205b2dc0a2a3bfb75b8c1ed74fa
URL: https://github.com/llvm/llvm-project/commit/62df4df41c939205b2dc0a2a3bfb75b8c1ed74fa
DIFF: https://github.com/llvm/llvm-project/commit/62df4df41c939205b2dc0a2a3bfb75b8c1ed74fa.diff
LOG: [mlir-tblgen] Minor Refactor for StaticVerifierFunctionEmitter.
Move StaticVerifierFunctionEmitter to CodeGenHelper.h so that it can be
used for both ODS and DRR.
Reviewed By: jpienaar
Differential Revision: https://reviews.llvm.org/D106636
Added:
mlir/tools/mlir-tblgen/CodeGenHelpers.cpp
Modified:
mlir/include/mlir/TableGen/CodeGenHelpers.h
mlir/tools/mlir-tblgen/CMakeLists.txt
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/TableGen/CodeGenHelpers.h b/mlir/include/mlir/TableGen/CodeGenHelpers.h
index acd80698b3463..68fdf545e23b6 100644
--- a/mlir/include/mlir/TableGen/CodeGenHelpers.h
+++ b/mlir/include/mlir/TableGen/CodeGenHelpers.h
@@ -13,13 +13,21 @@
#ifndef MLIR_TABLEGEN_CODEGENHELPERS_H
#define MLIR_TABLEGEN_CODEGENHELPERS_H
+#include "mlir/Support/IndentedOstream.h"
#include "mlir/TableGen/Dialect.h"
+#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringRef.h"
+namespace llvm {
+class RecordKeeper;
+} // namespace llvm
+
namespace mlir {
namespace tblgen {
+class Constraint;
+
// Simple RAII helper for defining ifdef-undef-endif scopes.
class IfDefScope {
public:
@@ -62,6 +70,82 @@ class NamespaceEmitter {
SmallVector<StringRef, 2> namespaces;
};
+/// This class deduplicates shared operation verification code by emitting
+/// static functions alongside the op definitions. These methods are local to
+/// the definition file, and are invoked within the operation verify methods.
+/// An example is shown below:
+///
+/// static LogicalResult localVerify(...)
+///
+/// LogicalResult OpA::verify(...) {
+/// if (failed(localVerify(...)))
+/// return failure();
+/// ...
+/// }
+///
+/// LogicalResult OpB::verify(...) {
+/// if (failed(localVerify(...)))
+/// return failure();
+/// ...
+/// }
+///
+class StaticVerifierFunctionEmitter {
+public:
+ StaticVerifierFunctionEmitter(const llvm::RecordKeeper &records,
+ raw_ostream &os);
+
+ /// Emit the static verifier functions for `llvm::Record`s. The
+ /// `signatureFormat` describes the required arguments and it must have a
+ /// placeholder for function name.
+ /// Example,
+ /// const char *typeVerifierSignature =
+ /// "static ::mlir::LogicalResult {0}(::mlir::Operation *op, ::mlir::Type"
+ /// " type, ::llvm::StringRef valueKind, unsigned valueGroupStartIndex)";
+ ///
+ /// `errorHandlerFormat` describes the error message to return. It may have a
+ /// placeholder for the summary of Constraint and bring more information for
+ /// the error message.
+ /// Example,
+ /// const char *typeVerifierErrorHandler =
+ /// " op->emitOpError(valueKind) << \" #\" << valueGroupStartIndex << "
+ /// "\" must be {0}, but got \" << type";
+ ///
+ /// `typeArgName` is used to identify the argument that needs to check its
+ /// type. The constraint template will replace `$_self` with it.
+ void emitFunctionsFor(StringRef signatureFormat, StringRef errorHandlerFormat,
+ StringRef typeArgName, ArrayRef<llvm::Record *> opDefs,
+ bool emitDecl);
+
+ /// Get the name of the local function used for the given type constraint.
+ /// These functions are used for operand and result constraints and have the
+ /// form:
+ /// LogicalResult(Operation *op, Type type, StringRef valueKind,
+ /// unsigned valueGroupStartIndex);
+ StringRef getTypeConstraintFn(const Constraint &constraint) const;
+
+private:
+ /// Returns a unique name to use when generating local methods.
+ static std::string getUniqueName(const llvm::RecordKeeper &records);
+
+ /// Emit local methods for the type constraints used within the provided op
+ /// definitions.
+ void emitTypeConstraintMethods(StringRef signatureFormat,
+ StringRef errorHandlerFormat,
+ StringRef typeArgName,
+ ArrayRef<llvm::Record *> opDefs,
+ bool emitDecl);
+
+ raw_indented_ostream os;
+
+ /// A unique label for the file currently being generated. This is used to
+ /// ensure that the local functions have a unique name.
+ std::string uniqueOutputLabel;
+
+ /// A set of functions implementing type constraints, used for operand and
+ /// result verification.
+ llvm::DenseMap<const void *, std::string> localTypeConstraints;
+};
+
} // namespace tblgen
} // namespace mlir
diff --git a/mlir/tools/mlir-tblgen/CMakeLists.txt b/mlir/tools/mlir-tblgen/CMakeLists.txt
index 083f70e40e6d6..f16e8965daca4 100644
--- a/mlir/tools/mlir-tblgen/CMakeLists.txt
+++ b/mlir/tools/mlir-tblgen/CMakeLists.txt
@@ -6,6 +6,7 @@ set(LLVM_LINK_COMPONENTS
add_tablegen(mlir-tblgen MLIR
AttrOrTypeDefGen.cpp
+ CodeGenHelpers.cpp
DialectGen.cpp
DirectiveCommonGen.cpp
EnumsGen.cpp
diff --git a/mlir/tools/mlir-tblgen/CodeGenHelpers.cpp b/mlir/tools/mlir-tblgen/CodeGenHelpers.cpp
new file mode 100644
index 0000000000000..c003ad9eb673c
--- /dev/null
+++ b/mlir/tools/mlir-tblgen/CodeGenHelpers.cpp
@@ -0,0 +1,139 @@
+//===- CodeGenHelpers.cpp - MLIR op definitions generator ---------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// OpDefinitionsGen uses the description of operations to generate C++
+// definitions for ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/TableGen/CodeGenHelpers.h"
+#include "mlir/TableGen/Format.h"
+#include "mlir/TableGen/Operator.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/Path.h"
+#include "llvm/TableGen/Record.h"
+
+using namespace llvm;
+using namespace mlir;
+using namespace mlir::tblgen;
+
+StaticVerifierFunctionEmitter::StaticVerifierFunctionEmitter(
+ const llvm::RecordKeeper &records, raw_ostream &os)
+ : os(os), uniqueOutputLabel(getUniqueName(records)) {}
+
+void StaticVerifierFunctionEmitter::emitFunctionsFor(
+ StringRef signatureFormat, StringRef errorHandlerFormat,
+ StringRef typeArgName, ArrayRef<llvm::Record *> opDefs, bool emitDecl) {
+ llvm::Optional<NamespaceEmitter> namespaceEmitter;
+ if (!emitDecl)
+ namespaceEmitter.emplace(os, Operator(*opDefs[0]).getCppNamespace());
+
+ emitTypeConstraintMethods(signatureFormat, errorHandlerFormat, typeArgName,
+ opDefs, emitDecl);
+}
+
+StringRef StaticVerifierFunctionEmitter::getTypeConstraintFn(
+ const Constraint &constraint) const {
+ auto it = localTypeConstraints.find(constraint.getAsOpaquePointer());
+ assert(it != localTypeConstraints.end() && "expected valid constraint fn");
+ return it->second;
+}
+
+std::string StaticVerifierFunctionEmitter::getUniqueName(
+ const llvm::RecordKeeper &records) {
+ // Use the input file name when generating a unique name.
+ std::string inputFilename = records.getInputFilename();
+
+ // Drop all but the base filename.
+ StringRef nameRef = llvm::sys::path::filename(inputFilename);
+ nameRef.consume_back(".td");
+
+ // Sanitize any invalid characters.
+ std::string uniqueName;
+ for (char c : nameRef) {
+ if (llvm::isAlnum(c) || c == '_')
+ uniqueName.push_back(c);
+ else
+ uniqueName.append(llvm::utohexstr((unsigned char)c));
+ }
+ return uniqueName;
+}
+
+void StaticVerifierFunctionEmitter::emitTypeConstraintMethods(
+ StringRef signatureFormat, StringRef errorHandlerFormat,
+ StringRef typeArgName, ArrayRef<llvm::Record *> opDefs, bool emitDecl) {
+ // Collect a set of all of the used type constraints within the operation
+ // definitions.
+ llvm::SetVector<const void *> typeConstraints;
+ for (Record *def : opDefs) {
+ Operator op(*def);
+ for (NamedTypeConstraint &operand : op.getOperands())
+ if (operand.hasPredicate())
+ typeConstraints.insert(operand.constraint.getAsOpaquePointer());
+ for (NamedTypeConstraint &result : op.getResults())
+ if (result.hasPredicate())
+ typeConstraints.insert(result.constraint.getAsOpaquePointer());
+ }
+
+ // Record the mapping from predicate to constraint. If two constraints has the
+ // same predicate and constraint summary, they can share the same verification
+ // function.
+ llvm::DenseMap<Pred, const void *> predToConstraint;
+ FmtContext fctx;
+ for (auto it : llvm::enumerate(typeConstraints)) {
+ std::string name;
+ Constraint constraint = Constraint::getFromOpaquePointer(it.value());
+ Pred pred = constraint.getPredicate();
+ auto iter = predToConstraint.find(pred);
+ if (iter != predToConstraint.end()) {
+ do {
+ Constraint built = Constraint::getFromOpaquePointer(iter->second);
+ // We may have the
diff erent constraints but have the same predicate,
+ // for example, ConstraintA and Variadic<ConstraintA>, note that
+ // Variadic<> doesn't introduce new predicate. In this case, we can
+ // share the same predicate function if they also have consistent
+ // summary, otherwise we may report the wrong message while verification
+ // fails.
+ if (constraint.getSummary() == built.getSummary()) {
+ name = getTypeConstraintFn(built).str();
+ break;
+ }
+ ++iter;
+ } while (iter != predToConstraint.end() && iter->first == pred);
+ }
+
+ if (!name.empty()) {
+ localTypeConstraints.try_emplace(it.value(), name);
+ continue;
+ }
+
+ // Generate an obscure and unique name for this type constraint.
+ name = (Twine("__mlir_ods_local_type_constraint_") + uniqueOutputLabel +
+ Twine(it.index()))
+ .str();
+ predToConstraint.insert(
+ std::make_pair(constraint.getPredicate(), it.value()));
+ localTypeConstraints.try_emplace(it.value(), name);
+
+ // Only generate the methods if we are generating definitions.
+ if (emitDecl)
+ continue;
+
+ os << formatv(signatureFormat.data(), name) << " {\n";
+ os.indent() << "if (!("
+ << tgfmt(constraint.getConditionTemplate(),
+ &fctx.withSelf(typeArgName))
+ << ")) {\n";
+ os.indent() << "return "
+ << formatv(errorHandlerFormat.data(), constraint.getSummary())
+ << ";\n";
+ os.unindent() << "}\nreturn ::mlir::success();\n";
+ os.unindent() << "}\n\n";
+ }
+}
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 2a851618938b0..269803f50788f 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -24,7 +24,6 @@
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/StringExtras.h"
-#include "llvm/Support/Path.h"
#include "llvm/Support/Signals.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Record.h"
@@ -101,6 +100,14 @@ const char *valueRangeReturnCode = R"(
std::next({0}, valueRange.first + valueRange.second)};
)";
+const char *typeVerifierSignature =
+ "static ::mlir::LogicalResult {0}(::mlir::Operation *op, ::mlir::Type "
+ "type, ::llvm::StringRef valueKind, unsigned valueGroupStartIndex)";
+
+const char *typeVerifierErrorHandler =
+ " op->emitOpError(valueKind) << \" #\" << valueGroupStartIndex << \" must "
+ "be {0}, but got \" << type";
+
static const char *const opCommentHeader = R"(
//===----------------------------------------------------------------------===//
// {0} {1}
@@ -108,175 +115,6 @@ static const char *const opCommentHeader = R"(
)";
-//===----------------------------------------------------------------------===//
-// StaticVerifierFunctionEmitter
-//===----------------------------------------------------------------------===//
-
-namespace {
-/// This class deduplicates shared operation verification code by emitting
-/// static functions alongside the op definitions. These methods are local to
-/// the definition file, and are invoked within the operation verify methods.
-/// An example is shown below:
-///
-/// static LogicalResult localVerify(...)
-///
-/// LogicalResult OpA::verify(...) {
-/// if (failed(localVerify(...)))
-/// return failure();
-/// ...
-/// }
-///
-/// LogicalResult OpB::verify(...) {
-/// if (failed(localVerify(...)))
-/// return failure();
-/// ...
-/// }
-///
-class StaticVerifierFunctionEmitter {
-public:
- StaticVerifierFunctionEmitter(const llvm::RecordKeeper &records,
- ArrayRef<llvm::Record *> opDefs,
- raw_ostream &os, bool emitDecl);
-
- /// Get the name of the local function used for the given type constraint.
- /// These functions are used for operand and result constraints and have the
- /// form:
- /// LogicalResult(Operation *op, Type type, StringRef valueKind,
- /// unsigned valueGroupStartIndex);
- StringRef getTypeConstraintFn(const Constraint &constraint) const {
- auto it = localTypeConstraints.find(constraint.getAsOpaquePointer());
- assert(it != localTypeConstraints.end() && "expected valid constraint fn");
- return it->second;
- }
-
-private:
- /// Returns a unique name to use when generating local methods.
- static std::string getUniqueName(const llvm::RecordKeeper &records);
-
- /// Emit local methods for the type constraints used within the provided op
- /// definitions.
- void emitTypeConstraintMethods(ArrayRef<llvm::Record *> opDefs,
- raw_ostream &os, bool emitDecl);
-
- /// A unique label for the file currently being generated. This is used to
- /// ensure that the local functions have a unique name.
- std::string uniqueOutputLabel;
-
- /// A set of functions implementing type constraints, used for operand and
- /// result verification.
- llvm::DenseMap<const void *, std::string> localTypeConstraints;
-};
-} // namespace
-
-StaticVerifierFunctionEmitter::StaticVerifierFunctionEmitter(
- const llvm::RecordKeeper &records, ArrayRef<llvm::Record *> opDefs,
- raw_ostream &os, bool emitDecl)
- : uniqueOutputLabel(getUniqueName(records)) {
- llvm::Optional<NamespaceEmitter> namespaceEmitter;
- if (!emitDecl) {
- os << formatv(opCommentHeader, "Local Utility Method", "Definitions");
- namespaceEmitter.emplace(os, Operator(*opDefs[0]).getCppNamespace());
- }
-
- emitTypeConstraintMethods(opDefs, os, emitDecl);
-}
-
-std::string StaticVerifierFunctionEmitter::getUniqueName(
- const llvm::RecordKeeper &records) {
- // Use the input file name when generating a unique name.
- std::string inputFilename = records.getInputFilename();
-
- // Drop all but the base filename.
- StringRef nameRef = llvm::sys::path::filename(inputFilename);
- nameRef.consume_back(".td");
-
- // Sanitize any invalid characters.
- std::string uniqueName;
- for (char c : nameRef) {
- if (llvm::isAlnum(c) || c == '_')
- uniqueName.push_back(c);
- else
- uniqueName.append(llvm::utohexstr((unsigned char)c));
- }
- return uniqueName;
-}
-
-void StaticVerifierFunctionEmitter::emitTypeConstraintMethods(
- ArrayRef<llvm::Record *> opDefs, raw_ostream &os, bool emitDecl) {
- // Collect a set of all of the used type constraints within the operation
- // definitions.
- llvm::SetVector<const void *> typeConstraints;
- for (Record *def : opDefs) {
- Operator op(*def);
- for (NamedTypeConstraint &operand : op.getOperands())
- if (operand.hasPredicate())
- typeConstraints.insert(operand.constraint.getAsOpaquePointer());
- for (NamedTypeConstraint &result : op.getResults())
- if (result.hasPredicate())
- typeConstraints.insert(result.constraint.getAsOpaquePointer());
- }
-
- // Record the mapping from predicate to constraint. If two constraints has the
- // same predicate and constraint summary, they can share the same verification
- // function.
- llvm::DenseMap<Pred, const void *> predToConstraint;
- FmtContext fctx;
- for (auto it : llvm::enumerate(typeConstraints)) {
- std::string name;
- Constraint constraint = Constraint::getFromOpaquePointer(it.value());
- Pred pred = constraint.getPredicate();
- auto iter = predToConstraint.find(pred);
- if (iter != predToConstraint.end()) {
- do {
- Constraint built = Constraint::getFromOpaquePointer(iter->second);
- // We may have the
diff erent constraints but have the same predicate,
- // for example, ConstraintA and Variadic<ConstraintA>, note that
- // Variadic<> doesn't introduce new predicate. In this case, we can
- // share the same predicate function if they also have consistent
- // summary, otherwise we may report the wrong message while verification
- // fails.
- if (constraint.getSummary() == built.getSummary()) {
- name = getTypeConstraintFn(built).str();
- break;
- }
- ++iter;
- } while (iter != predToConstraint.end() && iter->first == pred);
- }
-
- if (!name.empty()) {
- localTypeConstraints.try_emplace(it.value(), name);
- continue;
- }
-
- // Generate an obscure and unique name for this type constraint.
- name = (Twine("__mlir_ods_local_type_constraint_") + uniqueOutputLabel +
- Twine(it.index()))
- .str();
- predToConstraint.insert(
- std::make_pair(constraint.getPredicate(), it.value()));
- localTypeConstraints.try_emplace(it.value(), name);
-
- // Only generate the methods if we are generating definitions.
- if (emitDecl)
- continue;
-
- os << "static ::mlir::LogicalResult " << name
- << "(::mlir::Operation *op, ::mlir::Type type, ::llvm::StringRef "
- "valueKind, unsigned valueGroupStartIndex) {\n";
-
- os << " if (!("
- << tgfmt(constraint.getConditionTemplate(), &fctx.withSelf("type"))
- << ")) {\n"
- << formatv(
- " return op->emitOpError(valueKind) << \" #\" << "
- "valueGroupStartIndex << \" must be {0}, but got \" << type;\n",
- constraint.getSummary())
- << " }\n"
- << " return ::mlir::success();\n"
- << "}\n\n";
- }
-}
-
//===----------------------------------------------------------------------===//
// Utility structs and functions
//===----------------------------------------------------------------------===//
@@ -2560,8 +2398,12 @@ static void emitOpClasses(const RecordKeeper &recordKeeper,
return;
// Generate all of the locally instantiated methods first.
- StaticVerifierFunctionEmitter staticVerifierEmitter(recordKeeper, defs, os,
- emitDecl);
+ StaticVerifierFunctionEmitter staticVerifierEmitter(recordKeeper, os);
+ os << formatv(opCommentHeader, "Local Utility Method", "Definitions");
+ staticVerifierEmitter.emitFunctionsFor(
+ typeVerifierSignature, typeVerifierErrorHandler, /*typeArgName=*/"type",
+ defs, emitDecl);
+
for (auto *def : defs) {
Operator op(*def);
if (emitDecl) {
More information about the Mlir-commits
mailing list