[Mlir-commits] [mlir] 8db947d - [mlir] Make ElementsAttr inherit from TypedAttr
Rahul Kayaith
llvmlistbot at llvm.org
Thu Apr 20 13:32:03 PDT 2023
Author: Rahul Kayaith
Date: 2023-04-20T16:31:53-04:00
New Revision: 8db947da0b377a13cd0e536b1bbc825abff3c001
URL: https://github.com/llvm/llvm-project/commit/8db947da0b377a13cd0e536b1bbc825abff3c001
DIFF: https://github.com/llvm/llvm-project/commit/8db947da0b377a13cd0e536b1bbc825abff3c001.diff
LOG: [mlir] Make ElementsAttr inherit from TypedAttr
This allows implicit conversion from `ElementsAttr` to `TypedAttr`, but
required renaming the `ElementsAttr::getType()` interface method to
`getShapedType`.
Reviewed By: rriddle
Differential Revision: https://reviews.llvm.org/D148492
Added:
Modified:
mlir/include/mlir/Dialect/CommonFolders.h
mlir/include/mlir/IR/BuiltinAttributeInterfaces.td
mlir/include/mlir/IR/BuiltinAttributes.td
mlir/lib/IR/BuiltinAttributeInterfaces.cpp
mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
mlir/lib/Transforms/ViewOpGraph.cpp
mlir/test/lib/Dialect/Test/TestAttrDefs.td
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/CommonFolders.h b/mlir/include/mlir/Dialect/CommonFolders.h
index 4f24580d02a70..5633b4ccd7cf6 100644
--- a/mlir/include/mlir/Dialect/CommonFolders.h
+++ b/mlir/include/mlir/Dialect/CommonFolders.h
@@ -189,7 +189,7 @@ Attribute constFoldUnaryOpConditional(ArrayRef<Attribute> operands,
return {};
elementResults.push_back(*elementResult);
}
- return DenseElementsAttr::get(op.getType(), elementResults);
+ return DenseElementsAttr::get(op.getShapedType(), elementResults);
}
return {};
}
diff --git a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td
index 19336e0483860..39c871d7cab15 100644
--- a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td
@@ -15,11 +15,30 @@
include "mlir/IR/OpBase.td"
+//===----------------------------------------------------------------------===//
+// TypedAttrInterface
+//===----------------------------------------------------------------------===//
+
+def TypedAttrInterface : AttrInterface<"TypedAttr"> {
+ let cppNamespace = "::mlir";
+
+ let description = [{
+ This interface is used for attributes that have a type. The type of an
+ attribute is understood to represent the type of the data contained in the
+ attribute and is often used as the type of a value with this data.
+ }];
+
+ let methods = [InterfaceMethod<
+ "Get the attribute's type",
+ "::mlir::Type", "getType"
+ >];
+}
+
//===----------------------------------------------------------------------===//
// ElementsAttrInterface
//===----------------------------------------------------------------------===//
-def ElementsAttrInterface : AttrInterface<"ElementsAttr"> {
+def ElementsAttrInterface : AttrInterface<"ElementsAttr", [TypedAttrInterface]> {
let cppNamespace = "::mlir";
let description = [{
This interface is used for attributes that contain the constant elements of
@@ -78,7 +97,7 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> {
}
/// * Attribute
auto value_begin_impl(OverloadToken<mlir::Attribute>) const {
- mlir::Type elementType = getType().getElementType();
+ mlir::Type elementType = getShapedType().getElementType();
auto it = llvm::map_range(getElements(), [=](uint64_t value) {
return mlir::IntegerAttr::get(elementType,
llvm::APInt(/*numBits=*/64, value));
@@ -154,13 +173,15 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> {
InterfaceMethod<[{
Returns true if the attribute elements correspond to a splat, i.e. that
all elements of the attribute are the same value.
- }], "bool", "isSplat", (ins), /*defaultImplementation=*/[{}], [{
+ }], "bool", "isSplat", (ins), [{}], /*defaultImplementation=*/[{
// By default, only check for a single element splat.
return $_attr.getNumElements() == 1;
}]>,
InterfaceMethod<[{
Returns the shaped type of the elements attribute.
- }], "::mlir::ShapedType", "getType">
+ }], "::mlir::ShapedType", "getShapedType", (ins), [{}], /*defaultImplementation=*/[{
+ return $_attr.getType();
+ }]>
];
string ElementsAttrInterfaceAccessors = [{
@@ -325,7 +346,7 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> {
ArrayRef<uint64_t> index);
static uint64_t getFlattenedIndex(ElementsAttr elementsAttr,
ArrayRef<uint64_t> index) {
- return getFlattenedIndex(elementsAttr.getType(), index);
+ return getFlattenedIndex(elementsAttr.getShapedType(), index);
}
/// Returns the number of elements held by this attribute.
@@ -357,7 +378,7 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> {
/// Return the elements of this attribute as a value of type 'T'.
template <typename T>
DefaultValueCheckT<T, iterator_range<T>> getValues() const {
- return {getType(), value_begin<T>(), value_end<T>()};
+ return {getShapedType(), value_begin<T>(), value_end<T>()};
}
template <typename T>
DefaultValueCheckT<T, iterator<T>> value_begin() const;
@@ -377,7 +398,7 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> {
template <typename T, typename = DerivedAttrValueCheckT<T>>
DerivedAttrValueIteratorRange<T> getValues() const {
auto castFn = [](Attribute attr) { return attr.template cast<T>(); };
- return {getType(), llvm::map_range(getValues<Attribute>(),
+ return {getShapedType(), llvm::map_range(getValues<Attribute>(),
static_cast<T (*)(Attribute)>(castFn))};
}
template <typename T, typename = DerivedAttrValueCheckT<T>>
@@ -397,7 +418,7 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> {
template <typename T>
DefaultValueCheckT<T, std::optional<iterator_range<T>>> tryGetValues() const {
if (std::optional<iterator<T>> beginIt = try_value_begin<T>())
- return iterator_range<T>(getType(), *beginIt, value_end<T>());
+ return iterator_range<T>(getShapedType(), *beginIt, value_end<T>());
return std::nullopt;
}
template <typename T>
@@ -413,7 +434,7 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> {
auto castFn = [](Attribute attr) { return attr.template cast<T>(); };
return DerivedAttrValueIteratorRange<T>(
- getType(),
+ getShapedType(),
llvm::map_range(*values, static_cast<T (*)(Attribute)>(castFn))
);
}
@@ -474,23 +495,4 @@ def MemRefLayoutAttrInterface : AttrInterface<"MemRefLayoutAttrInterface"> {
];
}
-//===----------------------------------------------------------------------===//
-// TypedAttrInterface
-//===----------------------------------------------------------------------===//
-
-def TypedAttrInterface : AttrInterface<"TypedAttr"> {
- let cppNamespace = "::mlir";
-
- let description = [{
- This interface is used for attributes that have a type. The type of an
- attribute is understood to represent the type of the data contained in the
- attribute and is often used as the type of a value with this data.
- }];
-
- let methods = [InterfaceMethod<
- "Get the attribute's type",
- "::mlir::Type", "getType"
- >];
-}
-
#endif // MLIR_IR_BUILTINATTRIBUTEINTERFACES_TD_
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td
index 446bc375d3fa6..01ab1a765aae0 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.td
+++ b/mlir/include/mlir/IR/BuiltinAttributes.td
@@ -218,7 +218,7 @@ def Builtin_DenseArray : Builtin_Attr<"DenseArray"> {
//===----------------------------------------------------------------------===//
def Builtin_DenseIntOrFPElementsAttr : Builtin_Attr<
- "DenseIntOrFPElements", [ElementsAttrInterface, TypedAttrInterface],
+ "DenseIntOrFPElements", [ElementsAttrInterface],
"DenseElementsAttr"
> {
let summary = "An Attribute containing a dense multi-dimensional array of "
@@ -359,7 +359,7 @@ def Builtin_DenseIntOrFPElementsAttr : Builtin_Attr<
//===----------------------------------------------------------------------===//
def Builtin_DenseStringElementsAttr : Builtin_Attr<
- "DenseStringElements", [ElementsAttrInterface, TypedAttrInterface],
+ "DenseStringElements", [ElementsAttrInterface],
"DenseElementsAttr"
> {
let summary = "An Attribute containing a dense multi-dimensional array of "
@@ -430,7 +430,7 @@ def Builtin_DenseStringElementsAttr : Builtin_Attr<
//===----------------------------------------------------------------------===//
def Builtin_DenseResourceElementsAttr : Builtin_Attr<"DenseResourceElements", [
- ElementsAttrInterface, TypedAttrInterface
+ ElementsAttrInterface
]> {
let summary = "An Attribute containing a dense multi-dimensional array "
"backed by a resource";
@@ -804,7 +804,7 @@ def Builtin_OpaqueAttr : Builtin_Attr<"Opaque", [TypedAttrInterface]> {
//===----------------------------------------------------------------------===//
def Builtin_SparseElementsAttr : Builtin_Attr<
- "SparseElements", [ElementsAttrInterface, TypedAttrInterface]
+ "SparseElements", [ElementsAttrInterface]
> {
let summary = "An opaque representation of a multi-dimensional array";
let description = [{
diff --git a/mlir/lib/IR/BuiltinAttributeInterfaces.cpp b/mlir/lib/IR/BuiltinAttributeInterfaces.cpp
index 6e35f120e22c4..ab216baede61c 100644
--- a/mlir/lib/IR/BuiltinAttributeInterfaces.cpp
+++ b/mlir/lib/IR/BuiltinAttributeInterfaces.cpp
@@ -25,11 +25,11 @@ using namespace mlir::detail;
//===----------------------------------------------------------------------===//
Type ElementsAttr::getElementType(ElementsAttr elementsAttr) {
- return elementsAttr.getType().getElementType();
+ return elementsAttr.getShapedType().getElementType();
}
int64_t ElementsAttr::getNumElements(ElementsAttr elementsAttr) {
- return elementsAttr.getType().getNumElements();
+ return elementsAttr.getShapedType().getNumElements();
}
bool ElementsAttr::isValidIndex(ShapedType type, ArrayRef<uint64_t> index) {
@@ -49,7 +49,7 @@ bool ElementsAttr::isValidIndex(ShapedType type, ArrayRef<uint64_t> index) {
}
bool ElementsAttr::isValidIndex(ElementsAttr elementsAttr,
ArrayRef<uint64_t> index) {
- return isValidIndex(elementsAttr.getType(), index);
+ return isValidIndex(elementsAttr.getShapedType(), index);
}
uint64_t ElementsAttr::getFlattenedIndex(Type type, ArrayRef<uint64_t> index) {
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index a08bb372fc738..66891afaf5d3b 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -407,8 +407,8 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
// Fall back to element-by-element construction otherwise.
if (auto elementsAttr = attr.dyn_cast<ElementsAttr>()) {
- assert(elementsAttr.getType().hasStaticShape());
- assert(!elementsAttr.getType().getShape().empty() &&
+ assert(elementsAttr.getShapedType().hasStaticShape());
+ assert(!elementsAttr.getShapedType().getShape().empty() &&
"unexpected empty elements attribute shape");
SmallVector<llvm::Constant *, 8> constants;
@@ -422,7 +422,7 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
}
ArrayRef<llvm::Constant *> constantsRef = constants;
llvm::Constant *result = buildSequentialConstant(
- constantsRef, elementsAttr.getType().getShape(), llvmType, loc);
+ constantsRef, elementsAttr.getShapedType().getShape(), llvmType, loc);
assert(constantsRef.empty() && "did not consume all elemental constants");
return result;
}
diff --git a/mlir/lib/Transforms/ViewOpGraph.cpp b/mlir/lib/Transforms/ViewOpGraph.cpp
index bdefbe3d2abd5..4598b56a901d5 100644
--- a/mlir/lib/Transforms/ViewOpGraph.cpp
+++ b/mlir/lib/Transforms/ViewOpGraph.cpp
@@ -153,8 +153,8 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
// Elide "big" elements attributes.
auto elements = attr.dyn_cast<ElementsAttr>();
if (elements && elements.getNumElements() > largeAttrLimit) {
- os << std::string(elements.getType().getRank(), '[') << "..."
- << std::string(elements.getType().getRank(), ']') << " : "
+ os << std::string(elements.getShapedType().getRank(), '[') << "..."
+ << std::string(elements.getShapedType().getRank(), ']') << " : "
<< elements.getType();
return;
}
diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
index ba9909d08ae30..ec0a5548a1603 100644
--- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
@@ -81,9 +81,7 @@ def AttrWithTrait : Test_Attr<"AttrWithTrait", [TestAttrTrait]> {
}
// Test support for ElementsAttrInterface.
-def TestI64ElementsAttr : Test_Attr<"TestI64Elements", [
- ElementsAttrInterface, TypedAttrInterface
- ]> {
+def TestI64ElementsAttr : Test_Attr<"TestI64Elements", [ElementsAttrInterface]> {
let mnemonic = "i64_elements";
let parameters = (ins
AttributeSelfTypeParameter<"", "::mlir::ShapedType">:$type,
@@ -269,9 +267,7 @@ def TestOverrideBuilderAttr : Test_Attr<"TestOverrideBuilder"> {
}
// Test simple extern 1D vector using ElementsAttrInterface.
-def TestExtern1DI64ElementsAttr : Test_Attr<"TestExtern1DI64Elements", [
- ElementsAttrInterface, TypedAttrInterface
- ]> {
+def TestExtern1DI64ElementsAttr : Test_Attr<"TestExtern1DI64Elements", [ElementsAttrInterface]> {
let mnemonic = "e1di64_elements";
let parameters = (ins
AttributeSelfTypeParameter<"", "::mlir::ShapedType">:$type,
More information about the Mlir-commits
mailing list