[Mlir-commits] [mlir] 64ce74a - [mlir] Handle StringAttr in SparseElementsAttr::getZeroAttr.

Chia-hung Duan llvmlistbot at llvm.org
Thu Nov 4 11:20:28 PDT 2021


Author: Chia-hung Duan
Date: 2021-11-04T18:18:20Z
New Revision: 64ce74a6c8f23481f8062830e3ca7f38e171d74c

URL: https://github.com/llvm/llvm-project/commit/64ce74a6c8f23481f8062830e3ca7f38e171d74c
DIFF: https://github.com/llvm/llvm-project/commit/64ce74a6c8f23481f8062830e3ca7f38e171d74c.diff

LOG: [mlir] Handle StringAttr in SparseElementsAttr::getZeroAttr.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D111203

Added: 
    

Modified: 
    mlir/lib/IR/BuiltinAttributes.cpp
    mlir/unittests/IR/AttributeTest.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index 72891d995af6..2acc386d259e 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -1336,8 +1336,11 @@ Attribute SparseElementsAttr::getZeroAttr() const {
   if (eltType.isa<FloatType>())
     return FloatAttr::get(eltType, 0);
 
+  // Handle string type.
+  if (getValues().isa<DenseStringElementsAttr>())
+    return StringAttr::get("", eltType);
+
   // Otherwise, this is an integer.
-  // TODO: Handle StringAttr here.
   return IntegerAttr::get(eltType, 0);
 }
 

diff  --git a/mlir/unittests/IR/AttributeTest.cpp b/mlir/unittests/IR/AttributeTest.cpp
index 891abd1a4f23..aaff61e7d5f9 100644
--- a/mlir/unittests/IR/AttributeTest.cpp
+++ b/mlir/unittests/IR/AttributeTest.cpp
@@ -205,4 +205,50 @@ TEST(DenseScalarTest, ExtractZeroRankElement) {
   EXPECT_TRUE(attr.getValue({0}) == value);
 }
 
+TEST(SparseElementsAttrTest, GetZero) {
+  MLIRContext context;
+  context.allowUnregisteredDialects();
+
+  IntegerType intTy = IntegerType::get(&context, 32);
+  FloatType floatTy = FloatType::getF32(&context);
+  Type stringTy = OpaqueType::get(Identifier::get("test", &context), "string");
+
+  ShapedType tensorI32 = RankedTensorType::get({2, 2}, intTy);
+  ShapedType tensorF32 = RankedTensorType::get({2, 2}, floatTy);
+  ShapedType tensorString = RankedTensorType::get({2, 2}, stringTy);
+
+  auto indicesType =
+      RankedTensorType::get({1, 2}, IntegerType::get(&context, 64));
+  auto indices =
+      DenseIntElementsAttr::get(indicesType, {APInt(64, 0), APInt(64, 0)});
+
+  RankedTensorType intValueTy = RankedTensorType::get({1}, intTy);
+  auto intValue = DenseIntElementsAttr::get(intValueTy, {1});
+
+  RankedTensorType floatValueTy = RankedTensorType::get({1}, floatTy);
+  auto floatValue = DenseFPElementsAttr::get(floatValueTy, {1.0f});
+
+  RankedTensorType stringValueTy = RankedTensorType::get({1}, stringTy);
+  auto stringValue = DenseElementsAttr::get(stringValueTy, {StringRef("foo")});
+
+  auto sparseInt = SparseElementsAttr::get(tensorI32, indices, intValue);
+  auto sparseFloat = SparseElementsAttr::get(tensorF32, indices, floatValue);
+  auto sparseString =
+      SparseElementsAttr::get(tensorString, indices, stringValue);
+
+  // Only index (0, 0) contains an element, others are supposed to return
+  // the zero/empty value.
+  auto zeroIntValue = sparseInt.getValue({1, 1});
+  EXPECT_EQ(zeroIntValue.cast<IntegerAttr>().getInt(), 0);
+  EXPECT_TRUE(zeroIntValue.getType() == intTy);
+
+  auto zeroFloatValue = sparseFloat.getValue({1, 1});
+  EXPECT_EQ(zeroFloatValue.cast<FloatAttr>().getValueAsDouble(), 0.0f);
+  EXPECT_TRUE(zeroFloatValue.getType() == floatTy);
+
+  auto zeroStringValue = sparseString.getValue({1, 1});
+  EXPECT_TRUE(zeroStringValue.cast<StringAttr>().getValue().empty());
+  EXPECT_TRUE(zeroStringValue.getType() == stringTy);
+}
+
 } // end namespace


        


More information about the Mlir-commits mailing list