[Mlir-commits] [mlir] 5bc6ff6 - [mlir][spirv] Add some folders for spv.LogicalAnd/spv.LogicalOr

Lei Zhang llvmlistbot at llvm.org
Wed Feb 26 12:14:01 PST 2020


Author: Lei Zhang
Date: 2020-02-26T15:13:37-05:00
New Revision: 5bc6ff6455ec663a5da2681d057d0f848817b388

URL: https://github.com/llvm/llvm-project/commit/5bc6ff6455ec663a5da2681d057d0f848817b388
DIFF: https://github.com/llvm/llvm-project/commit/5bc6ff6455ec663a5da2681d057d0f848817b388.diff

LOG: [mlir][spirv] Add some folders for spv.LogicalAnd/spv.LogicalOr

This commit handles folding spv.LogicalAnd/spv.LogicalOr when
one of the operands is constant true/false.

Differential Revision: https://reviews.llvm.org/D75195

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SPIRV/SPIRVLogicalOps.td
    mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp
    mlir/test/Dialect/SPIRV/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVLogicalOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVLogicalOps.td
index 57d09b388b10..26b332c4c451 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVLogicalOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVLogicalOps.td
@@ -526,6 +526,8 @@ def SPV_LogicalAndOp : SPV_LogicalBinaryOp<"LogicalAnd", SPV_Bool, [Commutative]
     %2 = spv.LogicalAnd %0, %1 : vector<4xi1>
     ```
   }];
+
+  let hasFolder = 1;
 }
 
 // -----
@@ -656,6 +658,8 @@ def SPV_LogicalOrOp : SPV_LogicalBinaryOp<"LogicalOr", SPV_Bool, [Commutative]>
     %2 = spv.LogicalOr %0, %1 : vector<4xi1>
     ```
   }];
+
+  let hasFolder = 1;
 }
 
 // -----

diff  --git a/mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp
index 32090f3d1ec0..2d1a66c301f8 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp
@@ -24,6 +24,26 @@ using namespace mlir;
 // Common utility functions
 //===----------------------------------------------------------------------===//
 
+/// Returns true if the given `irVal` is a scalar or splat vector constant of
+/// the given `boolVal`.
+static bool isScalarOrSplatBoolAttr(Attribute boolAttr, bool boolVal) {
+  if (!boolAttr)
+    return false;
+
+  auto type = boolAttr.getType();
+  if (type.isInteger(1)) {
+    auto attr = boolAttr.cast<BoolAttr>();
+    return attr.getValue() == boolVal;
+  }
+  if (auto vecType = type.cast<VectorType>()) {
+    if (vecType.getElementType().isInteger(1))
+      if (auto attr = boolAttr.dyn_cast<SplatElementsAttr>())
+        return attr.getSplatValue().template cast<BoolAttr>().getValue() ==
+               boolVal;
+  }
+  return false;
+}
+
 // Extracts an element from the given `composite` by following the given
 // `indices`. Returns a null Attribute if error happens.
 static Attribute extractCompositeElement(Attribute composite,
@@ -187,6 +207,24 @@ OpFoldResult spirv::ISubOp::fold(ArrayRef<Attribute> operands) {
                                         [](APInt a, APInt b) { return a - b; });
 }
 
+//===----------------------------------------------------------------------===//
+// spv.LogicalAnd
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::LogicalAndOp::fold(ArrayRef<Attribute> operands) {
+  assert(operands.size() == 2 && "spv.LogicalAnd should take two operands");
+
+  // x && true = x
+  if (isScalarOrSplatBoolAttr(operands.back(), true))
+    return operand1();
+
+  // x && false = false
+  if (isScalarOrSplatBoolAttr(operands.back(), false))
+    return operands.back();
+
+  return Attribute();
+}
+
 //===----------------------------------------------------------------------===//
 // spv.LogicalNot
 //===----------------------------------------------------------------------===//
@@ -198,6 +236,24 @@ void spirv::LogicalNotOp::getCanonicalizationPatterns(
                  ConvertLogicalNotOfLogicalNotEqual>(context);
 }
 
+//===----------------------------------------------------------------------===//
+// spv.LogicalOr
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::LogicalOrOp::fold(ArrayRef<Attribute> operands) {
+  assert(operands.size() == 2 && "spv.LogicalOr should take two operands");
+
+  // x || true = true
+  if (isScalarOrSplatBoolAttr(operands.back(), true))
+    return operands.back();
+
+  // x || false = x
+  if (isScalarOrSplatBoolAttr(operands.back(), false))
+    return operand1();
+
+  return Attribute();
+}
+
 //===----------------------------------------------------------------------===//
 // spv.selection
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/SPIRV/canonicalize.mlir b/mlir/test/Dialect/SPIRV/canonicalize.mlir
index c8e4ccf51176..72857d7d7de2 100644
--- a/mlir/test/Dialect/SPIRV/canonicalize.mlir
+++ b/mlir/test/Dialect/SPIRV/canonicalize.mlir
@@ -362,6 +362,36 @@ func @const_fold_vector_isub() -> vector<3xi32> {
 
 // -----
 
+//===----------------------------------------------------------------------===//
+// spv.LogicalAnd
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @convert_logical_and_true_false_scalar
+// CHECK-SAME: %[[ARG:.+]]: i1
+func @convert_logical_and_true_false_scalar(%arg: i1) -> (i1, i1) {
+  %true = spv.constant true
+  // CHECK: %[[FALSE:.+]] = spv.constant false
+  %false = spv.constant false
+  %0 = spv.LogicalAnd %true, %arg: i1
+  %1 = spv.LogicalAnd %arg, %false: i1
+  // CHECK: return %[[ARG]], %[[FALSE]]
+  return %0, %1: i1, i1
+}
+
+// CHECK-LABEL: @convert_logical_and_true_false_vector
+// CHECK-SAME: %[[ARG:.+]]: vector<3xi1>
+func @convert_logical_and_true_false_vector(%arg: vector<3xi1>) -> (vector<3xi1>, vector<3xi1>) {
+  %true = spv.constant dense<true> : vector<3xi1>
+  // CHECK: %[[FALSE:.+]] = spv.constant dense<false>
+  %false = spv.constant dense<false> : vector<3xi1>
+  %0 = spv.LogicalAnd %true, %arg: vector<3xi1>
+  %1 = spv.LogicalAnd %arg, %false: vector<3xi1>
+  // CHECK: return %[[ARG]], %[[FALSE]]
+  return %0, %1: vector<3xi1>, vector<3xi1>
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spv.LogicalNot
 //===----------------------------------------------------------------------===//
@@ -419,6 +449,36 @@ func @convert_logical_not_to_logical_equal(%arg0: vector<3xi1>, %arg1: vector<3x
 
 // -----
 
+//===----------------------------------------------------------------------===//
+// spv.LogicalOr
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @convert_logical_or_true_false_scalar
+// CHECK-SAME: %[[ARG:.+]]: i1
+func @convert_logical_or_true_false_scalar(%arg: i1) -> (i1, i1) {
+  // CHECK: %[[TRUE:.+]] = spv.constant true
+  %true = spv.constant true
+  %false = spv.constant false
+  %0 = spv.LogicalOr %true, %arg: i1
+  %1 = spv.LogicalOr %arg, %false: i1
+  // CHECK: return %[[TRUE]], %[[ARG]]
+  return %0, %1: i1, i1
+}
+
+// CHECK-LABEL: @convert_logical_or_true_false_vector
+// CHECK-SAME: %[[ARG:.+]]: vector<3xi1>
+func @convert_logical_or_true_false_vector(%arg: vector<3xi1>) -> (vector<3xi1>, vector<3xi1>) {
+  // CHECK: %[[TRUE:.+]] = spv.constant dense<true>
+  %true = spv.constant dense<true> : vector<3xi1>
+  %false = spv.constant dense<false> : vector<3xi1>
+  %0 = spv.LogicalOr %true, %arg: vector<3xi1>
+  %1 = spv.LogicalOr %arg, %false: vector<3xi1>
+  // CHECK: return %[[TRUE]], %[[ARG]]
+  return %0, %1: vector<3xi1>, vector<3xi1>
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spv.selection
 //===----------------------------------------------------------------------===//


        


More information about the Mlir-commits mailing list