[Mlir-commits] [mlir] aac844a - [mlir] Improve bitEnumContains methods.

Hanhan Wang llvmlistbot at llvm.org
Fri Sep 9 11:56:59 PDT 2022


Author: Hanhan Wang
Date: 2022-09-09T11:56:36-07:00
New Revision: aac844a4b1577947285b74d83a8e9740b9ab337b

URL: https://github.com/llvm/llvm-project/commit/aac844a4b1577947285b74d83a8e9740b9ab337b
DIFF: https://github.com/llvm/llvm-project/commit/aac844a4b1577947285b74d83a8e9740b9ab337b.diff

LOG: [mlir] Improve bitEnumContains methods.

https://github.com/llvm/llvm-project/commit/839b436c93604e042f74050cf2adadd75f30e898
changes the behavior. Based on the discussion, we also want to support
"and" behavior. The revision changes it into two functions, bitEnumContainsAny
and bitEnumContainsAll.

Reviewed By: krzysz00, antiagainst

Differential Revision: https://reviews.llvm.org/D133507

Added: 
    

Modified: 
    mlir/docs/OpDefinitions.md
    mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
    mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
    mlir/tools/mlir-tblgen/EnumsGen.cpp
    mlir/unittests/TableGen/EnumsGenTest.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md
index 7dd583426f7a..fe8dfc0db15b 100644
--- a/mlir/docs/OpDefinitions.md
+++ b/mlir/docs/OpDefinitions.md
@@ -1442,9 +1442,12 @@ inline constexpr MyBitEnum operator~(MyBitEnum bits) {
   // Ensure only bits that can be present in the enum are set
   return static_cast<MyBitEnum>(~static_cast<uint32_t>(bits) & static_cast<uint32_t>(15u));
 }
-inline constexpr bool bitEnumContains(MyBitEnum bits, MyBitEnum bit) {
+inline constexpr bool bitEnumContainsAll(MyBitEnum bits, MyBitEnum bit) {
   return (bits & bit) == bit;
 }
+inline constexpr bool bitEnumContainsAny(MyBitEnum bits, MyBitEnum bit) {
+  return (static_cast<uint32_t>(bits) & static_cast<uint32_t>(bit)) != 0;
+}
 inline constexpr MyBitEnum bitEnumClear(MyBitEnum bits, MyBitEnum bit) {
   return bits & ~bit;
 }

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index bfdc24490316..4b5403025a34 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -260,7 +260,8 @@ static ParseResult parseMemoryAccessAttributes(OpAsmParser &parser,
                                                 kMemoryAccessAttrName))
     return failure();
 
-  if (spirv::bitEnumContains(memoryAccessAttr, spirv::MemoryAccess::Aligned)) {
+  if (spirv::bitEnumContainsAll(memoryAccessAttr,
+                                spirv::MemoryAccess::Aligned)) {
     // Parse integer attribute for alignment.
     Attribute alignmentAttr;
     Type i32Type = parser.getBuilder().getIntegerType(32);
@@ -290,7 +291,8 @@ static ParseResult parseSourceMemoryAccessAttributes(OpAsmParser &parser,
                                                 kSourceMemoryAccessAttrName))
     return failure();
 
-  if (spirv::bitEnumContains(memoryAccessAttr, spirv::MemoryAccess::Aligned)) {
+  if (spirv::bitEnumContainsAll(memoryAccessAttr,
+                                spirv::MemoryAccess::Aligned)) {
     // Parse integer attribute for alignment.
     Attribute alignmentAttr;
     Type i32Type = parser.getBuilder().getIntegerType(32);
@@ -316,7 +318,7 @@ static void printMemoryAccessAttribute(
 
     printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
 
-    if (spirv::bitEnumContains(*memAccess, spirv::MemoryAccess::Aligned)) {
+    if (spirv::bitEnumContainsAll(*memAccess, spirv::MemoryAccess::Aligned)) {
       // Print integer alignment attribute.
       if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
                                                : memoryOp.alignment())) {
@@ -349,7 +351,7 @@ static void printSourceMemoryAccessAttribute(
 
     printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
 
-    if (spirv::bitEnumContains(*memAccess, spirv::MemoryAccess::Aligned)) {
+    if (spirv::bitEnumContainsAll(*memAccess, spirv::MemoryAccess::Aligned)) {
       // Print integer alignment attribute.
       if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
                                                : memoryOp.alignment())) {
@@ -407,7 +409,7 @@ static LogicalResult verifyImageOperands(Op imageOp,
       spirv::ImageOperands::MakeTexelVisible |
       spirv::ImageOperands::SignExtend | spirv::ImageOperands::ZeroExtend;
 
-  if (spirv::bitEnumContains(attr.getValue(), noSupportOperands))
+  if (spirv::bitEnumContainsAll(attr.getValue(), noSupportOperands))
     llvm_unreachable("unimplemented operands of Image Operands");
 
   return success();
@@ -491,8 +493,8 @@ static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp) {
            << memAccessAttr;
   }
 
-  if (spirv::bitEnumContains(memAccess.getValue(),
-                             spirv::MemoryAccess::Aligned)) {
+  if (spirv::bitEnumContainsAll(memAccess.getValue(),
+                                spirv::MemoryAccess::Aligned)) {
     if (!op->getAttr(kAlignmentAttrName)) {
       return memoryOp.emitOpError("missing alignment value");
     }
@@ -535,8 +537,8 @@ static LogicalResult verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp) {
            << memAccess;
   }
 
-  if (spirv::bitEnumContains(memAccess.getValue(),
-                             spirv::MemoryAccess::Aligned)) {
+  if (spirv::bitEnumContainsAll(memAccess.getValue(),
+                                spirv::MemoryAccess::Aligned)) {
     if (!op->getAttr(kSourceAlignmentAttrName)) {
       return memoryOp.emitOpError("missing alignment value");
     }

diff  --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
index d968cb8d3029..e81c052e6de0 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
@@ -162,7 +162,7 @@ static llvm::FastMathFlags getFastmathFlags(FastmathFlagsInterface &op) {
   llvm::FastMathFlags ret;
   auto fmf = op.getFastmathFlags();
   for (auto it : handlers)
-    if (bitEnumContains(fmf, it.first))
+    if (bitEnumContainsAll(fmf, it.first))
       (ret.*(it.second))(true);
   return ret;
 }

diff  --git a/mlir/tools/mlir-tblgen/EnumsGen.cpp b/mlir/tools/mlir-tblgen/EnumsGen.cpp
index 19dcd31932da..a17e927bdbf6 100644
--- a/mlir/tools/mlir-tblgen/EnumsGen.cpp
+++ b/mlir/tools/mlir-tblgen/EnumsGen.cpp
@@ -138,7 +138,8 @@ getAllBitsUnsetCase(llvm::ArrayRef<EnumAttrCase> cases) {
 // inline constexpr <enum-type> operator&(<enum-type> a, <enum-type> b);
 // inline constexpr <enum-type> operator^(<enum-type> a, <enum-type> b);
 // inline constexpr <enum-type> operator~(<enum-type> bits);
-// inline constexpr bool bitEnumContains(<enum-type> bits, <enum-type> bit);
+// inline constexpr bool bitEnumContainsAll(<enum-type> bits, <enum-type> bit);
+// inline constexpr bool bitEnumContainsAny(<enum-type> bits, <enum-type> bit);
 // inline constexpr <enum-type> bitEnumClear(<enum-type> bits, <enum-type> bit);
 // inline constexpr <enum-type> bitEnumSet(<enum-type> bits, <enum-type> bit,
 // bool value=true);
@@ -161,9 +162,12 @@ inline constexpr {0} operator~({0} bits) {{
   // Ensure only bits that can be present in the enum are set
   return static_cast<{0}>(~static_cast<{1}>(bits) & static_cast<{1}>({2}u));
 }
-inline constexpr bool bitEnumContains({0} bits, {0} bit) {{
+inline constexpr bool bitEnumContainsAll({0} bits, {0} bit) {{
   return (bits & bit) == bit;
 }
+inline constexpr bool bitEnumContainsAny({0} bits, {0} bit) {{
+  return (static_cast<{1}>(bits) & static_cast<{1}>(bit)) != 0;
+}
 inline constexpr {0} bitEnumClear({0} bits, {0} bit) {{
   return bits & ~bit;
 }

diff  --git a/mlir/unittests/TableGen/EnumsGenTest.cpp b/mlir/unittests/TableGen/EnumsGenTest.cpp
index d971a9b9e0ab..95df2344fc8f 100644
--- a/mlir/unittests/TableGen/EnumsGenTest.cpp
+++ b/mlir/unittests/TableGen/EnumsGenTest.cpp
@@ -142,10 +142,10 @@ TEST(EnumsGenTest, GeneratedSymbolToStringFnForPrimaryGroupBitEnum) {
 }
 
 TEST(EnumsGenTest, GeneratedOperator) {
-  EXPECT_TRUE(bitEnumContains(BitEnumWithNone::Bit0 | BitEnumWithNone::Bit3,
-                              BitEnumWithNone::Bit0));
-  EXPECT_FALSE(bitEnumContains(BitEnumWithNone::Bit0 & BitEnumWithNone::Bit3,
-                               BitEnumWithNone::Bit0));
+  EXPECT_TRUE(bitEnumContainsAll(BitEnumWithNone::Bit0 | BitEnumWithNone::Bit3,
+                                 BitEnumWithNone::Bit0));
+  EXPECT_FALSE(bitEnumContainsAll(BitEnumWithNone::Bit0 & BitEnumWithNone::Bit3,
+                                  BitEnumWithNone::Bit0));
 }
 
 TEST(EnumsGenTest, GeneratedSymbolToCustomStringFn) {


        


More information about the Mlir-commits mailing list