[Mlir-commits] [mlir] 4b8632e - [mlir] Expand operand adapter to take attributes

Jacques Pienaar llvmlistbot at llvm.org
Sun May 24 21:07:38 PDT 2020


Author: Jacques Pienaar
Date: 2020-05-24T21:06:47-07:00
New Revision: 4b8632e174d5ba79c4858a1245b96efd3ed281fb

URL: https://github.com/llvm/llvm-project/commit/4b8632e174d5ba79c4858a1245b96efd3ed281fb
DIFF: https://github.com/llvm/llvm-project/commit/4b8632e174d5ba79c4858a1245b96efd3ed281fb.diff

LOG: [mlir] Expand operand adapter to take attributes

* Enables using with more variadic sized operands;
* Generate convenience accessors for attributes;
  - The accessor are named the same as their name in ODS and returns attribute
    type (not convenience type) and no derived attributes.

This is first step to changing adapter to support verifying argument
constraints before the op is even created. This does not change the name of
adaptor nor does it require it except for ops with variadic operands to keep this change smaller.

Considered creating separate adapter but decided against that given operands also require attributes in general (and definitely for verification of operands and attributes).

Differential Revision: https://reviews.llvm.org/D80420

Added: 
    

Modified: 
    mlir/include/mlir/TableGen/OpClass.h
    mlir/lib/TableGen/OpClass.cpp
    mlir/test/mlir-tblgen/op-decl.td
    mlir/test/mlir-tblgen/op-operand.td
    mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/TableGen/OpClass.h b/mlir/include/mlir/TableGen/OpClass.h
index 8788a505a4b3..e8f73c605dfd 100644
--- a/mlir/include/mlir/TableGen/OpClass.h
+++ b/mlir/include/mlir/TableGen/OpClass.h
@@ -145,10 +145,6 @@ class OpClass : public Class {
 public:
   explicit OpClass(StringRef name, StringRef extraClassDeclaration = "");
 
-  // Sets whether this OpClass should generate the using directive for its
-  // associate operand adaptor class.
-  void setHasOperandAdaptorClass(bool has);
-
   // Adds an op trait.
   void addTrait(Twine trait);
 
@@ -160,7 +156,6 @@ class OpClass : public Class {
   StringRef extraClassDeclaration;
   SmallVector<std::string, 4> traitsVec;
   StringSet<> traitsSet;
-  bool hasOperandAdaptor;
 };
 
 } // namespace tblgen

diff  --git a/mlir/lib/TableGen/OpClass.cpp b/mlir/lib/TableGen/OpClass.cpp
index 26519df72534..bfdcbdc344a3 100644
--- a/mlir/lib/TableGen/OpClass.cpp
+++ b/mlir/lib/TableGen/OpClass.cpp
@@ -188,12 +188,7 @@ void tblgen::Class::writeDefTo(raw_ostream &os) const {
 //===----------------------------------------------------------------------===//
 
 tblgen::OpClass::OpClass(StringRef name, StringRef extraClassDeclaration)
-    : Class(name), extraClassDeclaration(extraClassDeclaration),
-      hasOperandAdaptor(true) {}
-
-void tblgen::OpClass::setHasOperandAdaptorClass(bool has) {
-  hasOperandAdaptor = has;
-}
+    : Class(name), extraClassDeclaration(extraClassDeclaration) {}
 
 void tblgen::OpClass::addTrait(Twine trait) {
   auto traitStr = trait.str();
@@ -207,8 +202,7 @@ void tblgen::OpClass::writeDeclTo(raw_ostream &os) const {
     os << ", " << trait;
   os << "> {\npublic:\n";
   os << "  using Op::Op;\n";
-  if (hasOperandAdaptor)
-    os << "  using OperandAdaptor = " << className << "OperandAdaptor;\n";
+  os << "  using OperandAdaptor = " << className << "OperandAdaptor;\n";
 
   bool hasPrivateMethod = false;
   for (const auto &method : methods) {

diff  --git a/mlir/test/mlir-tblgen/op-decl.td b/mlir/test/mlir-tblgen/op-decl.td
index 0b9bac2ecb4c..c68d03c96b30 100644
--- a/mlir/test/mlir-tblgen/op-decl.td
+++ b/mlir/test/mlir-tblgen/op-decl.td
@@ -50,12 +50,14 @@ def NS_AOp : NS_Op<"a_op", [IsolatedFromAbove, IsolatedFromAbove]> {
 
 // CHECK: class AOpOperandAdaptor {
 // CHECK: public:
-// CHECK:   AOpOperandAdaptor(ArrayRef<Value> values);
+// CHECK:   AOpOperandAdaptor(ArrayRef<Value> values
 // CHECK:   ArrayRef<Value> getODSOperands(unsigned index);
 // CHECK:   Value a();
 // CHECK:   ArrayRef<Value> b();
+// CHECK:   IntegerAttr attr1();
+// CHECL:   FloatAttr attr2();
 // CHECK: private:
-// CHECK:   ArrayRef<Value> tblgen_operands;
+// CHECK:   ArrayRef<Value> odsOperands;
 // CHECK: };
 
 // CHECK: class AOp : public Op<AOp, OpTrait::AtLeastNRegions<1>::Impl, OpTrait::AtLeastNResults<1>::Impl, OpTrait::ZeroSuccessor, OpTrait::AtLeastNOperands<1>::Impl, OpTrait::IsIsolatedFromAbove
@@ -90,6 +92,29 @@ def NS_AOp : NS_Op<"a_op", [IsolatedFromAbove, IsolatedFromAbove]> {
 // CHECK:   void displayGraph();
 // CHECK: };
 
+// Check AttrSizedOperandSegments
+// ---
+
+def NS_AttrSizedOperandOp : NS_Op<"attr_sized_operands",
+                                 [AttrSizedOperandSegments]> {
+  let arguments = (ins
+    Variadic<I32>:$a,
+    Variadic<I32>:$b,
+    I32:$c,
+    Variadic<I32>:$d,
+    I32ElementsAttr:$operand_segment_sizes
+  );
+}
+
+// CHECK-LABEL: AttrSizedOperandOpOperandAdaptor(
+// CHECK-SAME:    ArrayRef<Value> values
+// CHECK-SAME:    DictionaryAttr attrs
+// CHECK:  ArrayRef<Value> a();
+// CHECK:  ArrayRef<Value> b();
+// CHECK:  Value c();
+// CHECK:  ArrayRef<Value> d();
+// CHECK:  DenseIntElementsAttr operand_segment_sizes();
+
 // Check op trait for 
diff erent number of operands
 // ---
 
@@ -150,3 +175,4 @@ def _BOp : NS_Op<"_op_with_leading_underscore_and_no_namespace", []>;
 
 // CHECK-LABEL: _BOp declarations
 // CHECK: class _BOp : public Op<_BOp
+

diff  --git a/mlir/test/mlir-tblgen/op-operand.td b/mlir/test/mlir-tblgen/op-operand.td
index 2ffde33c5331..5f0bfae92812 100644
--- a/mlir/test/mlir-tblgen/op-operand.td
+++ b/mlir/test/mlir-tblgen/op-operand.td
@@ -15,7 +15,7 @@ def OpA : NS_Op<"one_normal_operand_op", []> {
 // CHECK-LABEL: OpA definitions
 
 // CHECK:      OpAOperandAdaptor::OpAOperandAdaptor
-// CHECK-NEXT: tblgen_operands = values
+// CHECK-NEXT: odsOperands = values
 
 // CHECK:      void OpA::build
 // CHECK:        Value input

diff  --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 8709760e2c6b..2010262f2185 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -70,13 +70,19 @@ const char *sameVariadicSizeValueRangeCalcCode = R"(
 // (variadic or not).
 //
 // {0}: The name of the attribute specifying the segment sizes.
-const char *attrSizedSegmentValueRangeCalcCode = R"(
+const char *adapterSegmentSizeAttrInitCode = R"(
+  assert(odsAttrs && "missing segment size attribute for op");
+  auto sizeAttr = odsAttrs.get("{0}").cast<DenseIntElementsAttr>();
+)";
+const char *opSegmentSizeAttrInitCode = R"(
   auto sizeAttr = getAttrOfType<DenseIntElementsAttr>("{0}");
+)";
+const char *attrSizedSegmentValueRangeCalcCode = R"(
   unsigned start = 0;
   for (unsigned i = 0; i < index; ++i)
     start += (*(sizeAttr.begin() + i)).getZExtValue();
   unsigned size = (*(sizeAttr.begin() + index)).getZExtValue();
-  return {{start, size};
+  return {start, size};
 )";
 
 // The logic to build a range of either operand or result values.
@@ -496,15 +502,14 @@ static void
 generateValueRangeStartAndEnd(Class &opClass, StringRef methodName,
                               int numVariadic, int numNonVariadic,
                               StringRef rangeSizeCall, bool hasAttrSegmentSize,
-                              StringRef segmentSizeAttr, RangeT &&odsValues) {
+                              StringRef sizeAttrInit, RangeT &&odsValues) {
   auto &method = opClass.newMethod("std::pair<unsigned, unsigned>", methodName,
                                    "unsigned index");
 
   if (numVariadic == 0) {
     method.body() << "  return {index, 1};\n";
   } else if (hasAttrSegmentSize) {
-    method.body() << formatv(attrSizedSegmentValueRangeCalcCode,
-                             segmentSizeAttr);
+    method.body() << sizeAttrInit << attrSizedSegmentValueRangeCalcCode;
   } else {
     // Because the op can have arbitrarily interleaved variadic and non-variadic
     // operands, we need to embed a list in the "sink" getter method for
@@ -532,6 +537,7 @@ generateValueRangeStartAndEnd(Class &opClass, StringRef methodName,
 // of ops, in particular for one-operand ops that may not have the
 // `getOperand(unsigned)` method.
 static void generateNamedOperandGetters(const Operator &op, Class &opClass,
+                                        StringRef sizeAttrInit,
                                         StringRef rangeType,
                                         StringRef rangeBeginCall,
                                         StringRef rangeSizeCall,
@@ -563,10 +569,10 @@ static void generateNamedOperandGetters(const Operator &op, Class &opClass,
 
   // First emit a few "sink" getter methods upon which we layer all nicer named
   // getter methods.
-  generateValueRangeStartAndEnd(
-      opClass, "getODSOperandIndexAndLength", numVariadicOperands,
-      numNormalOperands, rangeSizeCall, attrSizedOperands,
-      "operand_segment_sizes", const_cast<Operator &>(op).getOperands());
+  generateValueRangeStartAndEnd(opClass, "getODSOperandIndexAndLength",
+                                numVariadicOperands, numNormalOperands,
+                                rangeSizeCall, attrSizedOperands, sizeAttrInit,
+                                const_cast<Operator &>(op).getOperands());
 
   auto &m = opClass.newMethod(rangeType, "getODSOperands", "unsigned index");
   m.body() << formatv(valueRangeReturnCode, rangeBeginCall,
@@ -574,7 +580,6 @@ static void generateNamedOperandGetters(const Operator &op, Class &opClass,
 
   // Then we emit nicer named getter methods by redirecting to the "sink" getter
   // method.
-
   for (int i = 0; i != numOperands; ++i) {
     const auto &operand = op.getOperand(i);
     if (operand.name.empty())
@@ -595,11 +600,11 @@ static void generateNamedOperandGetters(const Operator &op, Class &opClass,
 }
 
 void OpEmitter::genNamedOperandGetters() {
-  if (op.getTrait("OpTrait::AttrSizedOperandSegments"))
-    opClass.setHasOperandAdaptorClass(false);
-
   generateNamedOperandGetters(
-      op, opClass, /*rangeType=*/"Operation::operand_range",
+      op, opClass,
+      /*sizeAttrInit=*/
+      formatv(opSegmentSizeAttrInitCode, "operand_segment_sizes").str(),
+      /*rangeType=*/"Operation::operand_range",
       /*rangeBeginCall=*/"getOperation()->operand_begin()",
       /*rangeSizeCall=*/"getOperation()->getNumOperands()",
       /*getOperandCallPattern=*/"getOperation()->getOperand({0})");
@@ -656,7 +661,8 @@ void OpEmitter::genNamedResultGetters() {
   generateValueRangeStartAndEnd(
       opClass, "getODSResultIndexAndLength", numVariadicResults,
       numNormalResults, "getOperation()->getNumResults()", attrSizedResults,
-      "result_segment_sizes", op.getResults());
+      formatv(opSegmentSizeAttrInitCode, "result_segment_sizes").str(),
+      op.getResults());
   auto &m = opClass.newMethod("Operation::result_range", "getODSResults",
                               "unsigned index");
   m.body() << formatv(valueRangeReturnCode, "getOperation()->result_begin()",
@@ -1840,15 +1846,56 @@ class OpOperandAdaptorEmitter {
 
 OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(const Operator &op)
     : adapterClass(op.getCppClassName().str() + "OperandAdaptor") {
-  adapterClass.newField("ArrayRef<Value>", "tblgen_operands");
-  auto &constructor = adapterClass.newConstructor("ArrayRef<Value> values");
-  constructor.body() << "  tblgen_operands = values;\n";
-
-  generateNamedOperandGetters(op, adapterClass,
+  adapterClass.newField("ArrayRef<Value>", "odsOperands");
+  adapterClass.newField("DictionaryAttr", "odsAttrs");
+  const auto *attrSizedOperands =
+      op.getTrait("OpTrait::AttrSizedOperandSegments");
+  auto &constructor = adapterClass.newConstructor(
+      attrSizedOperands
+          ? "ArrayRef<Value> values, DictionaryAttr attrs"
+          : "ArrayRef<Value> values, DictionaryAttr attrs = nullptr");
+  constructor.body() << "  odsOperands = values;\n";
+  constructor.body() << "  odsAttrs = attrs;\n";
+
+  std::string sizeAttrInit =
+      formatv(adapterSegmentSizeAttrInitCode, "operand_segment_sizes");
+  generateNamedOperandGetters(op, adapterClass, sizeAttrInit,
                               /*rangeType=*/"ArrayRef<Value>",
-                              /*rangeBeginCall=*/"tblgen_operands.begin()",
-                              /*rangeSizeCall=*/"tblgen_operands.size()",
-                              /*getOperandCallPattern=*/"tblgen_operands[{0}]");
+                              /*rangeBeginCall=*/"odsOperands.begin()",
+                              /*rangeSizeCall=*/"odsOperands.size()",
+                              /*getOperandCallPattern=*/"odsOperands[{0}]");
+
+  FmtContext fctx;
+  fctx.withBuilder("mlir::Builder(odsAttrs.getContext())");
+
+  auto emitAttr = [&](StringRef name, Attribute attr) {
+    auto &body = adapterClass.newMethod(attr.getStorageType(), name).body();
+    body << "  assert(odsAttrs && \"no attributes when constructing adapter\");"
+         << "\n  " << attr.getStorageType() << " attr = "
+         << "odsAttrs.get(\"" << name << "\").";
+    if (attr.hasDefaultValue() || attr.isOptional())
+      body << "dyn_cast_or_null<";
+    else
+      body << "cast<";
+    body << attr.getStorageType() << ">();\n";
+
+    if (attr.hasDefaultValue()) {
+      // Use the default value if attribute is not set.
+      // TODO: this is inefficient, we are recreating the attribute for every
+      // call. This should be set instead.
+      std::string defaultValue = std::string(
+          tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue()));
+      body << "  if (!attr)\n    attr = " << defaultValue << ";\n";
+    }
+    body << "  return attr;\n";
+  };
+
+  for (auto &namedAttr : op.getAttributes()) {
+    const auto &name = namedAttr.name;
+    const auto &attr = namedAttr.attr;
+    if (!attr.isDerivedAttr())
+      emitAttr(name, attr);
+  }
 }
 
 void OpOperandAdaptorEmitter::emitDecl(const Operator &op, raw_ostream &os) {
@@ -1873,19 +1920,13 @@ static void emitOpClasses(const std::vector<Record *> &defs, raw_ostream &os,
   }
   for (auto *def : defs) {
     Operator op(*def);
-    const auto *attrSizedOperands =
-        op.getTrait("OpTrait::AttrSizedOperandSegments");
     if (emitDecl) {
       os << formatv(opCommentHeader, op.getQualCppClassName(), "declarations");
-      // We cannot generate the operand adaptor class if operand getters depend
-      // on an attribute.
-      if (!attrSizedOperands)
-        OpOperandAdaptorEmitter::emitDecl(op, os);
+      OpOperandAdaptorEmitter::emitDecl(op, os);
       OpEmitter::emitDecl(op, os);
     } else {
       os << formatv(opCommentHeader, op.getQualCppClassName(), "definitions");
-      if (!attrSizedOperands)
-        OpOperandAdaptorEmitter::emitDef(op, os);
+      OpOperandAdaptorEmitter::emitDef(op, os);
       OpEmitter::emitDef(op, os);
     }
   }


        


More information about the Mlir-commits mailing list