[Mlir-commits] [mlir] fefe436 - [mlir] Use ValueRange instead of ArrayRef<Value>
Jacques Pienaar
llvmlistbot at llvm.org
Thu May 28 09:05:37 PDT 2020
Author: Jacques Pienaar
Date: 2020-05-28T09:05:24-07:00
New Revision: fefe4366c3bdd03552c448972930a0f7df328c24
URL: https://github.com/llvm/llvm-project/commit/fefe4366c3bdd03552c448972930a0f7df328c24
DIFF: https://github.com/llvm/llvm-project/commit/fefe4366c3bdd03552c448972930a0f7df328c24.diff
LOG: [mlir] Use ValueRange instead of ArrayRef<Value>
This allows constructing operand adaptor from existing op (useful for commonalizing verification as I want to do in a follow up).
I also add ability to use member initializers for the generated adaptor constructors for convenience.
Differential Revision: https://reviews.llvm.org/D80667
Added:
Modified:
mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h
mlir/include/mlir/TableGen/OpClass.h
mlir/include/mlir/TableGen/Operator.h
mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
mlir/lib/TableGen/OpClass.cpp
mlir/lib/TableGen/Operator.cpp
mlir/test/mlir-tblgen/op-decl.td
mlir/test/mlir-tblgen/op-operand.td
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
index 2eae578fc966..c241de6ff6fe 100644
--- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
+++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
@@ -438,12 +438,12 @@ class ConvertToLLVMPattern : public ConversionPattern {
// This is a strided getElementPtr variant that linearizes subscripts as:
// `base_offset + index_0 * stride_0 + ... + index_n * stride_n`.
Value getStridedElementPtr(Location loc, Type elementTypePtr,
- Value descriptor, ArrayRef<Value> indices,
+ Value descriptor, ValueRange indices,
ArrayRef<int64_t> strides, int64_t offset,
ConversionPatternRewriter &rewriter) const;
Value getDataPtr(Location loc, MemRefType type, Value memRefDesc,
- ArrayRef<Value> indices, ConversionPatternRewriter &rewriter,
+ ValueRange indices, ConversionPatternRewriter &rewriter,
llvm::Module &module) const;
protected:
diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h
index 1fa668d7ddc0..f0a429941fb3 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h
@@ -124,7 +124,7 @@ Value getBuiltinVariableValue(Operation *op, BuiltIn builtin,
// with AffineMap that has static strides. Extend to handle dynamic strides.
spirv::AccessChainOp getElementPtr(SPIRVTypeConverter &typeConverter,
MemRefType baseType, Value basePtr,
- ArrayRef<Value> indices, Location loc,
+ ValueRange indices, Location loc,
OpBuilder &builder);
/// Sets the InterfaceVarABIAttr and EntryPointABIAttr for a function and its
diff --git a/mlir/include/mlir/TableGen/OpClass.h b/mlir/include/mlir/TableGen/OpClass.h
index e8f73c605dfd..694fed767e33 100644
--- a/mlir/include/mlir/TableGen/OpClass.h
+++ b/mlir/include/mlir/TableGen/OpClass.h
@@ -86,6 +86,7 @@ class OpMethod {
OpMethod(StringRef retType, StringRef name, StringRef params,
Property property, bool declOnly);
+ virtual ~OpMethod() = default;
OpMethodBody &body();
@@ -96,13 +97,13 @@ class OpMethod {
bool isPrivate() const;
// Writes the method as a declaration to the given `os`.
- void writeDeclTo(raw_ostream &os) const;
+ virtual void writeDeclTo(raw_ostream &os) const;
// Writes the method as a definition to the given `os`. `namePrefix` is the
// prefix to be prepended to the method name (typically namespaces for
// qualifying the method definition).
- void writeDefTo(raw_ostream &os, StringRef namePrefix) const;
+ virtual void writeDefTo(raw_ostream &os, StringRef namePrefix) const;
-private:
+protected:
Property properties;
// Whether this method only contains a declaration.
bool isDeclOnly;
@@ -110,6 +111,26 @@ class OpMethod {
OpMethodBody methodBody;
};
+// Class for holding an op's constructor method for C++ code emission.
+class OpConstructor : public OpMethod {
+public:
+ OpConstructor(StringRef retType, StringRef name, StringRef params,
+ Property property, bool declOnly)
+ : OpMethod(retType, name, params, property, declOnly){};
+
+ // Add member initializer to constructor initializing `name` with `value`.
+ void addMemberInitializer(StringRef name, StringRef value);
+
+ // Writes the method as a definition to the given `os`. `namePrefix` is the
+ // prefix to be prepended to the method name (typically namespaces for
+ // qualifying the method definition).
+ void writeDefTo(raw_ostream &os, StringRef namePrefix) const override;
+
+private:
+ // Member initializers.
+ std::string memberInitializers;
+};
+
// A class used to emit C++ classes from Tablegen. Contains a list of public
// methods and a list of private fields to be emitted.
class Class {
@@ -121,7 +142,7 @@ class Class {
OpMethod::Property = OpMethod::MP_None,
bool declOnly = false);
- OpMethod &newConstructor(StringRef params = "", bool declOnly = false);
+ OpConstructor &newConstructor(StringRef params = "", bool declOnly = false);
// Creates a new field in this class.
void newField(StringRef type, StringRef name, StringRef defaultValue = "");
@@ -136,6 +157,7 @@ class Class {
protected:
std::string className;
+ SmallVector<OpConstructor, 2> constructors;
SmallVector<OpMethod, 8> methods;
SmallVector<std::string, 4> fields;
};
diff --git a/mlir/include/mlir/TableGen/Operator.h b/mlir/include/mlir/TableGen/Operator.h
index 040f52314cea..cce754dd3454 100644
--- a/mlir/include/mlir/TableGen/Operator.h
+++ b/mlir/include/mlir/TableGen/Operator.h
@@ -58,6 +58,9 @@ class Operator {
// Returns this op's C++ class name prefixed with namespaces.
std::string getQualCppClassName() const;
+ // Returns the name of op's adaptor C++ class.
+ std::string getAdaptorName() const;
+
/// A class used to represent the decorators of an operator variable, i.e.
/// argument or result.
struct VariableDecorator {
diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index cbe6da31addf..8cc2315ddd15 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -795,8 +795,8 @@ Value ConvertToLLVMPattern::linearizeSubscripts(
}
Value ConvertToLLVMPattern::getStridedElementPtr(
- Location loc, Type elementTypePtr, Value descriptor,
- ArrayRef<Value> indices, ArrayRef<int64_t> strides, int64_t offset,
+ Location loc, Type elementTypePtr, Value descriptor, ValueRange indices,
+ ArrayRef<int64_t> strides, int64_t offset,
ConversionPatternRewriter &rewriter) const {
MemRefDescriptor memRefDescriptor(descriptor);
@@ -818,8 +818,7 @@ Value ConvertToLLVMPattern::getStridedElementPtr(
}
Value ConvertToLLVMPattern::getDataPtr(Location loc, MemRefType type,
- Value memRefDesc,
- ArrayRef<Value> indices,
+ Value memRefDesc, ValueRange indices,
ConversionPatternRewriter &rewriter,
llvm::Module &module) const {
LLVM::LLVMType ptrType = MemRefDescriptor(memRefDesc).getElementType();
@@ -2602,7 +2601,7 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<ViewOp> {
// Build and return the value for the idx^th shape dimension, either by
// returning the constant shape dimension or counting the proper dynamic size.
Value getSize(ConversionPatternRewriter &rewriter, Location loc,
- ArrayRef<int64_t> shape, ArrayRef<Value> dynamicSizes,
+ ArrayRef<int64_t> shape, ValueRange dynamicSizes,
unsigned idx) const {
assert(idx < shape.size());
if (!ShapedType::isDynamic(shape[idx]))
diff --git a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
index dfc2728ef710..6458756dec69 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
@@ -579,7 +579,7 @@ Value mlir::spirv::getBuiltinVariableValue(Operation *op,
spirv::AccessChainOp mlir::spirv::getElementPtr(
SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr,
- ArrayRef<Value> indices, Location loc, OpBuilder &builder) {
+ ValueRange indices, Location loc, OpBuilder &builder) {
// Get base and offset of the MemRefType and verify they are static.
int64_t offset;
@@ -591,6 +591,7 @@ spirv::AccessChainOp mlir::spirv::getElementPtr(
}
auto indexType = typeConverter.getIndexType(builder.getContext());
+
SmallVector<Value, 2> linearizedIndices;
// Add a '0' at the start to index into the struct.
auto zero = spirv::ConstantOp::getZero(indexType, loc, builder);
@@ -606,7 +607,7 @@ spirv::AccessChainOp mlir::spirv::getElementPtr(
loc, indexType, IntegerAttr::get(indexType, offset));
assert(indices.size() == strides.size() &&
"must provide indices for all dimensions");
- for (auto index : enumerate(indices)) {
+ for (auto index : llvm::enumerate(indices)) {
Value strideVal = builder.create<spirv::ConstantOp>(
loc, indexType, IntegerAttr::get(indexType, strides[index.index()]));
Value update =
diff --git a/mlir/lib/TableGen/OpClass.cpp b/mlir/lib/TableGen/OpClass.cpp
index bfdcbdc344a3..43bbe2420a9a 100644
--- a/mlir/lib/TableGen/OpClass.cpp
+++ b/mlir/lib/TableGen/OpClass.cpp
@@ -119,6 +119,27 @@ void tblgen::OpMethod::writeDefTo(raw_ostream &os, StringRef namePrefix) const {
os << "}";
}
+//===----------------------------------------------------------------------===//
+// OpConstructor definitions
+//===----------------------------------------------------------------------===//
+
+void mlir::tblgen::OpConstructor::addMemberInitializer(StringRef name,
+ StringRef value) {
+ memberInitializers.append(std::string(llvm::formatv(
+ "{0}{1}({2})", memberInitializers.empty() ? " : " : ", ", name, value)));
+}
+
+void mlir::tblgen::OpConstructor::writeDefTo(raw_ostream &os,
+ StringRef namePrefix) const {
+ if (isDeclOnly)
+ return;
+
+ methodSignature.writeDefTo(os, namePrefix);
+ os << " " << memberInitializers << " {\n";
+ methodBody.writeTo(os);
+ os << "}";
+}
+
//===----------------------------------------------------------------------===//
// Class definitions
//===----------------------------------------------------------------------===//
@@ -133,10 +154,11 @@ tblgen::OpMethod &tblgen::Class::newMethod(StringRef retType, StringRef name,
return methods.back();
}
-tblgen::OpMethod &tblgen::Class::newConstructor(StringRef params,
- bool declOnly) {
- return newMethod("", getClassName(), params, OpMethod::MP_Constructor,
- declOnly);
+tblgen::OpConstructor &tblgen::Class::newConstructor(StringRef params,
+ bool declOnly) {
+ constructors.emplace_back("", getClassName(), params,
+ OpMethod::MP_Constructor, declOnly);
+ return constructors.back();
}
void tblgen::Class::newField(StringRef type, StringRef name,
@@ -152,7 +174,8 @@ void tblgen::Class::writeDeclTo(raw_ostream &os) const {
bool hasPrivateMethod = false;
os << "class " << className << " {\n";
os << "public:\n";
- for (const auto &method : methods) {
+ for (const auto &method :
+ llvm::concat<const OpMethod>(constructors, methods)) {
if (!method.isPrivate()) {
method.writeDeclTo(os);
os << '\n';
@@ -163,7 +186,8 @@ void tblgen::Class::writeDeclTo(raw_ostream &os) const {
os << '\n';
os << "private:\n";
if (hasPrivateMethod) {
- for (const auto &method : methods) {
+ for (const auto &method :
+ llvm::concat<const OpMethod>(constructors, methods)) {
if (method.isPrivate()) {
method.writeDeclTo(os);
os << '\n';
@@ -177,7 +201,8 @@ void tblgen::Class::writeDeclTo(raw_ostream &os) const {
}
void tblgen::Class::writeDefTo(raw_ostream &os) const {
- for (const auto &method : methods) {
+ for (const auto &method :
+ llvm::concat<const OpMethod>(constructors, methods)) {
method.writeDefTo(os, className);
os << "\n\n";
}
diff --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp
index 2f77184980e2..f575fedc1f24 100644
--- a/mlir/lib/TableGen/Operator.cpp
+++ b/mlir/lib/TableGen/Operator.cpp
@@ -59,6 +59,10 @@ std::string tblgen::Operator::getOperationName() const {
return std::string(llvm::formatv("{0}.{1}", prefix, opName));
}
+std::string tblgen::Operator::getAdaptorName() const {
+ return std::string(llvm::formatv("{0}OperandAdaptor", getCppClassName()));
+}
+
StringRef tblgen::Operator::getDialectName() const { return dialect.getName(); }
StringRef tblgen::Operator::getCppClassName() const { return cppClassName; }
diff --git a/mlir/test/mlir-tblgen/op-decl.td b/mlir/test/mlir-tblgen/op-decl.td
index 565f1921125a..a101103b08fc 100644
--- a/mlir/test/mlir-tblgen/op-decl.td
+++ b/mlir/test/mlir-tblgen/op-decl.td
@@ -49,14 +49,14 @@ def NS_AOp : NS_Op<"a_op", [IsolatedFromAbove, IsolatedFromAbove]> {
// CHECK: class AOpOperandAdaptor {
// CHECK: public:
-// CHECK: AOpOperandAdaptor(ArrayRef<Value> values
-// CHECK: ArrayRef<Value> getODSOperands(unsigned index);
+// CHECK: AOpOperandAdaptor(ValueRange values
+// CHECK: ValueRange getODSOperands(unsigned index);
// CHECK: Value a();
-// CHECK: ArrayRef<Value> b();
+// CHECK: ValueRange b();
// CHECK: IntegerAttr attr1();
// CHECL: FloatAttr attr2();
// CHECK: private:
-// CHECK: ArrayRef<Value> odsOperands;
+// CHECK: ValueRange odsOperands;
// CHECK: };
// CHECK: class AOp : public Op<AOp, OpTrait::AtLeastNRegions<1>::Impl, OpTrait::AtLeastNResults<1>::Impl, OpTrait::ZeroSuccessor, OpTrait::AtLeastNOperands<1>::Impl, OpTrait::IsIsolatedFromAbove
@@ -106,12 +106,12 @@ def NS_AttrSizedOperandOp : NS_Op<"attr_sized_operands",
}
// CHECK-LABEL: AttrSizedOperandOpOperandAdaptor(
-// CHECK-SAME: ArrayRef<Value> values
+// CHECK-SAME: ValueRange values
// CHECK-SAME: DictionaryAttr attrs
-// CHECK: ArrayRef<Value> a();
-// CHECK: ArrayRef<Value> b();
+// CHECK: ValueRange a();
+// CHECK: ValueRange b();
// CHECK: Value c();
-// CHECK: ArrayRef<Value> d();
+// CHECK: ValueRange d();
// CHECK: DenseIntElementsAttr operand_segment_sizes();
// Check op trait for
diff erent number of operands
diff --git a/mlir/test/mlir-tblgen/op-operand.td b/mlir/test/mlir-tblgen/op-operand.td
index 5f0bfae92812..a9b61c179be0 100644
--- a/mlir/test/mlir-tblgen/op-operand.td
+++ b/mlir/test/mlir-tblgen/op-operand.td
@@ -15,7 +15,7 @@ def OpA : NS_Op<"one_normal_operand_op", []> {
// CHECK-LABEL: OpA definitions
// CHECK: OpAOperandAdaptor::OpAOperandAdaptor
-// CHECK-NEXT: odsOperands = values
+// CHECK-SAME: odsOperands(values), odsAttrs(attrs)
// CHECK: void OpA::build
// CHECK: Value input
@@ -39,13 +39,13 @@ def OpD : NS_Op<"mix_variadic_and_normal_inputs_op", [SameVariadicOperandSize]>
let arguments = (ins Variadic<AnyTensor>:$input1, AnyTensor:$input2, Variadic<AnyTensor>:$input3);
}
-// CHECK-LABEL: ArrayRef<Value> OpDOperandAdaptor::input1
+// CHECK-LABEL: ValueRange OpDOperandAdaptor::input1
// CHECK-NEXT: return getODSOperands(0);
// CHECK-LABEL: Value OpDOperandAdaptor::input2
// CHECK-NEXT: return *getODSOperands(1).begin();
-// CHECK-LABEL: ArrayRef<Value> OpDOperandAdaptor::input3
+// CHECK-LABEL: ValueRange OpDOperandAdaptor::input3
// CHECK-NEXT: return getODSOperands(2);
// CHECK-LABEL: Operation::operand_range OpD::input1
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 0b55825d1a46..7b0cd9d7a482 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -1890,27 +1890,38 @@ class OpOperandAdaptorEmitter {
private:
explicit OpOperandAdaptorEmitter(const Operator &op);
- Class adapterClass;
+ Class adaptor;
};
} // end namespace
OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(const Operator &op)
- : adapterClass(op.getCppClassName().str() + "OperandAdaptor") {
- adapterClass.newField("ArrayRef<Value>", "odsOperands");
- adapterClass.newField("DictionaryAttr", "odsAttrs");
+ : adaptor(op.getAdaptorName()) {
+ adaptor.newField("ValueRange", "odsOperands");
+ adaptor.newField("DictionaryAttr", "odsAttrs");
const auto *attrSizedOperands =
op.getTrait("OpTrait::AttrSizedOperandSegments");
- auto &constructor = adapterClass.newConstructor(
- attrSizedOperands
- ? "ArrayRef<Value> values, DictionaryAttr attrs"
- : "ArrayRef<Value> values, DictionaryAttr attrs = nullptr");
- constructor.body() << " odsOperands = values;\n";
- constructor.body() << " odsAttrs = attrs;\n";
+ {
+ auto &constructor = adaptor.newConstructor(
+ attrSizedOperands
+ ? "ValueRange values, DictionaryAttr attrs"
+ : "ValueRange values, DictionaryAttr attrs = nullptr");
+ constructor.addMemberInitializer("odsOperands", "values");
+ constructor.addMemberInitializer("odsAttrs", "attrs");
+ }
+
+ {
+ auto &constructor = adaptor.newConstructor(
+ llvm::formatv("{0}& op", op.getCppClassName()).str());
+ constructor.addMemberInitializer("odsOperands",
+ "op.getOperation()->getOperands()");
+ constructor.addMemberInitializer("odsAttrs",
+ "op.getOperation()->getAttrDictionary()");
+ }
std::string sizeAttrInit =
formatv(adapterSegmentSizeAttrInitCode, "operand_segment_sizes");
- generateNamedOperandGetters(op, adapterClass, sizeAttrInit,
- /*rangeType=*/"ArrayRef<Value>",
+ generateNamedOperandGetters(op, adaptor, sizeAttrInit,
+ /*rangeType=*/"ValueRange",
/*rangeBeginCall=*/"odsOperands.begin()",
/*rangeSizeCall=*/"odsOperands.size()",
/*getOperandCallPattern=*/"odsOperands[{0}]");
@@ -1919,7 +1930,7 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(const Operator &op)
fctx.withBuilder("mlir::Builder(odsAttrs.getContext())");
auto emitAttr = [&](StringRef name, Attribute attr) {
- auto &body = adapterClass.newMethod(attr.getStorageType(), name).body();
+ auto &body = adaptor.newMethod(attr.getStorageType(), name).body();
body << " assert(odsAttrs && \"no attributes when constructing adapter\");"
<< "\n " << attr.getStorageType() << " attr = "
<< "odsAttrs.get(\"" << name << "\").";
@@ -1949,11 +1960,11 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(const Operator &op)
}
void OpOperandAdaptorEmitter::emitDecl(const Operator &op, raw_ostream &os) {
- OpOperandAdaptorEmitter(op).adapterClass.writeDeclTo(os);
+ OpOperandAdaptorEmitter(op).adaptor.writeDeclTo(os);
}
void OpOperandAdaptorEmitter::emitDef(const Operator &op, raw_ostream &os) {
- OpOperandAdaptorEmitter(op).adapterClass.writeDefTo(os);
+ OpOperandAdaptorEmitter(op).adaptor.writeDefTo(os);
}
// Emits the opcode enum and op classes.
More information about the Mlir-commits
mailing list