[Mlir-commits] [mlir] 9ea6b30 - Update ODS variadic segments "magic" attributes to use native Properties
Mehdi Amini
llvmlistbot at llvm.org
Mon Jul 24 18:17:23 PDT 2023
Author: Mehdi Amini
Date: 2023-07-24T18:16:58-07:00
New Revision: 9ea6b30ac20f8223fb6aeae853e5c73691850a8d
URL: https://github.com/llvm/llvm-project/commit/9ea6b30ac20f8223fb6aeae853e5c73691850a8d
DIFF: https://github.com/llvm/llvm-project/commit/9ea6b30ac20f8223fb6aeae853e5c73691850a8d.diff
LOG: Update ODS variadic segments "magic" attributes to use native Properties
The operand_segment_sizes and result_segment_sizes Attributes are now inlined
in the operation as native propertie. We continue to support building an
Attribute on the fly for `getAttr("operand_segment_sizes")` and setting the
property from an attribute with `setAttr("operand_segment_sizes", attr)`.
A new bytecode version is introduced to support backward compatibility and
backdeployments.
Differential Revision: https://reviews.llvm.org/D155919
Added:
Modified:
mlir/include/mlir/Bytecode/BytecodeImplementation.h
mlir/include/mlir/Bytecode/Encoding.h
mlir/include/mlir/IR/ODSSupport.h
mlir/include/mlir/IR/OpBase.td
mlir/include/mlir/IR/OpDefinition.h
mlir/include/mlir/IR/OperationSupport.h
mlir/include/mlir/TableGen/Property.h
mlir/lib/Bytecode/Reader/BytecodeReader.cpp
mlir/lib/Bytecode/Writer/IRNumbering.cpp
mlir/lib/Bytecode/Writer/IRNumbering.h
mlir/lib/CAPI/IR/IR.cpp
mlir/lib/IR/ODSSupport.cpp
mlir/lib/TableGen/Property.cpp
mlir/test/Dialect/LLVMIR/invalid.mlir
mlir/test/IR/traits.mlir
mlir/test/lib/Dialect/Test/TestOps.td
mlir/test/python/dialects/linalg/ops.py
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
mlir/tools/mlir-tblgen/OpFormatGen.cpp
mlir/unittests/IR/AdaptorTest.cpp
mlir/unittests/IR/OpPropertiesTest.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Bytecode/BytecodeImplementation.h b/mlir/include/mlir/Bytecode/BytecodeImplementation.h
index 4e74c124adcde3..9c9aa7a4fc0ed1 100644
--- a/mlir/include/mlir/Bytecode/BytecodeImplementation.h
+++ b/mlir/include/mlir/Bytecode/BytecodeImplementation.h
@@ -20,6 +20,7 @@
#include "mlir/IR/DialectInterface.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/Twine.h"
namespace mlir {
@@ -39,6 +40,9 @@ class DialectBytecodeReader {
/// Emit an error to the reader.
virtual InFlightDiagnostic emitError(const Twine &msg = {}) = 0;
+ /// Return the bytecode version being read.
+ virtual uint64_t getBytecodeVersion() const = 0;
+
/// Read out a list of elements, invoking the provided callback for each
/// element. The callback function may be in any of the following forms:
/// * LogicalResult(T &)
@@ -148,6 +152,76 @@ class DialectBytecodeReader {
[this](int64_t &value) { return readSignedVarInt(value); });
}
+ /// Parse a variable length encoded integer whose low bit is used to encode an
+ /// unrelated flag, i.e: `(integerValue << 1) | (flag ? 1 : 0)`.
+ LogicalResult readVarIntWithFlag(uint64_t &result, bool &flag) {
+ if (failed(readVarInt(result)))
+ return failure();
+ flag = result & 1;
+ result >>= 1;
+ return success();
+ }
+
+ /// Read a "small" sparse array of integer <= 32 bits elements, where
+ /// index/value pairs can be compressed when the array is small.
+ /// Note that only some position of the array will be read and the ones
+ /// not stored in the bytecode are gonne be left untouched.
+ /// If the provided array is too small for the stored indices, an error
+ /// will be returned.
+ template <typename T>
+ LogicalResult readSparseArray(MutableArrayRef<T> array) {
+ static_assert(sizeof(T) < sizeof(uint64_t), "expect integer < 64 bits");
+ static_assert(std::is_integral<T>::value, "expects integer");
+ uint64_t nonZeroesCount;
+ bool useSparseEncoding;
+ if (failed(readVarIntWithFlag(nonZeroesCount, useSparseEncoding)))
+ return failure();
+ if (nonZeroesCount == 0)
+ return success();
+ if (!useSparseEncoding) {
+ // This is a simple dense array.
+ if (nonZeroesCount > array.size()) {
+ emitError("trying to read an array of ")
+ << nonZeroesCount << " but only " << array.size()
+ << " storage available.";
+ return failure();
+ }
+ for (int64_t index : llvm::seq<int64_t>(0, nonZeroesCount)) {
+ uint64_t value;
+ if (failed(readVarInt(value)))
+ return failure();
+ array[index] = value;
+ }
+ return success();
+ }
+ // Read sparse encoding
+ // This is the number of bits used for packing the index with the value.
+ uint64_t indexBitSize;
+ if (failed(readVarInt(indexBitSize)))
+ return failure();
+ constexpr uint64_t maxIndexBitSize = 8;
+ if (indexBitSize > maxIndexBitSize) {
+ emitError("reading sparse array with indexing above 8 bits: ")
+ << indexBitSize;
+ return failure();
+ }
+ for (uint32_t count : llvm::seq<uint32_t>(0, nonZeroesCount)) {
+ (void)count;
+ uint64_t indexValuePair;
+ if (failed(readVarInt(indexValuePair)))
+ return failure();
+ uint64_t index = indexValuePair & ~(uint64_t(-1) << (indexBitSize));
+ uint64_t value = indexValuePair >> indexBitSize;
+ if (index >= array.size()) {
+ emitError("reading a sparse array found index ")
+ << index << " but only " << array.size() << " storage available.";
+ return failure();
+ }
+ array[index] = value;
+ }
+ return success();
+ }
+
/// Read an APInt that is known to have been encoded with the given width.
virtual FailureOr<APInt> readAPIntWithKnownWidth(unsigned bitWidth) = 0;
@@ -230,6 +304,55 @@ class DialectBytecodeWriter {
writeList(value, [this](int64_t value) { writeSignedVarInt(value); });
}
+ /// Write a VarInt and a flag packed together.
+ void writeVarIntWithFlag(uint64_t value, bool flag) {
+ writeVarInt((value << 1) | (flag ? 1 : 0));
+ }
+
+ /// Write out a "small" sparse array of integer <= 32 bits elements, where
+ /// index/value pairs can be compressed when the array is small. This method
+ /// will scan the array multiple times and should not be used for large
+ /// arrays. The optional provided "zero" can be used to adjust for the
+ /// expected repeated value. We assume here that the array size fits in a 32
+ /// bits integer.
+ template <typename T>
+ void writeSparseArray(ArrayRef<T> array) {
+ static_assert(sizeof(T) < sizeof(uint64_t), "expect integer < 64 bits");
+ static_assert(std::is_integral<T>::value, "expects integer");
+ uint32_t size = array.size();
+ uint32_t nonZeroesCount = 0, lastIndex = 0;
+ for (uint32_t index : llvm::seq<uint32_t>(0, size)) {
+ if (!array[index])
+ continue;
+ nonZeroesCount++;
+ lastIndex = index;
+ }
+ // If the last position is too large, or the array isn't at least 50%
+ // sparse, emit it with a dense encoding.
+ if (lastIndex > 256 || nonZeroesCount > size / 2) {
+ // Emit the array size and a flag which indicates whether it is sparse.
+ writeVarIntWithFlag(size, false);
+ for (const T &elt : array)
+ writeVarInt(elt);
+ return;
+ }
+ // Emit sparse: first the number of elements we'll write and a flag
+ // indicating it is a sparse encoding.
+ writeVarIntWithFlag(nonZeroesCount, true);
+ if (nonZeroesCount == 0)
+ return;
+ // This is the number of bits used for packing the index with the value.
+ int indexBitSize = llvm::Log2_32_Ceil(lastIndex + 1);
+ writeVarInt(indexBitSize);
+ for (uint32_t index : llvm::seq<uint32_t>(0, lastIndex + 1)) {
+ T value = array[index];
+ if (!value)
+ continue;
+ uint64_t indexValuePair = (value << indexBitSize) | (index);
+ writeVarInt(indexValuePair);
+ }
+ }
+
/// Write an APInt to the bytecode stream whose bitwidth will be known
/// externally at read time. This method is useful for encoding APInt values
/// when the width is known via external means, such as via a type. This
diff --git a/mlir/include/mlir/Bytecode/Encoding.h b/mlir/include/mlir/Bytecode/Encoding.h
index 21edc1e3d9857a..ac3826929ce7c1 100644
--- a/mlir/include/mlir/Bytecode/Encoding.h
+++ b/mlir/include/mlir/Bytecode/Encoding.h
@@ -45,8 +45,12 @@ enum BytecodeVersion {
/// with the discardable attributes.
kNativePropertiesEncoding = 5,
+ /// ODS emits operand/result segment_size as native properties instead of
+ /// an attribute.
+ kNativePropertiesODSSegmentSize = 6,
+
/// The current bytecode version.
- kVersion = 5,
+ kVersion = 6,
/// An arbitrary value used to fill alignment padding.
kAlignmentByte = 0xCB,
diff --git a/mlir/include/mlir/IR/ODSSupport.h b/mlir/include/mlir/IR/ODSSupport.h
index 1d3cbbd6900341..687f764ae95fd9 100644
--- a/mlir/include/mlir/IR/ODSSupport.h
+++ b/mlir/include/mlir/IR/ODSSupport.h
@@ -37,6 +37,13 @@ Attribute convertToAttribute(MLIRContext *ctx, int64_t storage);
LogicalResult convertFromAttribute(MutableArrayRef<int64_t> storage,
Attribute attr, InFlightDiagnostic *diag);
+/// Convert a DenseI32ArrayAttr to the provided storage. It is expected that the
+/// storage has the same size as the array. An error is returned if the
+/// attribute isn't a DenseI32ArrayAttr or it does not have the same size. If
+/// the optional diagnostic is provided an error message is also emitted.
+LogicalResult convertFromAttribute(MutableArrayRef<int32_t> storage,
+ Attribute attr, InFlightDiagnostic *diag);
+
/// Convert the provided ArrayRef<int64_t> to a DenseI64ArrayAttr attribute.
Attribute convertToAttribute(MLIRContext *ctx, ArrayRef<int64_t> storage);
diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 940588b7c0f9f2..274a531f4061e6 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -1241,6 +1241,7 @@ class ArrayProperty<string storageTypeParam = "", int n, string desc = ""> :
let interfaceType = "::llvm::ArrayRef<" # storageTypeParam # ">";
let convertFromStorage = "$_storage";
let assignToStorage = "::llvm::copy($_value, $_storage)";
+ let hashProperty = "llvm::hash_combine_range(std::begin($_storage), std::end($_storage));";
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index 221c607c15f4c9..d42bffaf32b03a 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -20,6 +20,7 @@
#define MLIR_IR_OPDEFINITION_H
#include "mlir/IR/Dialect.h"
+#include "mlir/IR/ODSSupport.h"
#include "mlir/IR/Operation.h"
#include "llvm/Support/PointerLikeTypeTraits.h"
diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index 303e0303a87fa0..f3a79eb52f8ec0 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -555,7 +555,8 @@ class RegisteredOperationName : public OperationName {
StringRef name) final {
if constexpr (hasProperties) {
auto concreteOp = cast<ConcreteOp>(op);
- return ConcreteOp::getInherentAttr(concreteOp.getProperties(), name);
+ return ConcreteOp::getInherentAttr(concreteOp.getContext(),
+ concreteOp.getProperties(), name);
}
// If the op does not have support for properties, we dispatch back to the
// dictionnary of discardable attributes for now.
@@ -575,7 +576,8 @@ class RegisteredOperationName : public OperationName {
void populateInherentAttrs(Operation *op, NamedAttrList &attrs) final {
if constexpr (hasProperties) {
auto concreteOp = cast<ConcreteOp>(op);
- ConcreteOp::populateInherentAttrs(concreteOp.getProperties(), attrs);
+ ConcreteOp::populateInherentAttrs(concreteOp.getContext(),
+ concreteOp.getProperties(), attrs);
}
}
LogicalResult
diff --git a/mlir/include/mlir/TableGen/Property.h b/mlir/include/mlir/TableGen/Property.h
index 597543d6f7261d..d0d6f4940c7c04 100644
--- a/mlir/include/mlir/TableGen/Property.h
+++ b/mlir/include/mlir/TableGen/Property.h
@@ -35,51 +35,76 @@ class Property {
public:
explicit Property(const llvm::Record *record);
explicit Property(const llvm::DefInit *init);
+ Property(StringRef storageType, StringRef interfaceType,
+ StringRef convertFromStorageCall, StringRef assignToStorageCall,
+ StringRef convertToAttributeCall, StringRef convertFromAttributeCall,
+ StringRef readFromMlirBytecodeCall,
+ StringRef writeToMlirBytecodeCall, StringRef hashPropertyCall,
+ StringRef defaultValue);
// Returns the storage type.
- StringRef getStorageType() const;
+ StringRef getStorageType() const { return storageType; }
// Returns the interface type for this property.
- StringRef getInterfaceType() const;
+ StringRef getInterfaceType() const { return interfaceType; }
// Returns the template getter method call which reads this property's
// storage and returns the value as of the desired return type.
- StringRef getConvertFromStorageCall() const;
+ StringRef getConvertFromStorageCall() const { return convertFromStorageCall; }
// Returns the template setter method call which reads this property's
// in the provided interface type and assign it to the storage.
- StringRef getAssignToStorageCall() const;
+ StringRef getAssignToStorageCall() const { return assignToStorageCall; }
// Returns the conversion method call which reads this property's
// in the storage type and builds an attribute.
- StringRef getConvertToAttributeCall() const;
+ StringRef getConvertToAttributeCall() const { return convertToAttributeCall; }
// Returns the setter method call which reads this property's
// in the provided interface type and assign it to the storage.
- StringRef getConvertFromAttributeCall() const;
+ StringRef getConvertFromAttributeCall() const {
+ return convertFromAttributeCall;
+ }
// Returns the method call which reads this property from
// bytecode and assign it to the storage.
- StringRef getReadFromMlirBytecodeCall() const;
+ StringRef getReadFromMlirBytecodeCall() const {
+ return readFromMlirBytecodeCall;
+ }
// Returns the method call which write this property's
// to the the bytecode.
- StringRef getWriteToMlirBytecodeCall() const;
+ StringRef getWriteToMlirBytecodeCall() const {
+ return writeToMlirBytecodeCall;
+ }
// Returns the code to compute the hash for this property.
- StringRef getHashPropertyCall() const;
+ StringRef getHashPropertyCall() const { return hashPropertyCall; }
// Returns whether this Property has a default value.
- bool hasDefaultValue() const;
+ bool hasDefaultValue() const { return !defaultValue.empty(); }
+
// Returns the default value for this Property.
- StringRef getDefaultValue() const;
+ StringRef getDefaultValue() const { return defaultValue; }
// Returns the TableGen definition this Property was constructed from.
- const llvm::Record &getDef() const;
+ const llvm::Record &getDef() const { return *def; }
private:
// The TableGen definition of this constraint.
const llvm::Record *def;
+
+ // Elements describing a Property, in general fetched from the record.
+ StringRef storageType;
+ StringRef interfaceType;
+ StringRef convertFromStorageCall;
+ StringRef assignToStorageCall;
+ StringRef convertToAttributeCall;
+ StringRef convertFromAttributeCall;
+ StringRef readFromMlirBytecodeCall;
+ StringRef writeToMlirBytecodeCall;
+ StringRef hashPropertyCall;
+ StringRef defaultValue;
};
// A struct wrapping an op property and its name together
diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
index 8269546d983655..0639baf10b0bc0 100644
--- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
+++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
@@ -796,9 +796,10 @@ class AttrTypeReader {
public:
AttrTypeReader(StringSectionReader &stringReader,
- ResourceSectionReader &resourceReader, Location fileLoc)
+ ResourceSectionReader &resourceReader, Location fileLoc,
+ uint64_t &bytecodeVersion)
: stringReader(stringReader), resourceReader(resourceReader),
- fileLoc(fileLoc) {}
+ fileLoc(fileLoc), bytecodeVersion(bytecodeVersion) {}
/// Initialize the attribute and type information within the reader.
LogicalResult initialize(MutableArrayRef<BytecodeDialect> dialects,
@@ -883,23 +884,30 @@ class AttrTypeReader {
/// A location used for error emission.
Location fileLoc;
+
+ /// Current bytecode version being used.
+ uint64_t &bytecodeVersion;
};
class DialectReader : public DialectBytecodeReader {
public:
DialectReader(AttrTypeReader &attrTypeReader,
StringSectionReader &stringReader,
- ResourceSectionReader &resourceReader, EncodingReader &reader)
+ ResourceSectionReader &resourceReader, EncodingReader &reader,
+ uint64_t &bytecodeVersion)
: attrTypeReader(attrTypeReader), stringReader(stringReader),
- resourceReader(resourceReader), reader(reader) {}
+ resourceReader(resourceReader), reader(reader),
+ bytecodeVersion(bytecodeVersion) {}
InFlightDiagnostic emitError(const Twine &msg) override {
return reader.emitError(msg);
}
+ uint64_t getBytecodeVersion() const override { return bytecodeVersion; }
+
DialectReader withEncodingReader(EncodingReader &encReader) {
return DialectReader(attrTypeReader, stringReader, resourceReader,
- encReader);
+ encReader, bytecodeVersion);
}
Location getLoc() const { return reader.getLoc(); }
@@ -1003,6 +1011,7 @@ class DialectReader : public DialectBytecodeReader {
StringSectionReader &stringReader;
ResourceSectionReader &resourceReader;
EncodingReader &reader;
+ uint64_t &bytecodeVersion;
};
/// Wraps the properties section and handles reading properties out of it.
@@ -1207,7 +1216,8 @@ template <typename T>
LogicalResult AttrTypeReader::parseCustomEntry(Entry<T> &entry,
EncodingReader &reader,
StringRef entryType) {
- DialectReader dialectReader(*this, stringReader, resourceReader, reader);
+ DialectReader dialectReader(*this, stringReader, resourceReader, reader,
+ bytecodeVersion);
if (failed(entry.dialect->load(dialectReader, fileLoc.getContext())))
return failure();
// Ensure that the dialect implements the bytecode interface.
@@ -1252,7 +1262,7 @@ class mlir::BytecodeReader::Impl {
llvm::MemoryBufferRef buffer,
const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef)
: config(config), fileLoc(fileLoc), lazyLoading(lazyLoading),
- attrTypeReader(stringReader, resourceReader, fileLoc),
+ attrTypeReader(stringReader, resourceReader, fileLoc, version),
// Use the builtin unrealized conversion cast operation to represent
// forward references to values that aren't yet defined.
forwardRefOpState(UnknownLoc::get(config.getContext()),
@@ -1782,7 +1792,7 @@ BytecodeReader::Impl::parseOpName(EncodingReader &reader,
if (!opName->opName) {
// Load the dialect and its version.
DialectReader dialectReader(attrTypeReader, stringReader, resourceReader,
- reader);
+ reader, version);
if (failed(opName->dialect->load(dialectReader, getContext())))
return failure();
// If the opName is empty, this is because we use to accept names such as
@@ -1825,7 +1835,7 @@ LogicalResult BytecodeReader::Impl::parseResourceSection(
// Initialize the resource reader with the resource sections.
DialectReader dialectReader(attrTypeReader, stringReader, resourceReader,
- reader);
+ reader, version);
return resourceReader.initialize(fileLoc, config, dialects, stringReader,
*resourceData, *resourceOffsetData,
dialectReader, bufferOwnerRef);
@@ -2186,7 +2196,7 @@ BytecodeReader::Impl::parseOpWithoutRegions(EncodingReader &reader,
// interface and control the serialization.
if (wasRegistered) {
DialectReader dialectReader(attrTypeReader, stringReader, resourceReader,
- reader);
+ reader, version);
if (failed(
propertiesReader.read(fileLoc, dialectReader, &*opName, opState)))
return failure();
diff --git a/mlir/lib/Bytecode/Writer/IRNumbering.cpp b/mlir/lib/Bytecode/Writer/IRNumbering.cpp
index 2547d815c12b17..284b3c02f1f2ce 100644
--- a/mlir/lib/Bytecode/Writer/IRNumbering.cpp
+++ b/mlir/lib/Bytecode/Writer/IRNumbering.cpp
@@ -48,7 +48,7 @@ struct IRNumberingState::NumberingDialectWriter : public DialectBytecodeWriter {
void writeOwnedBool(bool value) override {}
int64_t getBytecodeVersion() const override {
- llvm_unreachable("unexpected querying of version in IRNumbering");
+ return state.getDesiredBytecodeVersion();
}
/// The parent numbering state that is populated by this writer.
@@ -391,6 +391,10 @@ void IRNumberingState::number(Dialect *dialect,
}
}
+int64_t IRNumberingState::getDesiredBytecodeVersion() const {
+ return config.getDesiredBytecodeVersion();
+}
+
namespace {
/// A dummy resource builder used to number dialect resources.
struct NumberingResourceBuilder : public AsmResourceBuilder {
diff --git a/mlir/lib/Bytecode/Writer/IRNumbering.h b/mlir/lib/Bytecode/Writer/IRNumbering.h
index c10e09ade6e486..ca30078f3468f4 100644
--- a/mlir/lib/Bytecode/Writer/IRNumbering.h
+++ b/mlir/lib/Bytecode/Writer/IRNumbering.h
@@ -186,6 +186,9 @@ class IRNumberingState {
return blockOperationCounts[block];
}
+ /// Get the set desired bytecode version to emit.
+ int64_t getDesiredBytecodeVersion() const;
+
private:
/// This class is used to provide a fake dialect writer for numbering nested
/// attributes and types.
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index b140a463982f9b..5231fe5f94d8a0 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -17,6 +17,7 @@
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/Operation.h"
@@ -369,10 +370,15 @@ static LogicalResult inferOperationTypes(OperationState &state) {
if (!properties && info->getOpPropertyByteSize() > 0 && !attributes.empty()) {
auto prop = std::make_unique<char[]>(info->getOpPropertyByteSize());
properties = OpaqueProperties(prop.get());
+ InFlightDiagnostic diag = emitError(state.location)
+ << " failed properties conversion while building "
+ << state.name.getStringRef() << " with `"
+ << attributes << "`: ";
if (failed(info->setOpPropertiesFromAttribute(state.name, properties,
- attributes, nullptr))) {
+ attributes, &diag))) {
return failure();
}
+ diag.abandon();
if (succeeded(inferInterface->inferReturnTypes(
context, state.location, state.operands, attributes, properties,
diff --git a/mlir/lib/IR/ODSSupport.cpp b/mlir/lib/IR/ODSSupport.cpp
index ffb84f0f718d38..f67e7dbf38592e 100644
--- a/mlir/lib/IR/ODSSupport.cpp
+++ b/mlir/lib/IR/ODSSupport.cpp
@@ -33,24 +33,40 @@ LogicalResult mlir::convertFromAttribute(int64_t &storage,
Attribute mlir::convertToAttribute(MLIRContext *ctx, int64_t storage) {
return IntegerAttr::get(IntegerType::get(ctx, 64), storage);
}
-LogicalResult mlir::convertFromAttribute(MutableArrayRef<int64_t> storage,
- ::mlir::Attribute attr,
- ::mlir::InFlightDiagnostic *diag) {
- auto valueAttr = dyn_cast<DenseI64ArrayAttr>(attr);
+
+template <typename DenseArrayTy, typename T>
+LogicalResult convertDenseArrayFromAttr(MutableArrayRef<T> storage,
+ ::mlir::Attribute attr,
+ ::mlir::InFlightDiagnostic *diag,
+ StringRef denseArrayTyStr) {
+ auto valueAttr = dyn_cast<DenseArrayTy>(attr);
if (!valueAttr) {
if (diag)
- *diag << "expected DenseI64ArrayAttr for key `value`";
+ *diag << "expected " << denseArrayTyStr << " for key `value`";
return failure();
}
if (valueAttr.size() != static_cast<int64_t>(storage.size())) {
if (diag)
- *diag << "Size mismatch in attribute conversion: " << valueAttr.size()
+ *diag << "size mismatch in attribute conversion: " << valueAttr.size()
<< " vs " << storage.size();
return failure();
}
llvm::copy(valueAttr.asArrayRef(), storage.begin());
return success();
}
+LogicalResult mlir::convertFromAttribute(MutableArrayRef<int64_t> storage,
+ ::mlir::Attribute attr,
+ ::mlir::InFlightDiagnostic *diag) {
+ return convertDenseArrayFromAttr<DenseI64ArrayAttr>(storage, attr, diag,
+ "DenseI64ArrayAttr");
+}
+LogicalResult mlir::convertFromAttribute(MutableArrayRef<int32_t> storage,
+ Attribute attr,
+ InFlightDiagnostic *diag) {
+ return convertDenseArrayFromAttr<DenseI32ArrayAttr>(storage, attr, diag,
+ "DenseI32ArrayAttr");
+}
+
Attribute mlir::convertToAttribute(MLIRContext *ctx,
ArrayRef<int64_t> storage) {
return DenseI64ArrayAttr::get(ctx, storage);
diff --git a/mlir/lib/TableGen/Property.cpp b/mlir/lib/TableGen/Property.cpp
index b0bea43dd5fda2..e61d2fd2480fd5 100644
--- a/mlir/lib/TableGen/Property.cpp
+++ b/mlir/lib/TableGen/Property.cpp
@@ -32,65 +32,40 @@ static StringRef getValueAsString(const Init *init) {
return {};
}
-Property::Property(const Record *record) : def(record) {
- assert((record->isSubClassOf("Property") || record->isSubClassOf("Attr")) &&
+Property::Property(const Record *def)
+ : Property(getValueAsString(def->getValueInit("storageType")),
+ getValueAsString(def->getValueInit("interfaceType")),
+ getValueAsString(def->getValueInit("convertFromStorage")),
+ getValueAsString(def->getValueInit("assignToStorage")),
+ getValueAsString(def->getValueInit("convertToAttribute")),
+ getValueAsString(def->getValueInit("convertFromAttribute")),
+ getValueAsString(def->getValueInit("readFromMlirBytecode")),
+ getValueAsString(def->getValueInit("writeToMlirBytecode")),
+ getValueAsString(def->getValueInit("hashProperty")),
+ getValueAsString(def->getValueInit("defaultValue"))) {
+ this->def = def;
+ assert((def->isSubClassOf("Property") || def->isSubClassOf("Attr")) &&
"must be subclass of TableGen 'Property' class");
}
Property::Property(const DefInit *init) : Property(init->getDef()) {}
-StringRef Property::getStorageType() const {
- const auto *init = def->getValueInit("storageType");
- auto type = getValueAsString(init);
- if (type.empty())
- return "Property";
- return type;
+Property::Property(StringRef storageType, StringRef interfaceType,
+ StringRef convertFromStorageCall,
+ StringRef assignToStorageCall,
+ StringRef convertToAttributeCall,
+ StringRef convertFromAttributeCall,
+ StringRef readFromMlirBytecodeCall,
+ StringRef writeToMlirBytecodeCall,
+ StringRef hashPropertyCall, StringRef defaultValue)
+ : storageType(storageType), interfaceType(interfaceType),
+ convertFromStorageCall(convertFromStorageCall),
+ assignToStorageCall(assignToStorageCall),
+ convertToAttributeCall(convertToAttributeCall),
+ convertFromAttributeCall(convertFromAttributeCall),
+ readFromMlirBytecodeCall(readFromMlirBytecodeCall),
+ writeToMlirBytecodeCall(writeToMlirBytecodeCall),
+ hashPropertyCall(hashPropertyCall), defaultValue(defaultValue) {
+ if (storageType.empty())
+ storageType = "Property";
}
-
-StringRef Property::getInterfaceType() const {
- const auto *init = def->getValueInit("interfaceType");
- return getValueAsString(init);
-}
-
-StringRef Property::getConvertFromStorageCall() const {
- const auto *init = def->getValueInit("convertFromStorage");
- return getValueAsString(init);
-}
-
-StringRef Property::getAssignToStorageCall() const {
- const auto *init = def->getValueInit("assignToStorage");
- return getValueAsString(init);
-}
-
-StringRef Property::getConvertToAttributeCall() const {
- const auto *init = def->getValueInit("convertToAttribute");
- return getValueAsString(init);
-}
-
-StringRef Property::getConvertFromAttributeCall() const {
- const auto *init = def->getValueInit("convertFromAttribute");
- return getValueAsString(init);
-}
-
-StringRef Property::getReadFromMlirBytecodeCall() const {
- const auto *init = def->getValueInit("readFromMlirBytecode");
- return getValueAsString(init);
-}
-
-StringRef Property::getWriteToMlirBytecodeCall() const {
- const auto *init = def->getValueInit("writeToMlirBytecode");
- return getValueAsString(init);
-}
-
-StringRef Property::getHashPropertyCall() const {
- return getValueAsString(def->getValueInit("hashProperty"));
-}
-
-bool Property::hasDefaultValue() const { return !getDefaultValue().empty(); }
-
-StringRef Property::getDefaultValue() const {
- const auto *init = def->getValueInit("defaultValue");
- return getValueAsString(init);
-}
-
-const llvm::Record &Property::getDef() const { return *def; }
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index 09bbc5a4739657..14141c4c243ab3 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -887,7 +887,7 @@ func.func @switch_wrong_number_of_weights(%arg0 : i32) {
func.func @switch_case_type_mismatch(%arg0 : i64) {
// expected-error at below {{expects case value type to match condition value type}}
- "llvm.switch"(%arg0)[^bb1, ^bb2] <{case_operand_segments = array<i32: 0>, case_values = dense<42> : vector<1xi32>, operand_segment_sizes = array<i32: 1, 0, 0>}> : (i64) -> ()
+ "llvm.switch"(%arg0)[^bb1, ^bb2] <{case_operand_segments = array<i32: 0>, case_values = dense<42> : vector<1xi32>, odsOperandSegmentSizes = array<i32: 1, 0, 0>}> : (i64) -> ()
^bb1: // pred: ^bb0
llvm.return
^bb2: // pred: ^bb0
diff --git a/mlir/test/IR/traits.mlir b/mlir/test/IR/traits.mlir
index 307918defab0e1..7d922ecf67de5d 100644
--- a/mlir/test/IR/traits.mlir
+++ b/mlir/test/IR/traits.mlir
@@ -383,21 +383,21 @@ func.func private @foo()
// -----
func.func @failedMissingOperandSizeAttr(%arg: i32) {
- // expected-error @+1 {{requires dense i32 array attribute 'operand_segment_sizes'}}
+ // expected-error @+1 {{op operand count (4) does not match with the total size (0) specified in attribute 'operand_segment_sizes'}}
"test.attr_sized_operands"(%arg, %arg, %arg, %arg) : (i32, i32, i32, i32) -> ()
}
// -----
func.func @failedOperandSizeAttrWrongType(%arg: i32) {
- // expected-error @+1 {{attribute 'operand_segment_sizes' failed to satisfy constraint: i32 dense array attribute}}
+ // expected-error @+1 {{op operand count (4) does not match with the total size (0) specified in attribute 'operand_segment_sizes'}}
"test.attr_sized_operands"(%arg, %arg, %arg, %arg) {operand_segment_sizes = 10} : (i32, i32, i32, i32) -> ()
}
// -----
func.func @failedOperandSizeAttrWrongElementType(%arg: i32) {
- // expected-error @+1 {{attribute 'operand_segment_sizes' failed to satisfy constraint: i32 dense array attribute}}
+ // expected-error @+1 {{op operand count (4) does not match with the total size (0) specified in attribute 'operand_segment_sizes'}}
"test.attr_sized_operands"(%arg, %arg, %arg, %arg) {operand_segment_sizes = array<i64: 1, 1, 1, 1>} : (i32, i32, i32, i32) -> ()
}
@@ -418,7 +418,7 @@ func.func @failedOperandSizeAttrWrongTotalSize(%arg: i32) {
// -----
func.func @failedOperandSizeAttrWrongCount(%arg: i32) {
- // expected-error @+1 {{'operand_segment_sizes' attribute for specifying operand segments must have 4 elements}}
+ // expected-error @+1 {{test.attr_sized_operands' op operand count (4) does not match with the total size (0) specified in attribute 'operand_segment_sizes}}
"test.attr_sized_operands"(%arg, %arg, %arg, %arg) {operand_segment_sizes = array<i32: 2, 1, 1>} : (i32, i32, i32, i32) -> ()
}
@@ -433,14 +433,14 @@ func.func @succeededOperandSizeAttr(%arg: i32) {
// -----
func.func @failedMissingResultSizeAttr() {
- // expected-error @+1 {{requires dense i32 array attribute 'result_segment_sizes'}}
+ // expected-error @+1 {{op result count (4) does not match with the total size (0) specified in attribute 'result_segment_sizes'}}
%0:4 = "test.attr_sized_results"() : () -> (i32, i32, i32, i32)
}
// -----
func.func @failedResultSizeAttrWrongType() {
- // expected-error @+1 {{requires dense i32 array attribute 'result_segment_sizes'}}
+ // expected-error @+1 {{ op result count (4) does not match with the total size (0) specified in attribute 'result_segment_sizes'}}
%0:4 = "test.attr_sized_results"() {result_segment_sizes = 10} : () -> (i32, i32, i32, i32)
}
@@ -448,7 +448,7 @@ func.func @failedResultSizeAttrWrongType() {
// -----
func.func @failedResultSizeAttrWrongElementType() {
- // expected-error @+1 {{requires dense i32 array attribute 'result_segment_sizes'}}
+ // expected-error @+1 {{ op result count (4) does not match with the total size (0) specified in attribute 'result_segment_sizes'}}
%0:4 = "test.attr_sized_results"() {result_segment_sizes = array<i64: 1, 1, 1, 1>} : () -> (i32, i32, i32, i32)
}
@@ -469,7 +469,7 @@ func.func @failedResultSizeAttrWrongTotalSize() {
// -----
func.func @failedResultSizeAttrWrongCount() {
- // expected-error @+1 {{'result_segment_sizes' attribute for specifying result segments must have 4 elements, but got 3}}
+ // expected-error @+1 {{ op result count (4) does not match with the total size (0) specified in attribute 'result_segment_sizes'}}
%0:4 = "test.attr_sized_results"() {result_segment_sizes = array<i32: 2, 1, 1>} : () -> (i32, i32, i32, i32)
}
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 8056f6d7e03183..966896b27d1cb1 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -839,8 +839,7 @@ def AttrSizedOperandOp : TEST_Op<"attr_sized_operands",
Variadic<I32>:$a,
Variadic<I32>:$b,
I32:$c,
- Variadic<I32>:$d,
- DenseI32ArrayAttr:$operand_segment_sizes
+ Variadic<I32>:$d
);
}
diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py
index 5e8414ad4055c6..88f48d0d544e7c 100644
--- a/mlir/test/python/dialects/linalg/ops.py
+++ b/mlir/test/python/dialects/linalg/ops.py
@@ -100,7 +100,7 @@ def named_form(lhs, rhs):
init_result = tensor.EmptyOp([4, 8], f32)
# CHECK: "linalg.matmul"(%{{.*}})
# CHECK-SAME: cast = #linalg.type_fn<cast_signed>
- # CHECK-SAME: operand_segment_sizes = array<i32: 2, 1>
+ # CHECK-SAME: odsOperandSegmentSizes = array<i32: 2, 1>
# CHECK-NEXT: ^bb0(%{{.*}}: f32, %{{.*}}: f32, %{{.*}}: f32):
# CHECK-NEXT: arith.mulf{{.*}} (f32, f32) -> f32
# CHECK-NEXT: arith.addf{{.*}} (f32, f32) -> f32
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index c069e7ce89a110..d0e888653239ad 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -34,6 +34,7 @@
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/Signals.h"
+#include "llvm/Support/raw_ostream.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Record.h"
#include "llvm/TableGen/TableGenBackend.h"
@@ -115,6 +116,10 @@ static const char *const adapterSegmentSizeAttrInitCode = R"(
assert({0} && "missing segment size attribute for op");
auto sizeAttr = ::llvm::cast<::mlir::DenseI32ArrayAttr>({0});
)";
+static const char *const adapterSegmentSizeAttrInitCodeProperties = R"(
+ ::llvm::ArrayRef<int32_t> sizeAttr = {0};
+)";
+
/// The code snippet to initialize the sizes for the value range calculation.
///
/// {0}: The code to get the attribute.
@@ -150,6 +155,29 @@ static const char *const valueRangeReturnCode = R"(
std::next({0}, valueRange.first + valueRange.second)};
)";
+/// Read operand/result segment_size from bytecode.
+static const char *const readBytecodeSegmentSize = R"(
+if ($_reader.getBytecodeVersion() < /*kNativePropertiesODSSegmentSize=*/6) {
+ DenseI32ArrayAttr attr;
+ if (::mlir::failed($_reader.readAttribute(attr))) return failure();
+ if (attr.size() > static_cast<int64_t>(sizeof($_storage) / sizeof(int32_t))) {
+ $_reader.emitError("size mismatch for operand/result_segment_size");
+ return failure();
+ }
+ llvm::copy(ArrayRef<int32_t>(attr), $_storage);
+} else {
+ return $_reader.readSparseArray(MutableArrayRef($_storage));
+}
+)";
+
+/// Write operand/result segment_size to bytecode.
+static const char *const writeBytecodeSegmentSize = R"(
+if ($_writer.getBytecodeVersion() < /*kNativePropertiesODSSegmentSize=*/6)
+ $_writer.writeAttribute(DenseI32ArrayAttr::get(getContext(), $_storage));
+else
+ $_writer.writeSparseArray(ArrayRef($_storage));
+)";
+
/// A header for indicating code sections.
///
/// {0}: Some text, or a class name.
@@ -343,6 +371,9 @@ class OpOrAdaptorHelper {
return true;
if (!op.getDialect().usePropertiesForAttributes())
return false;
+ if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments") ||
+ op.getTrait("::mlir::OpTrait::AttrSizedResultSegments"))
+ return true;
return llvm::any_of(getAttrMetadata(),
[](const std::pair<StringRef, AttributeMetadata> &it) {
return !it.second.constraint ||
@@ -350,6 +381,14 @@ class OpOrAdaptorHelper {
});
}
+ std::optional<NamedProperty> &getOperandSegmentsSize() {
+ return operandSegmentsSize;
+ }
+
+ std::optional<NamedProperty> &getResultSegmentsSize() {
+ return resultSegmentsSize;
+ }
+
private:
// Compute the attribute metadata.
void computeAttrMetadata();
@@ -361,6 +400,13 @@ class OpOrAdaptorHelper {
// The attribute metadata, mapped by name.
llvm::MapVector<StringRef, AttributeMetadata> attrMetadata;
+
+ // Property
+ std::optional<NamedProperty> operandSegmentsSize;
+ std::string operandSegmentsSizeStorage;
+ std::optional<NamedProperty> resultSegmentsSize;
+ std::string resultSegmentsSizeStorage;
+
// The number of required attributes.
unsigned numRequired;
};
@@ -377,18 +423,50 @@ void OpOrAdaptorHelper::computeAttrMetadata() {
attrMetadata.insert(
{namedAttr.name, AttributeMetadata{namedAttr.name, !isOptional, attr}});
}
+
+ auto makeProperty = [&](StringRef storageType) {
+ return Property(
+ /*storageType=*/storageType,
+ /*interfaceType=*/"::llvm::ArrayRef<int32_t>",
+ /*convertFromStorageCall=*/"$_storage",
+ /*assignToStorageCall=*/"::llvm::copy($_value, $_storage)",
+ /*convertToAttributeCall=*/
+ "DenseI32ArrayAttr::get($_ctxt, $_storage)",
+ /*convertFromAttributeCall=*/
+ "return convertFromAttribute($_storage, $_attr, $_diag);",
+ /*readFromMlirBytecodeCall=*/readBytecodeSegmentSize,
+ /*writeToMlirBytecodeCall=*/writeBytecodeSegmentSize,
+ /*hashPropertyCall=*/
+ "llvm::hash_combine_range(std::begin($_storage), "
+ "std::end($_storage));",
+ /*StringRef defaultValue=*/"");
+ };
// Include key attributes from several traits as implicitly registered.
if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
- attrMetadata.insert(
- {operandSegmentAttrName,
- AttributeMetadata{operandSegmentAttrName, /*isRequired=*/true,
- /*attr=*/std::nullopt}});
+ if (op.getDialect().usePropertiesForAttributes()) {
+ operandSegmentsSizeStorage =
+ llvm::formatv("int32_t[{0}]", op.getNumOperands());
+ operandSegmentsSize = {"odsOperandSegmentSizes",
+ makeProperty(operandSegmentsSizeStorage)};
+ } else {
+ attrMetadata.insert(
+ {operandSegmentAttrName, AttributeMetadata{operandSegmentAttrName,
+ /*isRequired=*/true,
+ /*attr=*/std::nullopt}});
+ }
}
if (op.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) {
- attrMetadata.insert(
- {resultSegmentAttrName,
- AttributeMetadata{resultSegmentAttrName, /*isRequired=*/true,
- /*attr=*/std::nullopt}});
+ if (op.getDialect().usePropertiesForAttributes()) {
+ resultSegmentsSizeStorage =
+ llvm::formatv("int32_t[{0}]", op.getNumResults());
+ resultSegmentsSize = {"odsResultSegmentSizes",
+ makeProperty(resultSegmentsSizeStorage)};
+ } else {
+ attrMetadata.insert(
+ {resultSegmentAttrName,
+ AttributeMetadata{resultSegmentAttrName, /*isRequired=*/true,
+ /*attr=*/std::nullopt}});
+ }
}
// Store the metadata in sorted order.
@@ -660,14 +738,17 @@ static void genNativeTraitAttrVerifier(MethodBody &body,
// Verify a few traits first so that we can use getODSOperands() and
// getODSResults() in the rest of the verifier.
auto &op = emitHelper.getOp();
- if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
- body << formatv(checkAttrSizedValueSegmentsCode, operandSegmentAttrName,
- op.getNumOperands(), "operand",
- emitHelper.emitErrorPrefix());
- }
- if (op.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) {
- body << formatv(checkAttrSizedValueSegmentsCode, resultSegmentAttrName,
- op.getNumResults(), "result", emitHelper.emitErrorPrefix());
+ if (!op.getDialect().usePropertiesForAttributes()) {
+ if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
+ body << formatv(checkAttrSizedValueSegmentsCode, operandSegmentAttrName,
+ op.getNumOperands(), "operand",
+ emitHelper.emitErrorPrefix());
+ }
+ if (op.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) {
+ body << formatv(checkAttrSizedValueSegmentsCode, resultSegmentAttrName,
+ op.getNumResults(), "result",
+ emitHelper.emitErrorPrefix());
+ }
}
}
@@ -964,14 +1045,16 @@ static void errorIfPruned(size_t line, Method *m, const Twine &methodName,
void OpEmitter::genAttrNameGetters() {
const llvm::MapVector<StringRef, AttributeMetadata> &attributes =
emitHelper.getAttrMetadata();
-
+ bool hasOperandSegmentsSize =
+ op.getDialect().usePropertiesForAttributes() &&
+ op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments");
// Emit the getAttributeNames method.
{
auto *method = opClass.addStaticInlineMethod(
"::llvm::ArrayRef<::llvm::StringRef>", "getAttributeNames");
ERROR_IF_PRUNED(method, "getAttributeNames", op);
auto &body = method->body();
- if (attributes.empty()) {
+ if (!hasOperandSegmentsSize && attributes.empty()) {
body << " return {};";
// Nothing else to do if there are no registered attributes. Exit early.
return;
@@ -981,6 +1064,11 @@ void OpEmitter::genAttrNameGetters() {
[&](StringRef attrName) {
body << "::llvm::StringRef(\"" << attrName << "\")";
});
+ if (hasOperandSegmentsSize) {
+ if (!attributes.empty())
+ body << ", ";
+ body << "::llvm::StringRef(\"" << operandSegmentAttrName << "\")";
+ }
body << "};\n return ::llvm::ArrayRef(attrNames);";
}
@@ -1033,6 +1121,26 @@ void OpEmitter::genAttrNameGetters() {
"name, " + Twine(index));
}
}
+ if (hasOperandSegmentsSize) {
+ std::string name = op.getGetterName(operandSegmentAttrName);
+ std::string methodName = name + "AttrName";
+ // Generate the non-static variant.
+ {
+ auto *method = opClass.addInlineMethod("::mlir::StringAttr", methodName);
+ ERROR_IF_PRUNED(method, methodName, op);
+ method->body()
+ << " return (*this)->getName().getAttributeNames().back();";
+ }
+
+ // Generate the static variant.
+ {
+ auto *method = opClass.addStaticInlineMethod(
+ "::mlir::StringAttr", methodName,
+ MethodParameter("::mlir::OperationName", "name"));
+ ERROR_IF_PRUNED(method, methodName, op);
+ method->body() << " return name.getAttributeNames().back();";
+ }
+ }
}
// Emit the getter for an attribute with the return type specified.
@@ -1080,6 +1188,10 @@ void OpEmitter::genPropertiesSupport() {
}
for (const NamedProperty &prop : op.getProperties())
attrOrProperties.push_back(&prop);
+ if (emitHelper.getOperandSegmentsSize())
+ attrOrProperties.push_back(&emitHelper.getOperandSegmentsSize().value());
+ if (emitHelper.getResultSegmentsSize())
+ attrOrProperties.push_back(&emitHelper.getResultSegmentsSize().value());
if (attrOrProperties.empty())
return;
auto &setPropMethod =
@@ -1104,6 +1216,7 @@ void OpEmitter::genPropertiesSupport() {
auto &getInherentAttrMethod =
opClass
.addStaticMethod("std::optional<mlir::Attribute>", "getInherentAttr",
+ MethodParameter("::mlir::MLIRContext *", "ctx"),
MethodParameter("const Properties &", "prop"),
MethodParameter("llvm::StringRef", "name"))
->body();
@@ -1117,6 +1230,7 @@ void OpEmitter::genPropertiesSupport() {
auto &populateInherentAttrsMethod =
opClass
.addStaticMethod("void", "populateInherentAttrs",
+ MethodParameter("::mlir::MLIRContext *", "ctx"),
MethodParameter("const Properties &", "prop"),
MethodParameter("::mlir::NamedAttrList &", "attrs"))
->body();
@@ -1165,7 +1279,7 @@ void OpEmitter::genPropertiesSupport() {
::mlir::InFlightDiagnostic *propDiag) {{
{0};
};
- auto attr = dict.get("{1}");
+ {2};
if (!attr) {{
if (diag)
*diag << "expected key entry for {1} in DictionaryAttr to set "
@@ -1176,25 +1290,50 @@ void OpEmitter::genPropertiesSupport() {
return ::mlir::failure();
}
)decl";
+
for (const auto &attrOrProp : attrOrProperties) {
if (const auto *namedProperty =
llvm::dyn_cast_if_present<const NamedProperty *>(attrOrProp)) {
StringRef name = namedProperty->name;
auto &prop = namedProperty->prop;
FmtContext fctx;
+
+ std::string getAttr;
+ llvm::raw_string_ostream os(getAttr);
+ os << " auto attr = dict.get(\"" << name << "\");";
+ if (name == "odsOperandSegmentSizes") {
+ os << " if (!attr) attr = dict.get(\"operand_segment_sizes\");";
+ }
+ if (name == "odsResultSegmentSizes") {
+ os << " if (!attr) attr = dict.get(\"result_segment_sizes\");";
+ }
+ os.flush();
+
setPropMethod << formatv(propFromAttrFmt,
tgfmt(prop.getConvertFromAttributeCall(),
&fctx.addSubst("_attr", propertyAttr)
.addSubst("_storage", propertyStorage)
.addSubst("_diag", propertyDiag)),
- name);
+ name, getAttr);
+
} else {
const auto *namedAttr = llvm::dyn_cast_if_present<const AttributeMetadata *>(attrOrProp);
StringRef name = namedAttr->attrName;
+ std::string getAttr;
+ llvm::raw_string_ostream os(getAttr);
+ os << " auto attr = dict.get(\"" << name << "\");";
+ if (name == "odsOperandSegmentSizes") {
+ os << " if (!attr) attr = dict.get(\"operand_segment_sizes\");";
+ }
+ if (name == "odsResultSegmentSizes") {
+ os << " if (!attr) attr = dict.get(\"result_segment_sizes\");";
+ }
+ os.flush();
+
setPropMethod << formatv(R"decl(
{{
auto &propStorage = prop.{0};
- auto attr = dict.get("{0}");
+ {2}
if (attr || /*isRequired=*/{1}) {{
if (!attr) {{
if (diag)
@@ -1213,7 +1352,7 @@ void OpEmitter::genPropertiesSupport() {
}
}
)decl",
- name, namedAttr->isRequired);
+ name, namedAttr->isRequired, getAttr);
}
}
setPropMethod << " return ::mlir::success();\n";
@@ -1318,6 +1457,60 @@ void OpEmitter::genPropertiesSupport() {
<< formatv(populateInherentAttrsMethodFmt, name);
continue;
}
+ // The ODS segment size property is "special": we expose it as an attribute
+ // even though it is a native property.
+ const auto *namedProperty = cast<const NamedProperty *>(attrOrProp);
+ StringRef name = namedProperty->name;
+ if (name != "odsOperandSegmentSizes" && name != "odsResultSegmentSizes")
+ continue;
+ auto &prop = namedProperty->prop;
+ FmtContext fctx;
+ fctx.addSubst("_ctxt", "ctx");
+ fctx.addSubst("_storage", Twine("prop.") + name);
+ if (name == "odsOperandSegmentSizes") {
+ getInherentAttrMethod
+ << formatv(" if (name == \"odsOperandSegmentSizes\" || name == "
+ "\"{0}\") return ",
+ operandSegmentAttrName);
+ } else {
+ getInherentAttrMethod
+ << formatv(" if (name == \"odsResultSegmentSizes\" || name == "
+ "\"{0}\") return ",
+ resultSegmentAttrName);
+ }
+ getInherentAttrMethod << tgfmt(prop.getConvertToAttributeCall(), &fctx)
+ << ";\n";
+
+ if (name == "odsOperandSegmentSizes") {
+ setInherentAttrMethod << formatv(
+ " if (name == \"odsOperandSegmentSizes\" || name == "
+ "\"{0}\") {{",
+ operandSegmentAttrName);
+ } else {
+ setInherentAttrMethod
+ << formatv(" if (name == \"odsResultSegmentSizes\" || name == "
+ "\"{0}\") {{",
+ resultSegmentAttrName);
+ }
+ setInherentAttrMethod << formatv(R"decl(
+ auto arrAttr = dyn_cast_or_null<DenseI32ArrayAttr>(value);
+ if (!arrAttr) return;
+ if (arrAttr.size() != sizeof(prop.{0}) / sizeof(int32_t))
+ return;
+ llvm::copy(arrAttr.asArrayRef(), prop.{0});
+ return;
+ }
+)decl",
+ name);
+ if (name == "odsOperandSegmentSizes") {
+ populateInherentAttrsMethod
+ << formatv(" attrs.append(\"{0}\", {1});\n", operandSegmentAttrName,
+ tgfmt(prop.getConvertToAttributeCall(), &fctx));
+ } else {
+ populateInherentAttrsMethod
+ << formatv(" attrs.append(\"{0}\", {1});\n", resultSegmentAttrName,
+ tgfmt(prop.getConvertToAttributeCall(), &fctx));
+ }
}
getInherentAttrMethod << " return std::nullopt;\n";
@@ -1815,8 +2008,13 @@ void OpEmitter::genNamedOperandGetters() {
// array.
std::string attrSizeInitCode;
if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
- attrSizeInitCode = formatv(opSegmentSizeAttrInitCode,
- emitHelper.getAttr(operandSegmentAttrName));
+ if (op.getDialect().usePropertiesForAttributes())
+ attrSizeInitCode = formatv(adapterSegmentSizeAttrInitCodeProperties,
+ "getProperties().odsOperandSegmentSizes");
+
+ else
+ attrSizeInitCode = formatv(opSegmentSizeAttrInitCode,
+ emitHelper.getAttr(operandSegmentAttrName));
}
generateNamedOperandGetters(
@@ -1851,10 +2049,11 @@ void OpEmitter::genNamedOperandSetters() {
"range.first, range.second";
if (attrSizedOperands) {
if (emitHelper.hasProperties())
- body << formatv(
- ", ::mlir::MutableOperandRange::OperandSegment({0}u, "
- "{getOperandSegmentSizesAttrName(), getProperties().{1}})",
- i, operandSegmentAttrName);
+ body << formatv(", ::mlir::MutableOperandRange::OperandSegment({0}u, "
+ "{{getOperandSegmentSizesAttrName(), "
+ "DenseI32ArrayAttr::get(getContext(), "
+ "getProperties().odsOperandSegmentSizes)})",
+ i);
else
body << formatv(
", ::mlir::MutableOperandRange::OperandSegment({0}u, *{1})", i,
@@ -1910,8 +2109,13 @@ void OpEmitter::genNamedResultGetters() {
// Build the initializer string for the result segment size attribute.
std::string attrSizeInitCode;
if (attrSizedResults) {
- attrSizeInitCode = formatv(opSegmentSizeAttrInitCode,
- emitHelper.getAttr(resultSegmentAttrName));
+ if (op.getDialect().usePropertiesForAttributes())
+ attrSizeInitCode = formatv(adapterSegmentSizeAttrInitCodeProperties,
+ "getProperties().odsResultSegmentSizes");
+
+ else
+ attrSizeInitCode = formatv(opSegmentSizeAttrInitCode,
+ emitHelper.getAttr(resultSegmentAttrName));
}
generateValueRangeStartAndEnd(
@@ -2086,10 +2290,7 @@ void OpEmitter::genSeparateArgParamBuilder() {
// the length of the type ranges.
if (op.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) {
if (op.getDialect().usePropertiesForAttributes()) {
- body << " (" << builderOpState
- << ".getOrAddProperties<Properties>()." << resultSegmentAttrName
- << " = \n"
- " odsBuilder.getDenseI32ArrayAttr({";
+ body << " llvm::copy(ArrayRef<int32_t>({";
} else {
std::string getterName = op.getGetterName(resultSegmentAttrName);
body << " " << builderOpState << ".addAttribute(" << getterName
@@ -2112,7 +2313,12 @@ void OpEmitter::genSeparateArgParamBuilder() {
body << "static_cast<int32_t>(" << resultNames[i] << ".size())";
}
});
- body << "}));\n";
+ if (op.getDialect().usePropertiesForAttributes()) {
+ body << "}), " << builderOpState
+ << ".getOrAddProperties<Properties>().odsResultSegmentSizes);\n";
+ } else {
+ body << "}));\n";
+ }
}
return;
@@ -2706,17 +2912,7 @@ void OpEmitter::genCodeForAddingArgAndRegionForBuilder(
}
// If the operation has the operand segment size attribute, add it here.
- if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
- std::string sizes = op.getGetterName(operandSegmentAttrName);
- if (op.getDialect().usePropertiesForAttributes()) {
- body << " (" << builderOpState << ".getOrAddProperties<Properties>()."
- << operandSegmentAttrName << "= "
- << "odsBuilder.getDenseI32ArrayAttr({";
- } else {
- body << " " << builderOpState << ".addAttribute(" << sizes << "AttrName("
- << builderOpState << ".name), "
- << "odsBuilder.getDenseI32ArrayAttr({";
- }
+ auto emitSegment = [&]() {
interleaveComma(llvm::seq<int>(0, op.getNumOperands()), body, [&](int i) {
const NamedTypeConstraint &operand = op.getOperand(i);
if (!operand.isVariableLength()) {
@@ -2737,7 +2933,21 @@ void OpEmitter::genCodeForAddingArgAndRegionForBuilder(
body << "static_cast<int32_t>(" << getArgumentName(op, i) << ".size())";
}
});
- body << "}));\n";
+ };
+ if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
+ std::string sizes = op.getGetterName(operandSegmentAttrName);
+ if (op.getDialect().usePropertiesForAttributes()) {
+ body << " llvm::copy(ArrayRef<int32_t>({";
+ emitSegment();
+ body << "}), " << builderOpState
+ << ".getOrAddProperties<Properties>().odsOperandSegmentSizes);\n";
+ } else {
+ body << " " << builderOpState << ".addAttribute(" << sizes << "AttrName("
+ << builderOpState << ".name), "
+ << "odsBuilder.getDenseI32ArrayAttr({";
+ emitSegment();
+ body << "}));\n";
+ }
}
// Push all attributes to the result.
@@ -3541,6 +3751,10 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
}
for (const NamedProperty &prop : op.getProperties())
attrOrProperties.push_back(&prop);
+ if (emitHelper.getOperandSegmentsSize())
+ attrOrProperties.push_back(&emitHelper.getOperandSegmentsSize().value());
+ if (emitHelper.getResultSegmentsSize())
+ attrOrProperties.push_back(&emitHelper.getResultSegmentsSize().value());
assert(!attrOrProperties.empty());
std::string declarations = " struct Properties {\n";
llvm::raw_string_ostream os(declarations);
@@ -3598,7 +3812,8 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
if (attr) {
storageType = attr->getStorageType();
} else {
- if (name != operandSegmentAttrName && name != resultSegmentAttrName) {
+ if (name != "odsOperandSegmentSizes" &&
+ name != "odsResultSegmentSizes") {
report_fatal_error("unexpected AttributeMetadata");
}
// TODO: update to use native integers.
@@ -3710,8 +3925,13 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
std::string sizeAttrInit;
if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
- sizeAttrInit = formatv(adapterSegmentSizeAttrInitCode,
- emitHelper.getAttr(operandSegmentAttrName));
+ if (op.getDialect().usePropertiesForAttributes())
+ sizeAttrInit =
+ formatv(adapterSegmentSizeAttrInitCodeProperties,
+ llvm::formatv("getProperties().odsOperandSegmentSizes"));
+ else
+ sizeAttrInit = formatv(adapterSegmentSizeAttrInitCode,
+ emitHelper.getAttr(operandSegmentAttrName));
}
generateNamedOperandGetters(op, genericAdaptor,
/*genericAdaptorBase=*/&genericAdaptorBase,
diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index 9a5e4d52c550a4..1e131799fbdced 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -1654,16 +1654,6 @@ void OperationFormat::genParserVariadicSegmentResolution(Operator &op,
MethodBody &body) {
if (!allOperands) {
if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
- if (op.getDialect().usePropertiesForAttributes()) {
- body << formatv(" "
- "result.getOrAddProperties<{0}::Properties>().operand_"
- "segment_sizes = "
- "(parser.getBuilder().getDenseI32ArrayAttr({{",
- op.getCppClassName());
- } else {
- body << " result.addAttribute(\"operand_segment_sizes\", "
- << "parser.getBuilder().getDenseI32ArrayAttr({";
- }
auto interleaveFn = [&](const NamedTypeConstraint &operand) {
// If the operand is variadic emit the parsed size.
if (operand.isVariableLength())
@@ -1671,8 +1661,19 @@ void OperationFormat::genParserVariadicSegmentResolution(Operator &op,
else
body << "1";
};
- llvm::interleaveComma(op.getOperands(), body, interleaveFn);
- body << "}));\n";
+ if (op.getDialect().usePropertiesForAttributes()) {
+ body << "llvm::copy(ArrayRef<int32_t>({";
+ llvm::interleaveComma(op.getOperands(), body, interleaveFn);
+ body << formatv("}), "
+ "result.getOrAddProperties<{0}::Properties>()."
+ "odsOperandSegmentSizes);\n",
+ op.getCppClassName());
+ } else {
+ body << " result.addAttribute(\"operand_segment_sizes\", "
+ << "parser.getBuilder().getDenseI32ArrayAttr({";
+ llvm::interleaveComma(op.getOperands(), body, interleaveFn);
+ body << "}));\n";
+ }
}
for (const NamedTypeConstraint &operand : op.getOperands()) {
if (!operand.isVariadicOfVariadic())
@@ -1697,16 +1698,6 @@ void OperationFormat::genParserVariadicSegmentResolution(Operator &op,
if (!allResultTypes &&
op.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) {
- if (op.getDialect().usePropertiesForAttributes()) {
- body << formatv(
- " "
- "result.getOrAddProperties<{0}::Properties>().result_segment_sizes = "
- "(parser.getBuilder().getDenseI32ArrayAttr({{",
- op.getCppClassName());
- } else {
- body << " result.addAttribute(\"result_segment_sizes\", "
- << "parser.getBuilder().getDenseI32ArrayAttr({";
- }
auto interleaveFn = [&](const NamedTypeConstraint &result) {
// If the result is variadic emit the parsed size.
if (result.isVariableLength())
@@ -1714,8 +1705,20 @@ void OperationFormat::genParserVariadicSegmentResolution(Operator &op,
else
body << "1";
};
- llvm::interleaveComma(op.getResults(), body, interleaveFn);
- body << "}));\n";
+ if (op.getDialect().usePropertiesForAttributes()) {
+ body << "llvm::copy(ArrayRef<int32_t>({";
+ llvm::interleaveComma(op.getResults(), body, interleaveFn);
+ body << formatv(
+ "}), "
+ "result.getOrAddProperties<{0}::Properties>().odsResultSegmentSizes"
+ ");\n",
+ op.getCppClassName());
+ } else {
+ body << " result.addAttribute(\"odsResultSegmentSizes\", "
+ << "parser.getBuilder().getDenseI32ArrayAttr({";
+ llvm::interleaveComma(op.getResults(), body, interleaveFn);
+ body << "}));\n";
+ }
}
}
diff --git a/mlir/unittests/IR/AdaptorTest.cpp b/mlir/unittests/IR/AdaptorTest.cpp
index fcd19addd8e3ad..e4dc4cda7608ac 100644
--- a/mlir/unittests/IR/AdaptorTest.cpp
+++ b/mlir/unittests/IR/AdaptorTest.cpp
@@ -39,7 +39,7 @@ TEST(Adaptor, GenericAdaptorsOperandAccess) {
// value from the value 0.
SmallVector<std::optional<int>> v = {0, 4};
OIListSimple::Properties prop;
- prop.operand_segment_sizes = builder.getDenseI32ArrayAttr({1, 0, 1});
+ llvm::copy(ArrayRef{1, 0, 1}, prop.odsOperandSegmentSizes);
OIListSimple::GenericAdaptor<ArrayRef<std::optional<int>>> d(v, {}, prop,
{});
EXPECT_EQ(d.getArg0(), 0);
diff --git a/mlir/unittests/IR/OpPropertiesTest.cpp b/mlir/unittests/IR/OpPropertiesTest.cpp
index eda84a0e0b075a..21ea4488e7190f 100644
--- a/mlir/unittests/IR/OpPropertiesTest.cpp
+++ b/mlir/unittests/IR/OpPropertiesTest.cpp
@@ -115,13 +115,15 @@ class OpWithProperties : public Op<OpWithProperties> {
// This alias is the only definition needed for enabling "properties" for this
// operation.
using Properties = TestProperties;
- static std::optional<mlir::Attribute> getInherentAttr(const Properties &prop,
+ static std::optional<mlir::Attribute> getInherentAttr(MLIRContext *context,
+ const Properties &prop,
StringRef name) {
return std::nullopt;
}
static void setInherentAttr(Properties &prop, StringRef name,
mlir::Attribute value) {}
- static void populateInherentAttrs(const Properties &prop,
+ static void populateInherentAttrs(MLIRContext *context,
+ const Properties &prop,
NamedAttrList &attrs) {}
static LogicalResult
verifyInherentAttrs(OperationName opName, NamedAttrList &attrs,
More information about the Mlir-commits
mailing list