[Mlir-commits] [mlir] ae60a4a - [mlir] Fix DenseElementsAttr::mapValues(i1, splat).
Benjamin Kramer
llvmlistbot at llvm.org
Tue Sep 6 12:34:44 PDT 2022
Author: Chenguang Wang
Date: 2022-09-06T21:28:25+02:00
New Revision: ae60a4a0efff337425638d04005b33a73dc5792f
URL: https://github.com/llvm/llvm-project/commit/ae60a4a0efff337425638d04005b33a73dc5792f
DIFF: https://github.com/llvm/llvm-project/commit/ae60a4a0efff337425638d04005b33a73dc5792f.diff
LOG: [mlir] Fix DenseElementsAttr::mapValues(i1, splat).
Splat of bool is encoded as a byte with all-ones in it [1]. Without this
change, this piece of code:
auto xs = builder.getI32TensorAttr({42, 42, 42, 42});
auto xs2 = xs.mapValues(builder.getI1Type(), [](const llvm::APInt &x) {
return x.isZero() ? llvm::APInt::getZero(1) : llvm::APInt::getAllOnes(1);
});
xs2.dump();
Prints:
dense<[true, false, false, false]> : tensor<4xi1>
Because only the first bit is set. This applies to both
DenseIntElementsAttr::mapValues() and DenseFPElementsAttr::mapValues().
[1]: https://github.com/llvm/llvm-project/blob/e877b42e2c70813352c1963ea33e992f481d5cba/mlir/lib/IR/BuiltinAttributes.cpp#L984
Reviewed By: rriddle
Differential Revision: https://reviews.llvm.org/D132767
Added:
Modified:
mlir/lib/IR/BuiltinAttributes.cpp
mlir/unittests/IR/AttributeTest.cpp
Removed:
################################################################################
diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index 810672e3d7ed..22eff2dc34b9 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -1526,7 +1526,12 @@ static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType,
// Check for the splat case.
if (attr.isSplat()) {
- processElt(*attr.begin(), /*index=*/0);
+ if (bitWidth == 1) {
+ // Handle the special encoding of splat of bool.
+ data[0] = mapping(*attr.begin()).isZero() ? 0 : -1;
+ } else {
+ processElt(*attr.begin(), /*index=*/0);
+ }
return newArrayType;
}
diff --git a/mlir/unittests/IR/AttributeTest.cpp b/mlir/unittests/IR/AttributeTest.cpp
index e393b83df78d..cffac41bd953 100644
--- a/mlir/unittests/IR/AttributeTest.cpp
+++ b/mlir/unittests/IR/AttributeTest.cpp
@@ -209,6 +209,40 @@ TEST(DenseScalarTest, ExtractZeroRankElement) {
auto attr = DenseElementsAttr::get(shape, llvm::makeArrayRef({elementValue}));
EXPECT_TRUE(attr.getValues<Attribute>()[0] == value);
}
+
+TEST(DenseSplatMapValuesTest, I32ToTrue) {
+ MLIRContext context;
+ const int elementValue = 12;
+ IntegerType boolTy = IntegerType::get(&context, 1);
+ IntegerType intTy = IntegerType::get(&context, 32);
+ RankedTensorType shape = RankedTensorType::get({4}, intTy);
+
+ auto attr =
+ DenseElementsAttr::get(shape, llvm::makeArrayRef({elementValue}))
+ .mapValues(boolTy, [](const APInt &x) {
+ return x.isZero() ? APInt::getZero(1) : APInt::getAllOnes(1);
+ });
+ EXPECT_EQ(attr.getNumElements(), 4);
+ EXPECT_TRUE(attr.isSplat());
+ EXPECT_TRUE(attr.getSplatValue<BoolAttr>().getValue());
+}
+
+TEST(DenseSplatMapValuesTest, I32ToFalse) {
+ MLIRContext context;
+ const int elementValue = 0;
+ IntegerType boolTy = IntegerType::get(&context, 1);
+ IntegerType intTy = IntegerType::get(&context, 32);
+ RankedTensorType shape = RankedTensorType::get({4}, intTy);
+
+ auto attr =
+ DenseElementsAttr::get(shape, llvm::makeArrayRef({elementValue}))
+ .mapValues(boolTy, [](const APInt &x) {
+ return x.isZero() ? APInt::getZero(1) : APInt::getAllOnes(1);
+ });
+ EXPECT_EQ(attr.getNumElements(), 4);
+ EXPECT_TRUE(attr.isSplat());
+ EXPECT_FALSE(attr.getSplatValue<BoolAttr>().getValue());
+}
} // namespace
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list