[Mlir-commits] [mlir] 0af2527 - Update ElementsAttr::isValidIndex to handle ElementsAttr with a scalar. Scalar will have rank 0.
Jacques Pienaar
llvmlistbot at llvm.org
Fri Jan 29 16:56:22 PST 2021
Author: karimnosseir
Date: 2021-01-29T16:56:00-08:00
New Revision: 0af25275364e27d9766eb0912a5dd9731d62936b
URL: https://github.com/llvm/llvm-project/commit/0af25275364e27d9766eb0912a5dd9731d62936b
DIFF: https://github.com/llvm/llvm-project/commit/0af25275364e27d9766eb0912a5dd9731d62936b.diff
LOG: Update ElementsAttr::isValidIndex to handle ElementsAttr with a scalar. Scalar will have rank 0.
Update ElementsAttr::isValidIndex to handle ElementsAttr with a scalar. Scalar will have rank 0.
Reviewed By: rriddle
Differential Revision: https://reviews.llvm.org/D95663
Added:
Modified:
mlir/lib/IR/BuiltinAttributes.cpp
mlir/unittests/IR/AttributeTest.cpp
Removed:
################################################################################
diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index f84d0af5c9a1..162bed96e3f4 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -459,6 +459,8 @@ bool ElementsAttr::isValidIndex(ArrayRef<uint64_t> index) const {
// Verify that the rank of the indices matches the held type.
auto rank = type.getRank();
+ if (rank == 0 && index.size() == 1 && index[0] == 0)
+ return true;
if (rank != static_cast<int64_t>(index.size()))
return false;
diff --git a/mlir/unittests/IR/AttributeTest.cpp b/mlir/unittests/IR/AttributeTest.cpp
index e1f603aec446..73acb7054754 100644
--- a/mlir/unittests/IR/AttributeTest.cpp
+++ b/mlir/unittests/IR/AttributeTest.cpp
@@ -192,4 +192,15 @@ TEST(DenseComplexTest, ComplexAPIntSplat) {
testSplat(complexType, value);
}
+TEST(DenseScalarTest, ExtractZeroRankElement) {
+ MLIRContext context;
+ const int elementValue = 12;
+ IntegerType intTy = IntegerType::get(&context, 32);
+ Attribute value = IntegerAttr::get(intTy, elementValue);
+ RankedTensorType shape = RankedTensorType::get({}, intTy);
+
+ auto attr = DenseElementsAttr::get(shape, llvm::makeArrayRef({elementValue}));
+ EXPECT_TRUE(attr.getValue({0}) == value);
+}
+
} // end namespace
More information about the Mlir-commits
mailing list