[llvm-branch-commits] [mlir] 2074177 - [mlir][ODS] Add a C++ abstraction for OpBuilders

River Riddle via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Mon Jan 11 12:13:23 PST 2021


Author: River Riddle
Date: 2021-01-11T12:06:22-08:00
New Revision: 207417730134931c7d5bf82e0b16c7757ad05e05

URL: https://github.com/llvm/llvm-project/commit/207417730134931c7d5bf82e0b16c7757ad05e05
DIFF: https://github.com/llvm/llvm-project/commit/207417730134931c7d5bf82e0b16c7757ad05e05.diff

LOG: [mlir][ODS] Add a C++ abstraction for OpBuilders

This removes the need for OpDefinitionsGen to use raw tablegen API, and will also
simplify adding builders to TypeDefs as well.

Differential Revision: https://reviews.llvm.org/D94273

Added: 
    mlir/include/mlir/TableGen/Builder.h
    mlir/lib/TableGen/Builder.cpp

Modified: 
    mlir/include/mlir/TableGen/Operator.h
    mlir/lib/TableGen/CMakeLists.txt
    mlir/lib/TableGen/Operator.cpp
    mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/TableGen/Builder.h b/mlir/include/mlir/TableGen/Builder.h
new file mode 100644
index 000000000000..b901c8414e81
--- /dev/null
+++ b/mlir/include/mlir/TableGen/Builder.h
@@ -0,0 +1,85 @@
+//===- Builder.h - Builder classes ------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Builder wrapper to simplify using TableGen Record for building
+// operations/types/etc.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TABLEGEN_BUILDER_H_
+#define MLIR_TABLEGEN_BUILDER_H_
+
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringRef.h"
+
+namespace llvm {
+class Init;
+class Record;
+class SMLoc;
+} // end namespace llvm
+
+namespace mlir {
+namespace tblgen {
+
+/// Wrapper class with helper methods for accessing Builders defined in
+/// TableGen.
+class Builder {
+public:
+  /// This class represents a single parameter to a builder method.
+  class Parameter {
+  public:
+    /// Return a string containing the C++ type of this parameter.
+    StringRef getCppType() const;
+
+    /// Return an optional string containing the name of this parameter. If
+    /// None, no name was specified for this parameter by the user.
+    Optional<StringRef> getName() const { return name; }
+
+    /// Return an optional string containing the default value to use for this
+    /// parameter.
+    Optional<StringRef> getDefaultValue() const;
+
+  private:
+    Parameter(Optional<StringRef> name, const llvm::Init *def)
+        : name(name), def(def) {}
+
+    /// The optional name of the parameter.
+    Optional<StringRef> name;
+
+    /// The tablegen definition of the parameter. This is either a StringInit,
+    /// or a CArg DefInit.
+    const llvm::Init *def;
+
+    // Allow access to the constructor.
+    friend Builder;
+  };
+
+  /// Construct a builder from the given Record instance.
+  Builder(const llvm::Record *record, ArrayRef<llvm::SMLoc> loc);
+
+  /// Return a list of parameters used in this build method.
+  ArrayRef<Parameter> getParameters() const { return parameters; }
+
+  /// Return an optional string containing the body of the builder.
+  Optional<StringRef> getBody() const;
+
+protected:
+  /// The TableGen definition of this builder.
+  const llvm::Record *def;
+
+private:
+  /// A collection of parameters to the builder.
+  SmallVector<Parameter> parameters;
+};
+
+} // end namespace tblgen
+} // end namespace mlir
+
+#endif // MLIR_TABLEGEN_BUILDER_H_

diff  --git a/mlir/include/mlir/TableGen/Operator.h b/mlir/include/mlir/TableGen/Operator.h
index 16d154b3beb0..d21b4b213ee4 100644
--- a/mlir/include/mlir/TableGen/Operator.h
+++ b/mlir/include/mlir/TableGen/Operator.h
@@ -16,6 +16,7 @@
 #include "mlir/Support/LLVM.h"
 #include "mlir/TableGen/Argument.h"
 #include "mlir/TableGen/Attribute.h"
+#include "mlir/TableGen/Builder.h"
 #include "mlir/TableGen/Dialect.h"
 #include "mlir/TableGen/OpTrait.h"
 #include "mlir/TableGen/Region.h"
@@ -287,6 +288,9 @@ class Operator {
   // Returns the OperandOrAttribute corresponding to the index.
   OperandOrAttribute getArgToOperandOrAttribute(int index) const;
 
+  // Returns the builders of this operation.
+  ArrayRef<Builder> getBuilders() const { return builders; }
+
 private:
   // Populates the vectors containing operands, attributes, results and traits.
   void populateOpStructure();
@@ -332,6 +336,9 @@ class Operator {
   // Map from argument to attribute or operand number.
   SmallVector<OperandOrAttribute, 4> attrOrOperandMapping;
 
+  // The builders of this operator.
+  SmallVector<Builder> builders;
+
   // The number of native attributes stored in the leading positions of
   // `attributes`.
   int numNativeAttributes;

diff  --git a/mlir/lib/TableGen/Builder.cpp b/mlir/lib/TableGen/Builder.cpp
new file mode 100644
index 000000000000..8210e8fe1a12
--- /dev/null
+++ b/mlir/lib/TableGen/Builder.cpp
@@ -0,0 +1,74 @@
+//===- Builder.cpp - Builder definitions ----------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/TableGen/Builder.h"
+#include "llvm/TableGen/Error.h"
+#include "llvm/TableGen/Record.h"
+
+using namespace mlir;
+using namespace mlir::tblgen;
+
+//===----------------------------------------------------------------------===//
+// Builder::Parameter
+//===----------------------------------------------------------------------===//
+
+/// Return a string containing the C++ type of this parameter.
+StringRef Builder::Parameter::getCppType() const {
+  if (const auto *stringInit = dyn_cast<llvm::StringInit>(def))
+    return stringInit->getValue();
+  const llvm::Record *record = cast<llvm::DefInit>(def)->getDef();
+  return record->getValueAsString("type");
+}
+
+/// Return an optional string containing the default value to use for this
+/// parameter.
+Optional<StringRef> Builder::Parameter::getDefaultValue() const {
+  if (isa<llvm::StringInit>(def))
+    return llvm::None;
+  const llvm::Record *record = cast<llvm::DefInit>(def)->getDef();
+  Optional<StringRef> value = record->getValueAsOptionalString("defaultValue");
+  return value && !value->empty() ? value : llvm::None;
+}
+
+//===----------------------------------------------------------------------===//
+// Builder
+//===----------------------------------------------------------------------===//
+
+Builder::Builder(const llvm::Record *record, ArrayRef<llvm::SMLoc> loc)
+    : def(record) {
+  // Initialize the parameters of the builder.
+  const llvm::DagInit *dag = def->getValueAsDag("dagParams");
+  auto *defInit = dyn_cast<llvm::DefInit>(dag->getOperator());
+  if (!defInit || !defInit->getDef()->getName().equals("ins"))
+    PrintFatalError(def->getLoc(), "expected 'ins' in builders");
+
+  bool seenDefaultValue = false;
+  for (unsigned i = 0, e = dag->getNumArgs(); i < e; ++i) {
+    const llvm::StringInit *paramName = dag->getArgName(i);
+    const llvm::Init *paramValue = dag->getArg(i);
+    Parameter param(paramName ? paramName->getValue() : Optional<StringRef>(),
+                    paramValue);
+
+    // Similarly to C++, once an argument with a default value is detected, the
+    // following arguments must have default values as well.
+    if (param.getDefaultValue()) {
+      seenDefaultValue = true;
+    } else if (seenDefaultValue) {
+      PrintFatalError(loc,
+                      "expected an argument with default value after other "
+                      "arguments with default values");
+    }
+    parameters.emplace_back(param);
+  }
+}
+
+/// Return an optional string containing the body of the builder.
+Optional<StringRef> Builder::getBody() const {
+  Optional<StringRef> body = def->getValueAsOptionalString("body");
+  return body && !body->empty() ? body : llvm::None;
+}

diff  --git a/mlir/lib/TableGen/CMakeLists.txt b/mlir/lib/TableGen/CMakeLists.txt
index a11473b9370c..fa52dde27a40 100644
--- a/mlir/lib/TableGen/CMakeLists.txt
+++ b/mlir/lib/TableGen/CMakeLists.txt
@@ -11,6 +11,7 @@
 llvm_add_library(MLIRTableGen STATIC
   Argument.cpp
   Attribute.cpp
+  Builder.cpp
   Constraint.cpp
   Dialect.cpp
   Format.cpp

diff  --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp
index ca1442f547d5..209c1ec0d94a 100644
--- a/mlir/lib/TableGen/Operator.cpp
+++ b/mlir/lib/TableGen/Operator.cpp
@@ -521,6 +521,18 @@ void Operator::populateOpStructure() {
     regions.push_back({name, region});
   }
 
+  // Populate the builders.
+  auto *builderList =
+      dyn_cast_or_null<llvm::ListInit>(def.getValueInit("builders"));
+  if (builderList && !builderList->empty()) {
+    for (llvm::Init *init : builderList->getValues())
+      builders.emplace_back(cast<llvm::DefInit>(init)->getDef(), def.getLoc());
+  } else if (skipDefaultBuilders()) {
+    PrintFatalError(
+        def.getLoc(),
+        "default builders are skipped and no custom builders provided");
+  }
+
   LLVM_DEBUG(print(llvm::dbgs()));
 }
 

diff  --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index ed04bd07fc4c..69c6b022054e 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -48,7 +48,7 @@ static cl::opt<std::string> opExcFilter(
 
 static const char *const tblgenNamePrefix = "tblgen_";
 static const char *const generatedArgName = "odsArg";
-static const char *const builder = "odsBuilder";
+static const char *const odsBuilder = "odsBuilder";
 static const char *const builderOpState = "odsState";
 
 // The logic to calculate the actual value range for a declared operand/result
@@ -1326,54 +1326,31 @@ void OpEmitter::genUseAttrAsResultTypeBuilder() {
   body << "  }\n";
 }
 
-/// Returns a signature of the builder as defined by a dag-typed initializer.
-/// Updates the context `fctx` to enable replacement of $_builder and $_state
-/// in the body. Reports errors at `loc`.
-static std::string builderSignatureFromDAG(const DagInit *init,
-                                           ArrayRef<llvm::SMLoc> loc) {
-  auto *defInit = dyn_cast<DefInit>(init->getOperator());
-  if (!defInit || !defInit->getDef()->getName().equals("ins"))
-    PrintFatalError(loc, "expected 'ins' in builders");
+/// Returns a signature of the builder. Updates the context `fctx` to enable
+/// replacement of $_builder and $_state in the body.
+static std::string getBuilderSignature(const Builder &builder) {
+  ArrayRef<Builder::Parameter> params(builder.getParameters());
 
   // Inject builder and state arguments.
   llvm::SmallVector<std::string, 8> arguments;
-  arguments.reserve(init->getNumArgs() + 2);
-  arguments.push_back(llvm::formatv("::mlir::OpBuilder &{0}", builder).str());
+  arguments.reserve(params.size() + 2);
+  arguments.push_back(
+      llvm::formatv("::mlir::OpBuilder &{0}", odsBuilder).str());
   arguments.push_back(
       llvm::formatv("::mlir::OperationState &{0}", builderOpState).str());
 
-  // Accept either a StringInit or a DefInit with two string values as dag
-  // arguments. The former corresponds to the type, the latter to the type and
-  // the default value. Similarly to C++, once an argument with a default value
-  // is detected, the following arguments must have default values as well.
-  bool seenDefaultValue = false;
-  for (unsigned i = 0, e = init->getNumArgs(); i < e; ++i) {
+  for (unsigned i = 0, e = params.size(); i < e; ++i) {
     // If no name is provided, generate one.
-    StringInit *argName = init->getArgName(i);
+    Optional<StringRef> paramName = params[i].getName();
     std::string name =
-        argName ? argName->getValue().str() : "odsArg" + std::to_string(i);
+        paramName ? paramName->str() : "odsArg" + std::to_string(i);
 
-    Init *argInit = init->getArg(i);
-    StringRef type;
     std::string defaultValue;
-    if (StringInit *strType = dyn_cast<StringInit>(argInit)) {
-      type = strType->getValue();
-    } else {
-      const Record *typeAndDefaultValue = cast<DefInit>(argInit)->getDef();
-      type = typeAndDefaultValue->getValueAsString("type");
-      StringRef defaultValueRef =
-          typeAndDefaultValue->getValueAsString("defaultValue");
-      if (!defaultValueRef.empty()) {
-        seenDefaultValue = true;
-        defaultValue = llvm::formatv(" = {0}", defaultValueRef).str();
-      }
-    }
-    if (seenDefaultValue && defaultValue.empty())
-      PrintFatalError(loc,
-                      "expected an argument with default value after other "
-                      "arguments with default values");
+    if (Optional<StringRef> defaultParamValue = params[i].getDefaultValue())
+      defaultValue = llvm::formatv(" = {0}", *defaultParamValue).str();
     arguments.push_back(
-        llvm::formatv("{0} {1}{2}", type, name, defaultValue).str());
+        llvm::formatv("{0} {1}{2}", params[i].getCppType(), name, defaultValue)
+            .str());
   }
 
   return llvm::join(arguments, ", ");
@@ -1381,41 +1358,26 @@ static std::string builderSignatureFromDAG(const DagInit *init,
 
 void OpEmitter::genBuilder() {
   // Handle custom builders if provided.
-  // TODO: Create wrapper class for OpBuilder to hide the native
-  // TableGen API calls here.
-  {
-    auto *listInit = dyn_cast_or_null<ListInit>(def.getValueInit("builders"));
-    if (listInit) {
-      for (Init *init : listInit->getValues()) {
-        Record *builderDef = cast<DefInit>(init)->getDef();
-        std::string paramStr = builderSignatureFromDAG(
-            builderDef->getValueAsDag("dagParams"), op.getLoc());
-
-        StringRef body = builderDef->getValueAsString("body");
-        bool hasBody = !body.empty();
-        OpMethod::Property properties =
-            hasBody ? OpMethod::MP_Static : OpMethod::MP_StaticDeclaration;
-        auto *method =
-            opClass.addMethodAndPrune("void", "build", properties, paramStr);
+  for (const Builder &builder : op.getBuilders()) {
+    std::string paramStr = getBuilderSignature(builder);
 
-        FmtContext fctx;
-        fctx.withBuilder(builder);
-        fctx.addSubst("_state", builderOpState);
-        if (hasBody)
-          method->body() << tgfmt(body, &fctx);
-      }
-    }
-    if (op.skipDefaultBuilders()) {
-      if (!listInit || listInit->empty())
-        PrintFatalError(
-            op.getLoc(),
-            "default builders are skipped and no custom builders provided");
-      return;
-    }
+    Optional<StringRef> body = builder.getBody();
+    OpMethod::Property properties =
+        body ? OpMethod::MP_Static : OpMethod::MP_StaticDeclaration;
+    auto *method =
+        opClass.addMethodAndPrune("void", "build", properties, paramStr);
+
+    FmtContext fctx;
+    fctx.withBuilder(odsBuilder);
+    fctx.addSubst("_state", builderOpState);
+    if (body)
+      method->body() << tgfmt(*body, &fctx);
   }
 
   // Generate default builders that requires all result type, operands, and
   // attributes as parameters.
+  if (op.skipDefaultBuilders())
+    return;
 
   // We generate three classes of builders here:
   // 1. one having a stand-alone parameter for each operand / attribute, and


        


More information about the llvm-branch-commits mailing list