[Mlir-commits] [mlir] 64ce74a - [mlir] Handle StringAttr in SparseElementsAttr::getZeroAttr.
Chia-hung Duan
llvmlistbot at llvm.org
Thu Nov 4 11:20:28 PDT 2021
Author: Chia-hung Duan
Date: 2021-11-04T18:18:20Z
New Revision: 64ce74a6c8f23481f8062830e3ca7f38e171d74c
URL: https://github.com/llvm/llvm-project/commit/64ce74a6c8f23481f8062830e3ca7f38e171d74c
DIFF: https://github.com/llvm/llvm-project/commit/64ce74a6c8f23481f8062830e3ca7f38e171d74c.diff
LOG: [mlir] Handle StringAttr in SparseElementsAttr::getZeroAttr.
Reviewed By: rriddle
Differential Revision: https://reviews.llvm.org/D111203
Added:
Modified:
mlir/lib/IR/BuiltinAttributes.cpp
mlir/unittests/IR/AttributeTest.cpp
Removed:
################################################################################
diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index 72891d995af6..2acc386d259e 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -1336,8 +1336,11 @@ Attribute SparseElementsAttr::getZeroAttr() const {
if (eltType.isa<FloatType>())
return FloatAttr::get(eltType, 0);
+ // Handle string type.
+ if (getValues().isa<DenseStringElementsAttr>())
+ return StringAttr::get("", eltType);
+
// Otherwise, this is an integer.
- // TODO: Handle StringAttr here.
return IntegerAttr::get(eltType, 0);
}
diff --git a/mlir/unittests/IR/AttributeTest.cpp b/mlir/unittests/IR/AttributeTest.cpp
index 891abd1a4f23..aaff61e7d5f9 100644
--- a/mlir/unittests/IR/AttributeTest.cpp
+++ b/mlir/unittests/IR/AttributeTest.cpp
@@ -205,4 +205,50 @@ TEST(DenseScalarTest, ExtractZeroRankElement) {
EXPECT_TRUE(attr.getValue({0}) == value);
}
+TEST(SparseElementsAttrTest, GetZero) {
+ MLIRContext context;
+ context.allowUnregisteredDialects();
+
+ IntegerType intTy = IntegerType::get(&context, 32);
+ FloatType floatTy = FloatType::getF32(&context);
+ Type stringTy = OpaqueType::get(Identifier::get("test", &context), "string");
+
+ ShapedType tensorI32 = RankedTensorType::get({2, 2}, intTy);
+ ShapedType tensorF32 = RankedTensorType::get({2, 2}, floatTy);
+ ShapedType tensorString = RankedTensorType::get({2, 2}, stringTy);
+
+ auto indicesType =
+ RankedTensorType::get({1, 2}, IntegerType::get(&context, 64));
+ auto indices =
+ DenseIntElementsAttr::get(indicesType, {APInt(64, 0), APInt(64, 0)});
+
+ RankedTensorType intValueTy = RankedTensorType::get({1}, intTy);
+ auto intValue = DenseIntElementsAttr::get(intValueTy, {1});
+
+ RankedTensorType floatValueTy = RankedTensorType::get({1}, floatTy);
+ auto floatValue = DenseFPElementsAttr::get(floatValueTy, {1.0f});
+
+ RankedTensorType stringValueTy = RankedTensorType::get({1}, stringTy);
+ auto stringValue = DenseElementsAttr::get(stringValueTy, {StringRef("foo")});
+
+ auto sparseInt = SparseElementsAttr::get(tensorI32, indices, intValue);
+ auto sparseFloat = SparseElementsAttr::get(tensorF32, indices, floatValue);
+ auto sparseString =
+ SparseElementsAttr::get(tensorString, indices, stringValue);
+
+ // Only index (0, 0) contains an element, others are supposed to return
+ // the zero/empty value.
+ auto zeroIntValue = sparseInt.getValue({1, 1});
+ EXPECT_EQ(zeroIntValue.cast<IntegerAttr>().getInt(), 0);
+ EXPECT_TRUE(zeroIntValue.getType() == intTy);
+
+ auto zeroFloatValue = sparseFloat.getValue({1, 1});
+ EXPECT_EQ(zeroFloatValue.cast<FloatAttr>().getValueAsDouble(), 0.0f);
+ EXPECT_TRUE(zeroFloatValue.getType() == floatTy);
+
+ auto zeroStringValue = sparseString.getValue({1, 1});
+ EXPECT_TRUE(zeroStringValue.cast<StringAttr>().getValue().empty());
+ EXPECT_TRUE(zeroStringValue.getType() == stringTy);
+}
+
} // end namespace
More information about the Mlir-commits
mailing list