[Mlir-commits] [mlir] fefe436 - [mlir] Use ValueRange instead of ArrayRef<Value>

Jacques Pienaar llvmlistbot at llvm.org
Thu May 28 09:05:37 PDT 2020


Author: Jacques Pienaar
Date: 2020-05-28T09:05:24-07:00
New Revision: fefe4366c3bdd03552c448972930a0f7df328c24

URL: https://github.com/llvm/llvm-project/commit/fefe4366c3bdd03552c448972930a0f7df328c24
DIFF: https://github.com/llvm/llvm-project/commit/fefe4366c3bdd03552c448972930a0f7df328c24.diff

LOG: [mlir] Use ValueRange instead of ArrayRef<Value>

This allows constructing operand adaptor from existing op (useful for commonalizing verification as I want to do in a follow up).

I also add ability to use member initializers for the generated adaptor constructors for convenience.

Differential Revision: https://reviews.llvm.org/D80667

Added: 
    

Modified: 
    mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
    mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h
    mlir/include/mlir/TableGen/OpClass.h
    mlir/include/mlir/TableGen/Operator.h
    mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
    mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
    mlir/lib/TableGen/OpClass.cpp
    mlir/lib/TableGen/Operator.cpp
    mlir/test/mlir-tblgen/op-decl.td
    mlir/test/mlir-tblgen/op-operand.td
    mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
index 2eae578fc966..c241de6ff6fe 100644
--- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
+++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
@@ -438,12 +438,12 @@ class ConvertToLLVMPattern : public ConversionPattern {
   // This is a strided getElementPtr variant that linearizes subscripts as:
   //   `base_offset + index_0 * stride_0 + ... + index_n * stride_n`.
   Value getStridedElementPtr(Location loc, Type elementTypePtr,
-                             Value descriptor, ArrayRef<Value> indices,
+                             Value descriptor, ValueRange indices,
                              ArrayRef<int64_t> strides, int64_t offset,
                              ConversionPatternRewriter &rewriter) const;
 
   Value getDataPtr(Location loc, MemRefType type, Value memRefDesc,
-                   ArrayRef<Value> indices, ConversionPatternRewriter &rewriter,
+                   ValueRange indices, ConversionPatternRewriter &rewriter,
                    llvm::Module &module) const;
 
 protected:

diff  --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h
index 1fa668d7ddc0..f0a429941fb3 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h
@@ -124,7 +124,7 @@ Value getBuiltinVariableValue(Operation *op, BuiltIn builtin,
 // with AffineMap that has static strides. Extend to handle dynamic strides.
 spirv::AccessChainOp getElementPtr(SPIRVTypeConverter &typeConverter,
                                    MemRefType baseType, Value basePtr,
-                                   ArrayRef<Value> indices, Location loc,
+                                   ValueRange indices, Location loc,
                                    OpBuilder &builder);
 
 /// Sets the InterfaceVarABIAttr and EntryPointABIAttr for a function and its

diff  --git a/mlir/include/mlir/TableGen/OpClass.h b/mlir/include/mlir/TableGen/OpClass.h
index e8f73c605dfd..694fed767e33 100644
--- a/mlir/include/mlir/TableGen/OpClass.h
+++ b/mlir/include/mlir/TableGen/OpClass.h
@@ -86,6 +86,7 @@ class OpMethod {
 
   OpMethod(StringRef retType, StringRef name, StringRef params,
            Property property, bool declOnly);
+  virtual ~OpMethod() = default;
 
   OpMethodBody &body();
 
@@ -96,13 +97,13 @@ class OpMethod {
   bool isPrivate() const;
 
   // Writes the method as a declaration to the given `os`.
-  void writeDeclTo(raw_ostream &os) const;
+  virtual void writeDeclTo(raw_ostream &os) const;
   // Writes the method as a definition to the given `os`. `namePrefix` is the
   // prefix to be prepended to the method name (typically namespaces for
   // qualifying the method definition).
-  void writeDefTo(raw_ostream &os, StringRef namePrefix) const;
+  virtual void writeDefTo(raw_ostream &os, StringRef namePrefix) const;
 
-private:
+protected:
   Property properties;
   // Whether this method only contains a declaration.
   bool isDeclOnly;
@@ -110,6 +111,26 @@ class OpMethod {
   OpMethodBody methodBody;
 };
 
+// Class for holding an op's constructor method for C++ code emission.
+class OpConstructor : public OpMethod {
+public:
+  OpConstructor(StringRef retType, StringRef name, StringRef params,
+                Property property, bool declOnly)
+      : OpMethod(retType, name, params, property, declOnly){};
+
+  // Add member initializer to constructor initializing `name` with `value`.
+  void addMemberInitializer(StringRef name, StringRef value);
+
+  // Writes the method as a definition to the given `os`. `namePrefix` is the
+  // prefix to be prepended to the method name (typically namespaces for
+  // qualifying the method definition).
+  void writeDefTo(raw_ostream &os, StringRef namePrefix) const override;
+
+private:
+  // Member initializers.
+  std::string memberInitializers;
+};
+
 // A class used to emit C++ classes from Tablegen.  Contains a list of public
 // methods and a list of private fields to be emitted.
 class Class {
@@ -121,7 +142,7 @@ class Class {
                       OpMethod::Property = OpMethod::MP_None,
                       bool declOnly = false);
 
-  OpMethod &newConstructor(StringRef params = "", bool declOnly = false);
+  OpConstructor &newConstructor(StringRef params = "", bool declOnly = false);
 
   // Creates a new field in this class.
   void newField(StringRef type, StringRef name, StringRef defaultValue = "");
@@ -136,6 +157,7 @@ class Class {
 
 protected:
   std::string className;
+  SmallVector<OpConstructor, 2> constructors;
   SmallVector<OpMethod, 8> methods;
   SmallVector<std::string, 4> fields;
 };

diff  --git a/mlir/include/mlir/TableGen/Operator.h b/mlir/include/mlir/TableGen/Operator.h
index 040f52314cea..cce754dd3454 100644
--- a/mlir/include/mlir/TableGen/Operator.h
+++ b/mlir/include/mlir/TableGen/Operator.h
@@ -58,6 +58,9 @@ class Operator {
   // Returns this op's C++ class name prefixed with namespaces.
   std::string getQualCppClassName() const;
 
+  // Returns the name of op's adaptor C++ class.
+  std::string getAdaptorName() const;
+
   /// A class used to represent the decorators of an operator variable, i.e.
   /// argument or result.
   struct VariableDecorator {

diff  --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index cbe6da31addf..8cc2315ddd15 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -795,8 +795,8 @@ Value ConvertToLLVMPattern::linearizeSubscripts(
 }
 
 Value ConvertToLLVMPattern::getStridedElementPtr(
-    Location loc, Type elementTypePtr, Value descriptor,
-    ArrayRef<Value> indices, ArrayRef<int64_t> strides, int64_t offset,
+    Location loc, Type elementTypePtr, Value descriptor, ValueRange indices,
+    ArrayRef<int64_t> strides, int64_t offset,
     ConversionPatternRewriter &rewriter) const {
   MemRefDescriptor memRefDescriptor(descriptor);
 
@@ -818,8 +818,7 @@ Value ConvertToLLVMPattern::getStridedElementPtr(
 }
 
 Value ConvertToLLVMPattern::getDataPtr(Location loc, MemRefType type,
-                                       Value memRefDesc,
-                                       ArrayRef<Value> indices,
+                                       Value memRefDesc, ValueRange indices,
                                        ConversionPatternRewriter &rewriter,
                                        llvm::Module &module) const {
   LLVM::LLVMType ptrType = MemRefDescriptor(memRefDesc).getElementType();
@@ -2602,7 +2601,7 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<ViewOp> {
   // Build and return the value for the idx^th shape dimension, either by
   // returning the constant shape dimension or counting the proper dynamic size.
   Value getSize(ConversionPatternRewriter &rewriter, Location loc,
-                ArrayRef<int64_t> shape, ArrayRef<Value> dynamicSizes,
+                ArrayRef<int64_t> shape, ValueRange dynamicSizes,
                 unsigned idx) const {
     assert(idx < shape.size());
     if (!ShapedType::isDynamic(shape[idx]))

diff  --git a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
index dfc2728ef710..6458756dec69 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
@@ -579,7 +579,7 @@ Value mlir::spirv::getBuiltinVariableValue(Operation *op,
 
 spirv::AccessChainOp mlir::spirv::getElementPtr(
     SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr,
-    ArrayRef<Value> indices, Location loc, OpBuilder &builder) {
+    ValueRange indices, Location loc, OpBuilder &builder) {
   // Get base and offset of the MemRefType and verify they are static.
 
   int64_t offset;
@@ -591,6 +591,7 @@ spirv::AccessChainOp mlir::spirv::getElementPtr(
   }
 
   auto indexType = typeConverter.getIndexType(builder.getContext());
+
   SmallVector<Value, 2> linearizedIndices;
   // Add a '0' at the start to index into the struct.
   auto zero = spirv::ConstantOp::getZero(indexType, loc, builder);
@@ -606,7 +607,7 @@ spirv::AccessChainOp mlir::spirv::getElementPtr(
         loc, indexType, IntegerAttr::get(indexType, offset));
     assert(indices.size() == strides.size() &&
            "must provide indices for all dimensions");
-    for (auto index : enumerate(indices)) {
+    for (auto index : llvm::enumerate(indices)) {
       Value strideVal = builder.create<spirv::ConstantOp>(
           loc, indexType, IntegerAttr::get(indexType, strides[index.index()]));
       Value update =

diff  --git a/mlir/lib/TableGen/OpClass.cpp b/mlir/lib/TableGen/OpClass.cpp
index bfdcbdc344a3..43bbe2420a9a 100644
--- a/mlir/lib/TableGen/OpClass.cpp
+++ b/mlir/lib/TableGen/OpClass.cpp
@@ -119,6 +119,27 @@ void tblgen::OpMethod::writeDefTo(raw_ostream &os, StringRef namePrefix) const {
   os << "}";
 }
 
+//===----------------------------------------------------------------------===//
+// OpConstructor definitions
+//===----------------------------------------------------------------------===//
+
+void mlir::tblgen::OpConstructor::addMemberInitializer(StringRef name,
+                                                       StringRef value) {
+  memberInitializers.append(std::string(llvm::formatv(
+      "{0}{1}({2})", memberInitializers.empty() ? " : " : ", ", name, value)));
+}
+
+void mlir::tblgen::OpConstructor::writeDefTo(raw_ostream &os,
+                                             StringRef namePrefix) const {
+  if (isDeclOnly)
+    return;
+
+  methodSignature.writeDefTo(os, namePrefix);
+  os << " " << memberInitializers << " {\n";
+  methodBody.writeTo(os);
+  os << "}";
+}
+
 //===----------------------------------------------------------------------===//
 // Class definitions
 //===----------------------------------------------------------------------===//
@@ -133,10 +154,11 @@ tblgen::OpMethod &tblgen::Class::newMethod(StringRef retType, StringRef name,
   return methods.back();
 }
 
-tblgen::OpMethod &tblgen::Class::newConstructor(StringRef params,
-                                                bool declOnly) {
-  return newMethod("", getClassName(), params, OpMethod::MP_Constructor,
-                   declOnly);
+tblgen::OpConstructor &tblgen::Class::newConstructor(StringRef params,
+                                                     bool declOnly) {
+  constructors.emplace_back("", getClassName(), params,
+                            OpMethod::MP_Constructor, declOnly);
+  return constructors.back();
 }
 
 void tblgen::Class::newField(StringRef type, StringRef name,
@@ -152,7 +174,8 @@ void tblgen::Class::writeDeclTo(raw_ostream &os) const {
   bool hasPrivateMethod = false;
   os << "class " << className << " {\n";
   os << "public:\n";
-  for (const auto &method : methods) {
+  for (const auto &method :
+       llvm::concat<const OpMethod>(constructors, methods)) {
     if (!method.isPrivate()) {
       method.writeDeclTo(os);
       os << '\n';
@@ -163,7 +186,8 @@ void tblgen::Class::writeDeclTo(raw_ostream &os) const {
   os << '\n';
   os << "private:\n";
   if (hasPrivateMethod) {
-    for (const auto &method : methods) {
+    for (const auto &method :
+         llvm::concat<const OpMethod>(constructors, methods)) {
       if (method.isPrivate()) {
         method.writeDeclTo(os);
         os << '\n';
@@ -177,7 +201,8 @@ void tblgen::Class::writeDeclTo(raw_ostream &os) const {
 }
 
 void tblgen::Class::writeDefTo(raw_ostream &os) const {
-  for (const auto &method : methods) {
+  for (const auto &method :
+       llvm::concat<const OpMethod>(constructors, methods)) {
     method.writeDefTo(os, className);
     os << "\n\n";
   }

diff  --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp
index 2f77184980e2..f575fedc1f24 100644
--- a/mlir/lib/TableGen/Operator.cpp
+++ b/mlir/lib/TableGen/Operator.cpp
@@ -59,6 +59,10 @@ std::string tblgen::Operator::getOperationName() const {
   return std::string(llvm::formatv("{0}.{1}", prefix, opName));
 }
 
+std::string tblgen::Operator::getAdaptorName() const {
+  return std::string(llvm::formatv("{0}OperandAdaptor", getCppClassName()));
+}
+
 StringRef tblgen::Operator::getDialectName() const { return dialect.getName(); }
 
 StringRef tblgen::Operator::getCppClassName() const { return cppClassName; }

diff  --git a/mlir/test/mlir-tblgen/op-decl.td b/mlir/test/mlir-tblgen/op-decl.td
index 565f1921125a..a101103b08fc 100644
--- a/mlir/test/mlir-tblgen/op-decl.td
+++ b/mlir/test/mlir-tblgen/op-decl.td
@@ -49,14 +49,14 @@ def NS_AOp : NS_Op<"a_op", [IsolatedFromAbove, IsolatedFromAbove]> {
 
 // CHECK: class AOpOperandAdaptor {
 // CHECK: public:
-// CHECK:   AOpOperandAdaptor(ArrayRef<Value> values
-// CHECK:   ArrayRef<Value> getODSOperands(unsigned index);
+// CHECK:   AOpOperandAdaptor(ValueRange values
+// CHECK:   ValueRange getODSOperands(unsigned index);
 // CHECK:   Value a();
-// CHECK:   ArrayRef<Value> b();
+// CHECK:   ValueRange b();
 // CHECK:   IntegerAttr attr1();
 // CHECL:   FloatAttr attr2();
 // CHECK: private:
-// CHECK:   ArrayRef<Value> odsOperands;
+// CHECK:   ValueRange odsOperands;
 // CHECK: };
 
 // CHECK: class AOp : public Op<AOp, OpTrait::AtLeastNRegions<1>::Impl, OpTrait::AtLeastNResults<1>::Impl, OpTrait::ZeroSuccessor, OpTrait::AtLeastNOperands<1>::Impl, OpTrait::IsIsolatedFromAbove
@@ -106,12 +106,12 @@ def NS_AttrSizedOperandOp : NS_Op<"attr_sized_operands",
 }
 
 // CHECK-LABEL: AttrSizedOperandOpOperandAdaptor(
-// CHECK-SAME:    ArrayRef<Value> values
+// CHECK-SAME:    ValueRange values
 // CHECK-SAME:    DictionaryAttr attrs
-// CHECK:  ArrayRef<Value> a();
-// CHECK:  ArrayRef<Value> b();
+// CHECK:  ValueRange a();
+// CHECK:  ValueRange b();
 // CHECK:  Value c();
-// CHECK:  ArrayRef<Value> d();
+// CHECK:  ValueRange d();
 // CHECK:  DenseIntElementsAttr operand_segment_sizes();
 
 // Check op trait for 
diff erent number of operands

diff  --git a/mlir/test/mlir-tblgen/op-operand.td b/mlir/test/mlir-tblgen/op-operand.td
index 5f0bfae92812..a9b61c179be0 100644
--- a/mlir/test/mlir-tblgen/op-operand.td
+++ b/mlir/test/mlir-tblgen/op-operand.td
@@ -15,7 +15,7 @@ def OpA : NS_Op<"one_normal_operand_op", []> {
 // CHECK-LABEL: OpA definitions
 
 // CHECK:      OpAOperandAdaptor::OpAOperandAdaptor
-// CHECK-NEXT: odsOperands = values
+// CHECK-SAME: odsOperands(values), odsAttrs(attrs)
 
 // CHECK:      void OpA::build
 // CHECK:        Value input
@@ -39,13 +39,13 @@ def OpD : NS_Op<"mix_variadic_and_normal_inputs_op", [SameVariadicOperandSize]>
   let arguments = (ins Variadic<AnyTensor>:$input1, AnyTensor:$input2, Variadic<AnyTensor>:$input3);
 }
 
-// CHECK-LABEL: ArrayRef<Value> OpDOperandAdaptor::input1
+// CHECK-LABEL: ValueRange OpDOperandAdaptor::input1
 // CHECK-NEXT:    return getODSOperands(0);
 
 // CHECK-LABEL: Value OpDOperandAdaptor::input2
 // CHECK-NEXT:    return *getODSOperands(1).begin();
 
-// CHECK-LABEL: ArrayRef<Value> OpDOperandAdaptor::input3
+// CHECK-LABEL: ValueRange OpDOperandAdaptor::input3
 // CHECK-NEXT:    return getODSOperands(2);
 
 // CHECK-LABEL: Operation::operand_range OpD::input1

diff  --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 0b55825d1a46..7b0cd9d7a482 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -1890,27 +1890,38 @@ class OpOperandAdaptorEmitter {
 private:
   explicit OpOperandAdaptorEmitter(const Operator &op);
 
-  Class adapterClass;
+  Class adaptor;
 };
 } // end namespace
 
 OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(const Operator &op)
-    : adapterClass(op.getCppClassName().str() + "OperandAdaptor") {
-  adapterClass.newField("ArrayRef<Value>", "odsOperands");
-  adapterClass.newField("DictionaryAttr", "odsAttrs");
+    : adaptor(op.getAdaptorName()) {
+  adaptor.newField("ValueRange", "odsOperands");
+  adaptor.newField("DictionaryAttr", "odsAttrs");
   const auto *attrSizedOperands =
       op.getTrait("OpTrait::AttrSizedOperandSegments");
-  auto &constructor = adapterClass.newConstructor(
-      attrSizedOperands
-          ? "ArrayRef<Value> values, DictionaryAttr attrs"
-          : "ArrayRef<Value> values, DictionaryAttr attrs = nullptr");
-  constructor.body() << "  odsOperands = values;\n";
-  constructor.body() << "  odsAttrs = attrs;\n";
+  {
+    auto &constructor = adaptor.newConstructor(
+        attrSizedOperands
+            ? "ValueRange values, DictionaryAttr attrs"
+            : "ValueRange values, DictionaryAttr attrs = nullptr");
+    constructor.addMemberInitializer("odsOperands", "values");
+    constructor.addMemberInitializer("odsAttrs", "attrs");
+  }
+
+  {
+    auto &constructor = adaptor.newConstructor(
+        llvm::formatv("{0}& op", op.getCppClassName()).str());
+    constructor.addMemberInitializer("odsOperands",
+                                     "op.getOperation()->getOperands()");
+    constructor.addMemberInitializer("odsAttrs",
+                                     "op.getOperation()->getAttrDictionary()");
+  }
 
   std::string sizeAttrInit =
       formatv(adapterSegmentSizeAttrInitCode, "operand_segment_sizes");
-  generateNamedOperandGetters(op, adapterClass, sizeAttrInit,
-                              /*rangeType=*/"ArrayRef<Value>",
+  generateNamedOperandGetters(op, adaptor, sizeAttrInit,
+                              /*rangeType=*/"ValueRange",
                               /*rangeBeginCall=*/"odsOperands.begin()",
                               /*rangeSizeCall=*/"odsOperands.size()",
                               /*getOperandCallPattern=*/"odsOperands[{0}]");
@@ -1919,7 +1930,7 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(const Operator &op)
   fctx.withBuilder("mlir::Builder(odsAttrs.getContext())");
 
   auto emitAttr = [&](StringRef name, Attribute attr) {
-    auto &body = adapterClass.newMethod(attr.getStorageType(), name).body();
+    auto &body = adaptor.newMethod(attr.getStorageType(), name).body();
     body << "  assert(odsAttrs && \"no attributes when constructing adapter\");"
          << "\n  " << attr.getStorageType() << " attr = "
          << "odsAttrs.get(\"" << name << "\").";
@@ -1949,11 +1960,11 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(const Operator &op)
 }
 
 void OpOperandAdaptorEmitter::emitDecl(const Operator &op, raw_ostream &os) {
-  OpOperandAdaptorEmitter(op).adapterClass.writeDeclTo(os);
+  OpOperandAdaptorEmitter(op).adaptor.writeDeclTo(os);
 }
 
 void OpOperandAdaptorEmitter::emitDef(const Operator &op, raw_ostream &os) {
-  OpOperandAdaptorEmitter(op).adapterClass.writeDefTo(os);
+  OpOperandAdaptorEmitter(op).adaptor.writeDefTo(os);
 }
 
 // Emits the opcode enum and op classes.


        


More information about the Mlir-commits mailing list