[Mlir-commits] [mlir] 63779fb - [mlir][spirv] Refactoring to avoid calling the same function twice

Lei Zhang llvmlistbot at llvm.org
Wed Feb 26 12:37:04 PST 2020


Author: Lei Zhang
Date: 2020-02-26T15:36:54-05:00
New Revision: 63779fb462d828d16b87f427a6490dded842ca15

URL: https://github.com/llvm/llvm-project/commit/63779fb462d828d16b87f427a6490dded842ca15
DIFF: https://github.com/llvm/llvm-project/commit/63779fb462d828d16b87f427a6490dded842ca15.diff

LOG: [mlir][spirv] Refactoring to avoid calling the same function twice

Added: 
    

Modified: 
    mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp
index 2d1a66c301f8..c705dc87bfa8 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp
@@ -24,24 +24,23 @@ using namespace mlir;
 // Common utility functions
 //===----------------------------------------------------------------------===//
 
-/// Returns true if the given `irVal` is a scalar or splat vector constant of
-/// the given `boolVal`.
-static bool isScalarOrSplatBoolAttr(Attribute boolAttr, bool boolVal) {
+/// Returns the boolean value under the hood if the given `boolAttr` is a scalar
+/// or splat vector bool constant.
+static Optional<bool> getScalarOrSplatBoolAttr(Attribute boolAttr) {
   if (!boolAttr)
-    return false;
+    return llvm::None;
 
   auto type = boolAttr.getType();
   if (type.isInteger(1)) {
     auto attr = boolAttr.cast<BoolAttr>();
-    return attr.getValue() == boolVal;
+    return attr.getValue();
   }
   if (auto vecType = type.cast<VectorType>()) {
     if (vecType.getElementType().isInteger(1))
       if (auto attr = boolAttr.dyn_cast<SplatElementsAttr>())
-        return attr.getSplatValue().template cast<BoolAttr>().getValue() ==
-               boolVal;
+        return attr.getSplatValue<bool>();
   }
-  return false;
+  return llvm::None;
 }
 
 // Extracts an element from the given `composite` by following the given
@@ -214,13 +213,15 @@ OpFoldResult spirv::ISubOp::fold(ArrayRef<Attribute> operands) {
 OpFoldResult spirv::LogicalAndOp::fold(ArrayRef<Attribute> operands) {
   assert(operands.size() == 2 && "spv.LogicalAnd should take two operands");
 
-  // x && true = x
-  if (isScalarOrSplatBoolAttr(operands.back(), true))
-    return operand1();
+  if (Optional<bool> rhs = getScalarOrSplatBoolAttr(operands.back())) {
+    // x && true = x
+    if (rhs.getValue())
+      return operand1();
 
-  // x && false = false
-  if (isScalarOrSplatBoolAttr(operands.back(), false))
-    return operands.back();
+    // x && false = false
+    if (!rhs.getValue())
+      return operands.back();
+  }
 
   return Attribute();
 }
@@ -243,13 +244,15 @@ void spirv::LogicalNotOp::getCanonicalizationPatterns(
 OpFoldResult spirv::LogicalOrOp::fold(ArrayRef<Attribute> operands) {
   assert(operands.size() == 2 && "spv.LogicalOr should take two operands");
 
-  // x || true = true
-  if (isScalarOrSplatBoolAttr(operands.back(), true))
-    return operands.back();
+  if (auto rhs = getScalarOrSplatBoolAttr(operands.back())) {
+    if (rhs.getValue())
+      // x || true = true
+      return operands.back();
 
-  // x || false = x
-  if (isScalarOrSplatBoolAttr(operands.back(), false))
-    return operand1();
+    // x || false = x
+    if (!rhs.getValue())
+      return operand1();
+  }
 
   return Attribute();
 }


        


More information about the Mlir-commits mailing list