[Mlir-commits] [mlir] a5c46bf - [mlir] Allow DenseElementsAttr to use any shaped type
Jeff Niu
llvmlistbot at llvm.org
Fri Sep 30 23:25:23 PDT 2022
Author: Jeff Niu
Date: 2022-09-30T23:25:14-07:00
New Revision: a5c46bf9521e34d7f8c6fa048014912afb910020
URL: https://github.com/llvm/llvm-project/commit/a5c46bf9521e34d7f8c6fa048014912afb910020
DIFF: https://github.com/llvm/llvm-project/commit/a5c46bf9521e34d7f8c6fa048014912afb910020.diff
LOG: [mlir] Allow DenseElementsAttr to use any shaped type
This patch allows the type of DenseElementsAttr to be any shaped type.
Reviewed By: rriddle
Differential Revision: https://reviews.llvm.org/D135002
Added:
Modified:
mlir/lib/AsmParser/AttributeParser.cpp
mlir/lib/IR/BuiltinAttributes.cpp
mlir/test/IR/invalid-builtin-attributes.mlir
Removed:
################################################################################
diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp
index 819c86c997f3c..1e79b44322616 100644
--- a/mlir/lib/AsmParser/AttributeParser.cpp
+++ b/mlir/lib/AsmParser/AttributeParser.cpp
@@ -1066,12 +1066,12 @@ ShapedType Parser::parseElementsLiteralType(Type type) {
return nullptr;
}
- if (!type.isa<RankedTensorType, VectorType>()) {
- emitError("elements literal must be a ranked tensor or vector type");
+ auto sType = type.dyn_cast<ShapedType>();
+ if (!sType) {
+ emitError("elements literal must be a shaped type");
return nullptr;
}
- auto sType = type.cast<ShapedType>();
if (!sType.hasStaticShape())
return (emitError("elements literal type must have static shape"), nullptr);
diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index 70c2b47f41721..ed22134d1dcc8 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -1381,8 +1381,6 @@ DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
ArrayRef<char> data) {
- assert((type.isa<RankedTensorType, VectorType>()) &&
- "type must be ranked tensor or vector");
assert(type.hasStaticShape() && "type must have static shape");
bool isSplat = false;
bool isValid = isValidRawBuffer(type, data, isSplat);
@@ -1498,16 +1496,7 @@ static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType,
size_t bitWidth = getDenseElementBitWidth(newElementType);
size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
- ShapedType newArrayType;
- if (inType.isa<RankedTensorType>())
- newArrayType = RankedTensorType::get(inType.getShape(), newElementType);
- else if (inType.isa<UnrankedTensorType>())
- newArrayType = RankedTensorType::get(inType.getShape(), newElementType);
- else if (auto vType = inType.dyn_cast<VectorType>())
- newArrayType = VectorType::get(vType.getShape(), newElementType,
- vType.getNumScalableDims());
- else
- assert(newArrayType && "Unhandled tensor type");
+ ShapedType newArrayType = inType.cloneWith(inType.getShape(), newElementType);
size_t numRawElements = attr.isSplat() ? 1 : newArrayType.getNumElements();
data.resize(llvm::divideCeil(storageBitWidth * numRawElements, CHAR_BIT));
diff --git a/mlir/test/IR/invalid-builtin-attributes.mlir b/mlir/test/IR/invalid-builtin-attributes.mlir
index 49acce2cf1187..8e57afa41ba88 100644
--- a/mlir/test/IR/invalid-builtin-attributes.mlir
+++ b/mlir/test/IR/invalid-builtin-attributes.mlir
@@ -1,7 +1,7 @@
// RUN: mlir-opt -allow-unregistered-dialect %s -split-input-file -verify-diagnostics
func.func @elementsattr_non_tensor_type() -> () {
- "foo"(){bar = dense<[4]> : i32} : () -> () // expected-error {{elements literal must be a ranked tensor or vector type}}
+ "foo"(){bar = dense<[4]> : i32} : () -> () // expected-error {{elements literal must be a shaped type}}
}
// -----
More information about the Mlir-commits
mailing list