[Mlir-commits] [mlir] 0af2527 - Update ElementsAttr::isValidIndex to handle ElementsAttr with a scalar. Scalar will have rank 0.

Jacques Pienaar llvmlistbot at llvm.org
Fri Jan 29 16:56:22 PST 2021


Author: karimnosseir
Date: 2021-01-29T16:56:00-08:00
New Revision: 0af25275364e27d9766eb0912a5dd9731d62936b

URL: https://github.com/llvm/llvm-project/commit/0af25275364e27d9766eb0912a5dd9731d62936b
DIFF: https://github.com/llvm/llvm-project/commit/0af25275364e27d9766eb0912a5dd9731d62936b.diff

LOG: Update ElementsAttr::isValidIndex to handle ElementsAttr with a scalar. Scalar will have rank 0.

Update ElementsAttr::isValidIndex to handle ElementsAttr with a scalar. Scalar will have rank 0.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D95663

Added: 
    

Modified: 
    mlir/lib/IR/BuiltinAttributes.cpp
    mlir/unittests/IR/AttributeTest.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index f84d0af5c9a1..162bed96e3f4 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -459,6 +459,8 @@ bool ElementsAttr::isValidIndex(ArrayRef<uint64_t> index) const {
 
   // Verify that the rank of the indices matches the held type.
   auto rank = type.getRank();
+  if (rank == 0 && index.size() == 1 && index[0] == 0)
+    return true;
   if (rank != static_cast<int64_t>(index.size()))
     return false;
 

diff  --git a/mlir/unittests/IR/AttributeTest.cpp b/mlir/unittests/IR/AttributeTest.cpp
index e1f603aec446..73acb7054754 100644
--- a/mlir/unittests/IR/AttributeTest.cpp
+++ b/mlir/unittests/IR/AttributeTest.cpp
@@ -192,4 +192,15 @@ TEST(DenseComplexTest, ComplexAPIntSplat) {
   testSplat(complexType, value);
 }
 
+TEST(DenseScalarTest, ExtractZeroRankElement) {
+  MLIRContext context;
+  const int elementValue = 12;
+  IntegerType intTy = IntegerType::get(&context, 32);
+  Attribute value = IntegerAttr::get(intTy, elementValue);
+  RankedTensorType shape = RankedTensorType::get({}, intTy);
+
+  auto attr = DenseElementsAttr::get(shape, llvm::makeArrayRef({elementValue}));
+  EXPECT_TRUE(attr.getValue({0}) == value);
+}
+
 } // end namespace


        


More information about the Mlir-commits mailing list