[Mlir-commits] [mlir] 5bc6ff6 - [mlir][spirv] Add some folders for spv.LogicalAnd/spv.LogicalOr
Lei Zhang
llvmlistbot at llvm.org
Wed Feb 26 12:14:01 PST 2020
Author: Lei Zhang
Date: 2020-02-26T15:13:37-05:00
New Revision: 5bc6ff6455ec663a5da2681d057d0f848817b388
URL: https://github.com/llvm/llvm-project/commit/5bc6ff6455ec663a5da2681d057d0f848817b388
DIFF: https://github.com/llvm/llvm-project/commit/5bc6ff6455ec663a5da2681d057d0f848817b388.diff
LOG: [mlir][spirv] Add some folders for spv.LogicalAnd/spv.LogicalOr
This commit handles folding spv.LogicalAnd/spv.LogicalOr when
one of the operands is constant true/false.
Differential Revision: https://reviews.llvm.org/D75195
Added:
Modified:
mlir/include/mlir/Dialect/SPIRV/SPIRVLogicalOps.td
mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp
mlir/test/Dialect/SPIRV/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVLogicalOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVLogicalOps.td
index 57d09b388b10..26b332c4c451 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVLogicalOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVLogicalOps.td
@@ -526,6 +526,8 @@ def SPV_LogicalAndOp : SPV_LogicalBinaryOp<"LogicalAnd", SPV_Bool, [Commutative]
%2 = spv.LogicalAnd %0, %1 : vector<4xi1>
```
}];
+
+ let hasFolder = 1;
}
// -----
@@ -656,6 +658,8 @@ def SPV_LogicalOrOp : SPV_LogicalBinaryOp<"LogicalOr", SPV_Bool, [Commutative]>
%2 = spv.LogicalOr %0, %1 : vector<4xi1>
```
}];
+
+ let hasFolder = 1;
}
// -----
diff --git a/mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp
index 32090f3d1ec0..2d1a66c301f8 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp
@@ -24,6 +24,26 @@ using namespace mlir;
// Common utility functions
//===----------------------------------------------------------------------===//
+/// Returns true if the given `irVal` is a scalar or splat vector constant of
+/// the given `boolVal`.
+static bool isScalarOrSplatBoolAttr(Attribute boolAttr, bool boolVal) {
+ if (!boolAttr)
+ return false;
+
+ auto type = boolAttr.getType();
+ if (type.isInteger(1)) {
+ auto attr = boolAttr.cast<BoolAttr>();
+ return attr.getValue() == boolVal;
+ }
+ if (auto vecType = type.cast<VectorType>()) {
+ if (vecType.getElementType().isInteger(1))
+ if (auto attr = boolAttr.dyn_cast<SplatElementsAttr>())
+ return attr.getSplatValue().template cast<BoolAttr>().getValue() ==
+ boolVal;
+ }
+ return false;
+}
+
// Extracts an element from the given `composite` by following the given
// `indices`. Returns a null Attribute if error happens.
static Attribute extractCompositeElement(Attribute composite,
@@ -187,6 +207,24 @@ OpFoldResult spirv::ISubOp::fold(ArrayRef<Attribute> operands) {
[](APInt a, APInt b) { return a - b; });
}
+//===----------------------------------------------------------------------===//
+// spv.LogicalAnd
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::LogicalAndOp::fold(ArrayRef<Attribute> operands) {
+ assert(operands.size() == 2 && "spv.LogicalAnd should take two operands");
+
+ // x && true = x
+ if (isScalarOrSplatBoolAttr(operands.back(), true))
+ return operand1();
+
+ // x && false = false
+ if (isScalarOrSplatBoolAttr(operands.back(), false))
+ return operands.back();
+
+ return Attribute();
+}
+
//===----------------------------------------------------------------------===//
// spv.LogicalNot
//===----------------------------------------------------------------------===//
@@ -198,6 +236,24 @@ void spirv::LogicalNotOp::getCanonicalizationPatterns(
ConvertLogicalNotOfLogicalNotEqual>(context);
}
+//===----------------------------------------------------------------------===//
+// spv.LogicalOr
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::LogicalOrOp::fold(ArrayRef<Attribute> operands) {
+ assert(operands.size() == 2 && "spv.LogicalOr should take two operands");
+
+ // x || true = true
+ if (isScalarOrSplatBoolAttr(operands.back(), true))
+ return operands.back();
+
+ // x || false = x
+ if (isScalarOrSplatBoolAttr(operands.back(), false))
+ return operand1();
+
+ return Attribute();
+}
+
//===----------------------------------------------------------------------===//
// spv.selection
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/canonicalize.mlir b/mlir/test/Dialect/SPIRV/canonicalize.mlir
index c8e4ccf51176..72857d7d7de2 100644
--- a/mlir/test/Dialect/SPIRV/canonicalize.mlir
+++ b/mlir/test/Dialect/SPIRV/canonicalize.mlir
@@ -362,6 +362,36 @@ func @const_fold_vector_isub() -> vector<3xi32> {
// -----
+//===----------------------------------------------------------------------===//
+// spv.LogicalAnd
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @convert_logical_and_true_false_scalar
+// CHECK-SAME: %[[ARG:.+]]: i1
+func @convert_logical_and_true_false_scalar(%arg: i1) -> (i1, i1) {
+ %true = spv.constant true
+ // CHECK: %[[FALSE:.+]] = spv.constant false
+ %false = spv.constant false
+ %0 = spv.LogicalAnd %true, %arg: i1
+ %1 = spv.LogicalAnd %arg, %false: i1
+ // CHECK: return %[[ARG]], %[[FALSE]]
+ return %0, %1: i1, i1
+}
+
+// CHECK-LABEL: @convert_logical_and_true_false_vector
+// CHECK-SAME: %[[ARG:.+]]: vector<3xi1>
+func @convert_logical_and_true_false_vector(%arg: vector<3xi1>) -> (vector<3xi1>, vector<3xi1>) {
+ %true = spv.constant dense<true> : vector<3xi1>
+ // CHECK: %[[FALSE:.+]] = spv.constant dense<false>
+ %false = spv.constant dense<false> : vector<3xi1>
+ %0 = spv.LogicalAnd %true, %arg: vector<3xi1>
+ %1 = spv.LogicalAnd %arg, %false: vector<3xi1>
+ // CHECK: return %[[ARG]], %[[FALSE]]
+ return %0, %1: vector<3xi1>, vector<3xi1>
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spv.LogicalNot
//===----------------------------------------------------------------------===//
@@ -419,6 +449,36 @@ func @convert_logical_not_to_logical_equal(%arg0: vector<3xi1>, %arg1: vector<3x
// -----
+//===----------------------------------------------------------------------===//
+// spv.LogicalOr
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @convert_logical_or_true_false_scalar
+// CHECK-SAME: %[[ARG:.+]]: i1
+func @convert_logical_or_true_false_scalar(%arg: i1) -> (i1, i1) {
+ // CHECK: %[[TRUE:.+]] = spv.constant true
+ %true = spv.constant true
+ %false = spv.constant false
+ %0 = spv.LogicalOr %true, %arg: i1
+ %1 = spv.LogicalOr %arg, %false: i1
+ // CHECK: return %[[TRUE]], %[[ARG]]
+ return %0, %1: i1, i1
+}
+
+// CHECK-LABEL: @convert_logical_or_true_false_vector
+// CHECK-SAME: %[[ARG:.+]]: vector<3xi1>
+func @convert_logical_or_true_false_vector(%arg: vector<3xi1>) -> (vector<3xi1>, vector<3xi1>) {
+ // CHECK: %[[TRUE:.+]] = spv.constant dense<true>
+ %true = spv.constant dense<true> : vector<3xi1>
+ %false = spv.constant dense<false> : vector<3xi1>
+ %0 = spv.LogicalOr %true, %arg: vector<3xi1>
+ %1 = spv.LogicalOr %arg, %false: vector<3xi1>
+ // CHECK: return %[[TRUE]], %[[ARG]]
+ return %0, %1: vector<3xi1>, vector<3xi1>
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spv.selection
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list