[llvm-branch-commits] [mlir] b3ee7f1 - [mlir][OpDefGen] Add support for generating local functions for shared utilities

River Riddle via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Mon Dec 14 14:26:40 PST 2020


Author: River Riddle
Date: 2020-12-14T14:21:30-08:00
New Revision: b3ee7f1f312dd41a0ac883d78b71c14d96e78939

URL: https://github.com/llvm/llvm-project/commit/b3ee7f1f312dd41a0ac883d78b71c14d96e78939
DIFF: https://github.com/llvm/llvm-project/commit/b3ee7f1f312dd41a0ac883d78b71c14d96e78939.diff

LOG: [mlir][OpDefGen] Add support for generating local functions for shared utilities

This revision adds a new `StaticVerifierFunctionEmitter` class that emits local static functions in the .cpp file for shared operation verification. This class deduplicates shared operation verification code by emitting static functions alongside the op definitions. These methods are local to the definition file, and are invoked within the operation verify methods. The first bit of shared verification is for the type constraints used when verifying operands and results. An example is shown below:

```
static LogicalResult localVerify(...) {
  ...
}

LogicalResult OpA::verify(...) {
  if (failed(localVerify(...)))
    return failure();
  ...
}

LogicalResult OpB::verify(...) {
  if (failed(localVerify(...)))
    return failure();
  ...
}
```

This allowed for saving >400kb of code size from a downstream TensorFlow project (~15% of MLIR code size).

Differential Revision: https://reviews.llvm.org/D91381

Added: 
    

Modified: 
    mlir/include/mlir/TableGen/Constraint.h
    mlir/test/mlir-tblgen/predicate.td
    mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/TableGen/Constraint.h b/mlir/include/mlir/TableGen/Constraint.h
index 775b3545a034..e5ced0361c16 100644
--- a/mlir/include/mlir/TableGen/Constraint.h
+++ b/mlir/include/mlir/TableGen/Constraint.h
@@ -52,6 +52,13 @@ class Constraint {
 
   Kind getKind() const { return kind; }
 
+  /// Get an opaque pointer to the constraint.
+  const void *getAsOpaquePointer() const { return def; }
+  /// Construct a constraint from the opaque pointer representation.
+  static Constraint getFromOpaquePointer(const void *ptr) {
+    return Constraint(reinterpret_cast<const llvm::Record *>(ptr));
+  }
+
 protected:
   Constraint(Kind kind, const llvm::Record *record);
 

diff  --git a/mlir/test/mlir-tblgen/predicate.td b/mlir/test/mlir-tblgen/predicate.td
index e5d4ac853b50..386c61319b79 100644
--- a/mlir/test/mlir-tblgen/predicate.td
+++ b/mlir/test/mlir-tblgen/predicate.td
@@ -15,10 +15,22 @@ def OpA : NS_Op<"op_for_CPred_containing_multiple_same_placeholder", []> {
   let arguments = (ins I32OrF32:$x);
 }
 
+// CHECK: static ::mlir::LogicalResult [[$INTEGER_FLOAT_CONSTRAINT:__mlir_ods_local_type_constraint.*]](
+// CHECK:  if (!((type.isInteger(32) || type.isF32()))) {
+// CHECK:    return op->emitOpError(valueKind) << " #" << valueGroupStartIndex << " must be 32-bit integer or floating-point type, but got " << type;
+
+// CHECK: static ::mlir::LogicalResult [[$TENSOR_CONSTRAINT:__mlir_ods_local_type_constraint.*]](
+// CHECK:  if (!(((type.isa<::mlir::TensorType>())) && ((true)))) {
+// CHECK:    return op->emitOpError(valueKind) << " #" << valueGroupStartIndex << " must be tensor of any type values, but got " << type;
+
+// CHECK: static ::mlir::LogicalResult [[$TENSOR_INTEGER_FLOAT_CONSTRAINT:__mlir_ods_local_type_constraint.*]](
+// CHECK:  if (!(((type.isa<::mlir::TensorType>())) && (((type.cast<::mlir::ShapedType>().getElementType().isF32())) || ((type.cast<::mlir::ShapedType>().getElementType().isSignlessInteger(32)))))) {
+// CHECK:    return op->emitOpError(valueKind) << " #" << valueGroupStartIndex << " must be tensor of 32-bit float or 32-bit signless integer values, but got " << type;
+
 // CHECK-LABEL: OpA::verify
 // CHECK: auto valueGroup0 = getODSOperands(0);
 // CHECK: for (::mlir::Value v : valueGroup0) {
-// CHECK:   if (!((v.getType().isInteger(32) || v.getType().isF32())))
+// CHECK:   if (::mlir::failed([[$INTEGER_FLOAT_CONSTRAINT]]
 
 def OpB : NS_Op<"op_for_And_PredOpTrait", [
     PredOpTrait<"both first and second holds",
@@ -93,4 +105,4 @@ def OpK : NS_Op<"op_for_AnyTensorOf", []> {
 // CHECK-LABEL: OpK::verify
 // CHECK: auto valueGroup0 = getODSOperands(0);
 // CHECK: for (::mlir::Value v : valueGroup0) {
-// CHECK: if (!(((v.getType().isa<::mlir::TensorType>())) && (((v.getType().cast<::mlir::ShapedType>().getElementType().isF32())) || ((v.getType().cast<::mlir::ShapedType>().getElementType().isSignlessInteger(32))))))
+// CHECK: if (::mlir::failed([[$TENSOR_INTEGER_FLOAT_CONSTRAINT]]

diff  --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index f2a57fbbb00f..9a4ddbffac43 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -117,6 +117,144 @@ static const char *const opCommentHeader = R"(
 
 )";
 
+//===----------------------------------------------------------------------===//
+// StaticVerifierFunctionEmitter
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// This class deduplicates shared operation verification code by emitting
+/// static functions alongside the op definitions. These methods are local to
+/// the definition file, and are invoked within the operation verify methods.
+/// An example is shown below:
+///
+/// static LogicalResult localVerify(...)
+///
+/// LogicalResult OpA::verify(...) {
+///  if (failed(localVerify(...)))
+///    return failure();
+///  ...
+/// }
+///
+/// LogicalResult OpB::verify(...) {
+///  if (failed(localVerify(...)))
+///    return failure();
+///  ...
+/// }
+///
+class StaticVerifierFunctionEmitter {
+public:
+  StaticVerifierFunctionEmitter(const llvm::RecordKeeper &records,
+                                ArrayRef<llvm::Record *> opDefs,
+                                raw_ostream &os, bool emitDecl);
+
+  /// Get the name of the local function used for the given type constraint.
+  /// These functions are used for operand and result constraints and have the
+  /// form:
+  ///   LogicalResult(Operation *op, Type type, StringRef valueKind,
+  ///                 unsigned valueGroupStartIndex);
+  StringRef getTypeConstraintFn(const Constraint &constraint) const {
+    auto it = localTypeConstraints.find(constraint.getAsOpaquePointer());
+    assert(it != localTypeConstraints.end() && "expected valid constraint fn");
+    return it->second;
+  }
+
+private:
+  /// Returns a unique name to use when generating local methods.
+  static std::string getUniqueName(const llvm::RecordKeeper &records);
+
+  /// Emit local methods for the type constraints used within the provided op
+  /// definitions.
+  void emitTypeConstraintMethods(ArrayRef<llvm::Record *> opDefs,
+                                 raw_ostream &os, bool emitDecl);
+
+  /// A unique label for the file currently being generated. This is used to
+  /// ensure that the local functions have a unique name.
+  std::string uniqueOutputLabel;
+
+  /// A set of functions implementing type constraints, used for operand and
+  /// result verification.
+  llvm::DenseMap<const void *, std::string> localTypeConstraints;
+};
+} // namespace
+
+StaticVerifierFunctionEmitter::StaticVerifierFunctionEmitter(
+    const llvm::RecordKeeper &records, ArrayRef<llvm::Record *> opDefs,
+    raw_ostream &os, bool emitDecl)
+    : uniqueOutputLabel(getUniqueName(records)) {
+  llvm::Optional<NamespaceEmitter> namespaceEmitter;
+  if (!emitDecl) {
+    os << formatv(opCommentHeader, "Local Utility Method", "Definitions");
+    namespaceEmitter.emplace(os, Operator(*opDefs[0]).getDialect());
+  }
+
+  emitTypeConstraintMethods(opDefs, os, emitDecl);
+}
+
+std::string StaticVerifierFunctionEmitter::getUniqueName(
+    const llvm::RecordKeeper &records) {
+  // Use the input file name when generating a unique name.
+  std::string inputFilename = records.getInputFilename();
+
+  // Drop all but the base filename.
+  StringRef nameRef = llvm::sys::path::filename(inputFilename);
+  nameRef.consume_back(".td");
+
+  // Sanitize any invalid characters.
+  std::string uniqueName;
+  for (char c : nameRef) {
+    if (llvm::isAlnum(c) || c == '_')
+      uniqueName.push_back(c);
+    else
+      uniqueName.append(llvm::utohexstr((unsigned char)c));
+  }
+  return uniqueName;
+}
+
+void StaticVerifierFunctionEmitter::emitTypeConstraintMethods(
+    ArrayRef<llvm::Record *> opDefs, raw_ostream &os, bool emitDecl) {
+  // Collect a set of all of the used type constraints within the operation
+  // definitions.
+  llvm::SetVector<const void *> typeConstraints;
+  for (Record *def : opDefs) {
+    Operator op(*def);
+    for (NamedTypeConstraint &operand : op.getOperands())
+      if (operand.hasPredicate())
+        typeConstraints.insert(operand.constraint.getAsOpaquePointer());
+    for (NamedTypeConstraint &result : op.getResults())
+      if (result.hasPredicate())
+        typeConstraints.insert(result.constraint.getAsOpaquePointer());
+  }
+
+  FmtContext fctx;
+  for (auto it : llvm::enumerate(typeConstraints)) {
+    // Generate an obscure and unique name for this type constraint.
+    std::string name = (Twine("__mlir_ods_local_type_constraint_") +
+                        uniqueOutputLabel + Twine(it.index()))
+                           .str();
+    localTypeConstraints.try_emplace(it.value(), name);
+
+    // Only generate the methods if we are generating definitions.
+    if (emitDecl)
+      continue;
+
+    Constraint constraint = Constraint::getFromOpaquePointer(it.value());
+    os << "static ::mlir::LogicalResult " << name
+       << "(::mlir::Operation *op, ::mlir::Type type, ::llvm::StringRef "
+          "valueKind, unsigned valueGroupStartIndex) {\n";
+
+    os << "  if (!("
+       << tgfmt(constraint.getConditionTemplate(), &fctx.withSelf("type"))
+       << ")) {\n"
+       << formatv(
+              "    return op->emitOpError(valueKind) << \" #\" << "
+              "valueGroupStartIndex << \" must be {0}, but got \" << type;\n",
+              constraint.getDescription())
+       << "  }\n"
+       << "  return ::mlir::success();\n"
+       << "}\n\n";
+  }
+}
+
 //===----------------------------------------------------------------------===//
 // Utility structs and functions
 //===----------------------------------------------------------------------===//
@@ -164,11 +302,16 @@ namespace {
 // Helper class to emit a record into the given output stream.
 class OpEmitter {
 public:
-  static void emitDecl(const Operator &op, raw_ostream &os);
-  static void emitDef(const Operator &op, raw_ostream &os);
+  static void
+  emitDecl(const Operator &op, raw_ostream &os,
+           const StaticVerifierFunctionEmitter &staticVerifierEmitter);
+  static void
+  emitDef(const Operator &op, raw_ostream &os,
+          const StaticVerifierFunctionEmitter &staticVerifierEmitter);
 
 private:
-  OpEmitter(const Operator &op);
+  OpEmitter(const Operator &op,
+            const StaticVerifierFunctionEmitter &staticVerifierEmitter);
 
   void emitDecl(raw_ostream &os);
   void emitDef(raw_ostream &os);
@@ -321,6 +464,9 @@ class OpEmitter {
 
   // The format context for verification code generation.
   FmtContext verifyCtx;
+
+  // The emitter containing all of the locally emitted verification functions.
+  const StaticVerifierFunctionEmitter &staticVerifierEmitter;
 };
 } // end anonymous namespace
 
@@ -434,9 +580,11 @@ static void genAttributeVerifier(const Operator &op, const char *attrGet,
   }
 }
 
-OpEmitter::OpEmitter(const Operator &op)
+OpEmitter::OpEmitter(const Operator &op,
+                     const StaticVerifierFunctionEmitter &staticVerifierEmitter)
     : def(op.getDef()), op(op),
-      opClass(op.getCppClassName(), op.getExtraClassDeclaration()) {
+      opClass(op.getCppClassName(), op.getExtraClassDeclaration()),
+      staticVerifierEmitter(staticVerifierEmitter) {
   verifyCtx.withOp("(*this->getOperation())");
 
   genTraits();
@@ -464,12 +612,16 @@ OpEmitter::OpEmitter(const Operator &op)
   genSideEffectInterfaceMethods();
 }
 
-void OpEmitter::emitDecl(const Operator &op, raw_ostream &os) {
-  OpEmitter(op).emitDecl(os);
+void OpEmitter::emitDecl(
+    const Operator &op, raw_ostream &os,
+    const StaticVerifierFunctionEmitter &staticVerifierEmitter) {
+  OpEmitter(op, staticVerifierEmitter).emitDecl(os);
 }
 
-void OpEmitter::emitDef(const Operator &op, raw_ostream &os) {
-  OpEmitter(op).emitDef(os);
+void OpEmitter::emitDef(
+    const Operator &op, raw_ostream &os,
+    const StaticVerifierFunctionEmitter &staticVerifierEmitter) {
+  OpEmitter(op, staticVerifierEmitter).emitDef(os);
 }
 
 void OpEmitter::emitDecl(raw_ostream &os) { opClass.writeDeclTo(os); }
@@ -1891,23 +2043,16 @@ void OpEmitter::genOperandResultVerifier(OpMethodBody &body,
     // Otherwise, if there is no predicate there is nothing left to do.
     if (!hasPredicate)
       continue;
-
     // Emit a loop to check all the dynamic values in the pack.
+    StringRef constraintFn = staticVerifierEmitter.getTypeConstraintFn(
+        staticValue.value().constraint);
     body << "    for (::mlir::Value v : valueGroup" << staticValue.index()
-         << ") {\n";
-
-    auto constraint = staticValue.value().constraint;
-    body << "      (void)v;\n"
-         << "      if (!("
-         << tgfmt(constraint.getConditionTemplate(),
-                  &fctx.withSelf("v.getType()"))
-         << ")) {\n"
-         << formatv("        return emitOpError(\"{0} #\") << index "
-                    "<< \" must be {1}, but got \" << v.getType();\n",
-                    valueKind, constraint.getDescription())
-         << "      }\n" // if
+         << ") {\n"
+         << "      if (::mlir::failed(" << constraintFn
+         << "(getOperation(), v.getType(), \"" << valueKind << "\", index)))\n"
+         << "        return ::mlir::failure();\n"
          << "      ++index;\n"
-         << "    }\n"; // for
+         << "    }\n";
   }
 
   body << "  }\n";
@@ -2248,7 +2393,8 @@ void OpOperandAdaptorEmitter::emitDef(const Operator &op, raw_ostream &os) {
 }
 
 // Emits the opcode enum and op classes.
-static void emitOpClasses(const std::vector<Record *> &defs, raw_ostream &os,
+static void emitOpClasses(const RecordKeeper &recordKeeper,
+                          const std::vector<Record *> &defs, raw_ostream &os,
                           bool emitDecl) {
   // First emit forward declaration for each class, this allows them to refer
   // to each others in traits for example.
@@ -2264,17 +2410,23 @@ static void emitOpClasses(const std::vector<Record *> &defs, raw_ostream &os,
   }
 
   IfDefScope scope("GET_OP_CLASSES", os);
+  if (defs.empty())
+    return;
+
+  // Generate all of the locally instantiated methods first.
+  StaticVerifierFunctionEmitter staticVerifierEmitter(recordKeeper, defs, os,
+                                                      emitDecl);
   for (auto *def : defs) {
     Operator op(*def);
     NamespaceEmitter emitter(os, op.getDialect());
     if (emitDecl) {
       os << formatv(opCommentHeader, op.getQualCppClassName(), "declarations");
       OpOperandAdaptorEmitter::emitDecl(op, os);
-      OpEmitter::emitDecl(op, os);
+      OpEmitter::emitDecl(op, os, staticVerifierEmitter);
     } else {
       os << formatv(opCommentHeader, op.getQualCppClassName(), "definitions");
       OpOperandAdaptorEmitter::emitDef(op, os);
-      OpEmitter::emitDef(op, os);
+      OpEmitter::emitDef(op, os, staticVerifierEmitter);
     }
   }
 }
@@ -2329,7 +2481,7 @@ static bool emitOpDecls(const RecordKeeper &recordKeeper, raw_ostream &os) {
   emitSourceFileHeader("Op Declarations", os);
 
   const auto &defs = getAllDerivedDefinitions(recordKeeper, "Op");
-  emitOpClasses(defs, os, /*emitDecl=*/true);
+  emitOpClasses(recordKeeper, defs, os, /*emitDecl=*/true);
 
   return false;
 }
@@ -2339,7 +2491,7 @@ static bool emitOpDefs(const RecordKeeper &recordKeeper, raw_ostream &os) {
 
   const auto &defs = getAllDerivedDefinitions(recordKeeper, "Op");
   emitOpList(defs, os);
-  emitOpClasses(defs, os, /*emitDecl=*/false);
+  emitOpClasses(recordKeeper, defs, os, /*emitDecl=*/false);
 
   return false;
 }


        


More information about the llvm-branch-commits mailing list