[Mlir-commits] [mlir] 63779fb - [mlir][spirv] Refactoring to avoid calling the same function twice
Lei Zhang
llvmlistbot at llvm.org
Wed Feb 26 12:37:04 PST 2020
Author: Lei Zhang
Date: 2020-02-26T15:36:54-05:00
New Revision: 63779fb462d828d16b87f427a6490dded842ca15
URL: https://github.com/llvm/llvm-project/commit/63779fb462d828d16b87f427a6490dded842ca15
DIFF: https://github.com/llvm/llvm-project/commit/63779fb462d828d16b87f427a6490dded842ca15.diff
LOG: [mlir][spirv] Refactoring to avoid calling the same function twice
Added:
Modified:
mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp
index 2d1a66c301f8..c705dc87bfa8 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp
@@ -24,24 +24,23 @@ using namespace mlir;
// Common utility functions
//===----------------------------------------------------------------------===//
-/// Returns true if the given `irVal` is a scalar or splat vector constant of
-/// the given `boolVal`.
-static bool isScalarOrSplatBoolAttr(Attribute boolAttr, bool boolVal) {
+/// Returns the boolean value under the hood if the given `boolAttr` is a scalar
+/// or splat vector bool constant.
+static Optional<bool> getScalarOrSplatBoolAttr(Attribute boolAttr) {
if (!boolAttr)
- return false;
+ return llvm::None;
auto type = boolAttr.getType();
if (type.isInteger(1)) {
auto attr = boolAttr.cast<BoolAttr>();
- return attr.getValue() == boolVal;
+ return attr.getValue();
}
if (auto vecType = type.cast<VectorType>()) {
if (vecType.getElementType().isInteger(1))
if (auto attr = boolAttr.dyn_cast<SplatElementsAttr>())
- return attr.getSplatValue().template cast<BoolAttr>().getValue() ==
- boolVal;
+ return attr.getSplatValue<bool>();
}
- return false;
+ return llvm::None;
}
// Extracts an element from the given `composite` by following the given
@@ -214,13 +213,15 @@ OpFoldResult spirv::ISubOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult spirv::LogicalAndOp::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 2 && "spv.LogicalAnd should take two operands");
- // x && true = x
- if (isScalarOrSplatBoolAttr(operands.back(), true))
- return operand1();
+ if (Optional<bool> rhs = getScalarOrSplatBoolAttr(operands.back())) {
+ // x && true = x
+ if (rhs.getValue())
+ return operand1();
- // x && false = false
- if (isScalarOrSplatBoolAttr(operands.back(), false))
- return operands.back();
+ // x && false = false
+ if (!rhs.getValue())
+ return operands.back();
+ }
return Attribute();
}
@@ -243,13 +244,15 @@ void spirv::LogicalNotOp::getCanonicalizationPatterns(
OpFoldResult spirv::LogicalOrOp::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 2 && "spv.LogicalOr should take two operands");
- // x || true = true
- if (isScalarOrSplatBoolAttr(operands.back(), true))
- return operands.back();
+ if (auto rhs = getScalarOrSplatBoolAttr(operands.back())) {
+ if (rhs.getValue())
+ // x || true = true
+ return operands.back();
- // x || false = x
- if (isScalarOrSplatBoolAttr(operands.back(), false))
- return operand1();
+ // x || false = x
+ if (!rhs.getValue())
+ return operand1();
+ }
return Attribute();
}
More information about the Mlir-commits
mailing list