[Mlir-commits] [mlir] aac844a - [mlir] Improve bitEnumContains methods.
Hanhan Wang
llvmlistbot at llvm.org
Fri Sep 9 11:56:59 PDT 2022
Author: Hanhan Wang
Date: 2022-09-09T11:56:36-07:00
New Revision: aac844a4b1577947285b74d83a8e9740b9ab337b
URL: https://github.com/llvm/llvm-project/commit/aac844a4b1577947285b74d83a8e9740b9ab337b
DIFF: https://github.com/llvm/llvm-project/commit/aac844a4b1577947285b74d83a8e9740b9ab337b.diff
LOG: [mlir] Improve bitEnumContains methods.
https://github.com/llvm/llvm-project/commit/839b436c93604e042f74050cf2adadd75f30e898
changes the behavior. Based on the discussion, we also want to support
"and" behavior. The revision changes it into two functions, bitEnumContainsAny
and bitEnumContainsAll.
Reviewed By: krzysz00, antiagainst
Differential Revision: https://reviews.llvm.org/D133507
Added:
Modified:
mlir/docs/OpDefinitions.md
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
mlir/tools/mlir-tblgen/EnumsGen.cpp
mlir/unittests/TableGen/EnumsGenTest.cpp
Removed:
################################################################################
diff --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md
index 7dd583426f7a..fe8dfc0db15b 100644
--- a/mlir/docs/OpDefinitions.md
+++ b/mlir/docs/OpDefinitions.md
@@ -1442,9 +1442,12 @@ inline constexpr MyBitEnum operator~(MyBitEnum bits) {
// Ensure only bits that can be present in the enum are set
return static_cast<MyBitEnum>(~static_cast<uint32_t>(bits) & static_cast<uint32_t>(15u));
}
-inline constexpr bool bitEnumContains(MyBitEnum bits, MyBitEnum bit) {
+inline constexpr bool bitEnumContainsAll(MyBitEnum bits, MyBitEnum bit) {
return (bits & bit) == bit;
}
+inline constexpr bool bitEnumContainsAny(MyBitEnum bits, MyBitEnum bit) {
+ return (static_cast<uint32_t>(bits) & static_cast<uint32_t>(bit)) != 0;
+}
inline constexpr MyBitEnum bitEnumClear(MyBitEnum bits, MyBitEnum bit) {
return bits & ~bit;
}
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index bfdc24490316..4b5403025a34 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -260,7 +260,8 @@ static ParseResult parseMemoryAccessAttributes(OpAsmParser &parser,
kMemoryAccessAttrName))
return failure();
- if (spirv::bitEnumContains(memoryAccessAttr, spirv::MemoryAccess::Aligned)) {
+ if (spirv::bitEnumContainsAll(memoryAccessAttr,
+ spirv::MemoryAccess::Aligned)) {
// Parse integer attribute for alignment.
Attribute alignmentAttr;
Type i32Type = parser.getBuilder().getIntegerType(32);
@@ -290,7 +291,8 @@ static ParseResult parseSourceMemoryAccessAttributes(OpAsmParser &parser,
kSourceMemoryAccessAttrName))
return failure();
- if (spirv::bitEnumContains(memoryAccessAttr, spirv::MemoryAccess::Aligned)) {
+ if (spirv::bitEnumContainsAll(memoryAccessAttr,
+ spirv::MemoryAccess::Aligned)) {
// Parse integer attribute for alignment.
Attribute alignmentAttr;
Type i32Type = parser.getBuilder().getIntegerType(32);
@@ -316,7 +318,7 @@ static void printMemoryAccessAttribute(
printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
- if (spirv::bitEnumContains(*memAccess, spirv::MemoryAccess::Aligned)) {
+ if (spirv::bitEnumContainsAll(*memAccess, spirv::MemoryAccess::Aligned)) {
// Print integer alignment attribute.
if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
: memoryOp.alignment())) {
@@ -349,7 +351,7 @@ static void printSourceMemoryAccessAttribute(
printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
- if (spirv::bitEnumContains(*memAccess, spirv::MemoryAccess::Aligned)) {
+ if (spirv::bitEnumContainsAll(*memAccess, spirv::MemoryAccess::Aligned)) {
// Print integer alignment attribute.
if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
: memoryOp.alignment())) {
@@ -407,7 +409,7 @@ static LogicalResult verifyImageOperands(Op imageOp,
spirv::ImageOperands::MakeTexelVisible |
spirv::ImageOperands::SignExtend | spirv::ImageOperands::ZeroExtend;
- if (spirv::bitEnumContains(attr.getValue(), noSupportOperands))
+ if (spirv::bitEnumContainsAll(attr.getValue(), noSupportOperands))
llvm_unreachable("unimplemented operands of Image Operands");
return success();
@@ -491,8 +493,8 @@ static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp) {
<< memAccessAttr;
}
- if (spirv::bitEnumContains(memAccess.getValue(),
- spirv::MemoryAccess::Aligned)) {
+ if (spirv::bitEnumContainsAll(memAccess.getValue(),
+ spirv::MemoryAccess::Aligned)) {
if (!op->getAttr(kAlignmentAttrName)) {
return memoryOp.emitOpError("missing alignment value");
}
@@ -535,8 +537,8 @@ static LogicalResult verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp) {
<< memAccess;
}
- if (spirv::bitEnumContains(memAccess.getValue(),
- spirv::MemoryAccess::Aligned)) {
+ if (spirv::bitEnumContainsAll(memAccess.getValue(),
+ spirv::MemoryAccess::Aligned)) {
if (!op->getAttr(kSourceAlignmentAttrName)) {
return memoryOp.emitOpError("missing alignment value");
}
diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
index d968cb8d3029..e81c052e6de0 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
@@ -162,7 +162,7 @@ static llvm::FastMathFlags getFastmathFlags(FastmathFlagsInterface &op) {
llvm::FastMathFlags ret;
auto fmf = op.getFastmathFlags();
for (auto it : handlers)
- if (bitEnumContains(fmf, it.first))
+ if (bitEnumContainsAll(fmf, it.first))
(ret.*(it.second))(true);
return ret;
}
diff --git a/mlir/tools/mlir-tblgen/EnumsGen.cpp b/mlir/tools/mlir-tblgen/EnumsGen.cpp
index 19dcd31932da..a17e927bdbf6 100644
--- a/mlir/tools/mlir-tblgen/EnumsGen.cpp
+++ b/mlir/tools/mlir-tblgen/EnumsGen.cpp
@@ -138,7 +138,8 @@ getAllBitsUnsetCase(llvm::ArrayRef<EnumAttrCase> cases) {
// inline constexpr <enum-type> operator&(<enum-type> a, <enum-type> b);
// inline constexpr <enum-type> operator^(<enum-type> a, <enum-type> b);
// inline constexpr <enum-type> operator~(<enum-type> bits);
-// inline constexpr bool bitEnumContains(<enum-type> bits, <enum-type> bit);
+// inline constexpr bool bitEnumContainsAll(<enum-type> bits, <enum-type> bit);
+// inline constexpr bool bitEnumContainsAny(<enum-type> bits, <enum-type> bit);
// inline constexpr <enum-type> bitEnumClear(<enum-type> bits, <enum-type> bit);
// inline constexpr <enum-type> bitEnumSet(<enum-type> bits, <enum-type> bit,
// bool value=true);
@@ -161,9 +162,12 @@ inline constexpr {0} operator~({0} bits) {{
// Ensure only bits that can be present in the enum are set
return static_cast<{0}>(~static_cast<{1}>(bits) & static_cast<{1}>({2}u));
}
-inline constexpr bool bitEnumContains({0} bits, {0} bit) {{
+inline constexpr bool bitEnumContainsAll({0} bits, {0} bit) {{
return (bits & bit) == bit;
}
+inline constexpr bool bitEnumContainsAny({0} bits, {0} bit) {{
+ return (static_cast<{1}>(bits) & static_cast<{1}>(bit)) != 0;
+}
inline constexpr {0} bitEnumClear({0} bits, {0} bit) {{
return bits & ~bit;
}
diff --git a/mlir/unittests/TableGen/EnumsGenTest.cpp b/mlir/unittests/TableGen/EnumsGenTest.cpp
index d971a9b9e0ab..95df2344fc8f 100644
--- a/mlir/unittests/TableGen/EnumsGenTest.cpp
+++ b/mlir/unittests/TableGen/EnumsGenTest.cpp
@@ -142,10 +142,10 @@ TEST(EnumsGenTest, GeneratedSymbolToStringFnForPrimaryGroupBitEnum) {
}
TEST(EnumsGenTest, GeneratedOperator) {
- EXPECT_TRUE(bitEnumContains(BitEnumWithNone::Bit0 | BitEnumWithNone::Bit3,
- BitEnumWithNone::Bit0));
- EXPECT_FALSE(bitEnumContains(BitEnumWithNone::Bit0 & BitEnumWithNone::Bit3,
- BitEnumWithNone::Bit0));
+ EXPECT_TRUE(bitEnumContainsAll(BitEnumWithNone::Bit0 | BitEnumWithNone::Bit3,
+ BitEnumWithNone::Bit0));
+ EXPECT_FALSE(bitEnumContainsAll(BitEnumWithNone::Bit0 & BitEnumWithNone::Bit3,
+ BitEnumWithNone::Bit0));
}
TEST(EnumsGenTest, GeneratedSymbolToCustomStringFn) {
More information about the Mlir-commits
mailing list