[Mlir-commits] [mlir] [mlir][spirv] Add folding for Bitwise[Or|And|Xor] (PR #74193)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Dec 3 10:46:24 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Finn Plummer (inbelic)

<details>
<summary>Changes</summary>

Add missing constant propogation folder for Bitwise[Or|And|Xor].

Move previous Bitwise[Or|And] fold implementations to SPIRVCanonicalization for consistency.

Implement additional folding when lhs == rhs and rhs = 0 for Xor. As well as, update an Xor testcase to account for this introduced folding.

This helps for readability of lowered code into SPIR-V.

Part of work for #<!-- -->70704

---
Full diff: https://github.com/llvm/llvm-project/pull/74193.diff


5 Files Affected:

- (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td (+2) 
- (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp (+86) 
- (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp (-49) 
- (modified) mlir/test/Dialect/SPIRV/IR/bit-ops.mlir (+4-2) 
- (modified) mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir (+155) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td
index 286f4de6f90f6..2b22fc1795402 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td
@@ -412,6 +412,8 @@ def SPIRV_BitwiseXorOp : SPIRV_BitBinaryOp<"BitwiseXor",
     %2 = spirv.BitwiseXor %0, %1 : vector<4xi32>
     ```
   }];
+
+  let hasFolder = 1;
 }
 
 // -----
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index 9acd982dc95af..38489fdf45f79 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -356,6 +356,92 @@ OpFoldResult spirv::LogicalOrOp::fold(FoldAdaptor adaptor) {
   return Attribute();
 }
 
+//===----------------------------------------------------------------------===//
+// spirv.BitwiseAndOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult
+spirv::BitwiseAndOp::fold(spirv::BitwiseAndOp::FoldAdaptor adaptor) {
+  APInt rhsMask;
+  if (matchPattern(adaptor.getOperand2(), m_ConstantInt(&rhsMask))) {
+    // x & 0 -> 0
+    if (rhsMask.isZero())
+      return getOperand2();
+
+    // x & <all ones> -> x
+    if (rhsMask.isAllOnes())
+      return getOperand1();
+
+    // (UConvert x : iN to iK) & <mask with N low bits set> -> UConvert x
+    if (auto zext = getOperand1().getDefiningOp<spirv::UConvertOp>()) {
+      int valueBits =
+          getElementTypeOrSelf(zext.getOperand()).getIntOrFloatBitWidth();
+      if (rhsMask.zextOrTrunc(valueBits).isAllOnes())
+        return getOperand1();
+    }
+  }
+
+  // According to the SPIR-V spec:
+  //
+  // Type is a scalar or vector of integer type.
+  // Results are computed per component, and within each component, per bit.
+  // So we can use the APInt & method.
+  return constFoldBinaryOp<IntegerAttr>(
+      adaptor.getOperands(),
+      [](const APInt &a, const APInt &b) { return a & b; });
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.BitwiseOrOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::BitwiseOrOp::fold(spirv::BitwiseOrOp::FoldAdaptor adaptor) {
+  APInt rhsMask;
+  if (matchPattern(adaptor.getOperand2(), m_ConstantInt(&rhsMask))) {
+    // x | 0 -> x
+    if (rhsMask.isZero())
+      return getOperand1();
+
+    // x | <all ones> -> <all ones>
+    if (rhsMask.isAllOnes())
+      return getOperand2();
+  }
+
+  // According to the SPIR-V spec:
+  //
+  // Type is a scalar or vector of integer type.
+  // Results are computed per component, and within each component, per bit.
+  // So we can use the APInt | method.
+  return constFoldBinaryOp<IntegerAttr>(
+      adaptor.getOperands(),
+      [](const APInt &a, const APInt &b) { return a | b; });
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.BitwiseXorOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult
+spirv::BitwiseXorOp::fold(spirv::BitwiseXorOp::FoldAdaptor adaptor) {
+  // x ^ 0 -> x
+  if (matchPattern(adaptor.getOperand2(), m_Zero())) {
+    return getOperand1();
+  }
+
+  // x ^ x -> 0
+  if (getOperand1() == getOperand2())
+    return Builder(getContext()).getZeroAttr(getType());
+
+  // According to the SPIR-V spec:
+  //
+  // Type is a scalar or vector of integer type.
+  // Results are computed per component, and within each component, per bit.
+  // So we can use the APInt ^ method.
+  return constFoldBinaryOp<IntegerAttr>(
+      adaptor.getOperands(),
+      [](const APInt &a, const APInt &b) { return a ^ b; });
+}
+
 //===----------------------------------------------------------------------===//
 // spirv.mlir.selection
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 3906bf74ea722..2a1d083308282 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -1968,55 +1968,6 @@ LogicalResult spirv::ShiftRightLogicalOp::verify() {
   return verifyShiftOp(*this);
 }
 
-//===----------------------------------------------------------------------===//
-// spirv.BtiwiseAndOp
-//===----------------------------------------------------------------------===//
-
-OpFoldResult
-spirv::BitwiseAndOp::fold(spirv::BitwiseAndOp::FoldAdaptor adaptor) {
-  APInt rhsMask;
-  if (!matchPattern(adaptor.getOperand2(), m_ConstantInt(&rhsMask)))
-    return {};
-
-  // x & 0 -> 0
-  if (rhsMask.isZero())
-    return getOperand2();
-
-  // x & <all ones> -> x
-  if (rhsMask.isAllOnes())
-    return getOperand1();
-
-  // (UConvert x : iN to iK) & <mask with N low bits set> -> UConvert x
-  if (auto zext = getOperand1().getDefiningOp<spirv::UConvertOp>()) {
-    int valueBits =
-        getElementTypeOrSelf(zext.getOperand()).getIntOrFloatBitWidth();
-    if (rhsMask.zextOrTrunc(valueBits).isAllOnes())
-      return getOperand1();
-  }
-
-  return {};
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.BtiwiseOrOp
-//===----------------------------------------------------------------------===//
-
-OpFoldResult spirv::BitwiseOrOp::fold(spirv::BitwiseOrOp::FoldAdaptor adaptor) {
-  APInt rhsMask;
-  if (!matchPattern(adaptor.getOperand2(), m_ConstantInt(&rhsMask)))
-    return {};
-
-  // x | 0 -> x
-  if (rhsMask.isZero())
-    return getOperand1();
-
-  // x | <all ones> -> <all ones>
-  if (rhsMask.isAllOnes())
-    return getOperand2();
-
-  return {};
-}
-
 //===----------------------------------------------------------------------===//
 // spirv.ImageQuerySize
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/IR/bit-ops.mlir b/mlir/test/Dialect/SPIRV/IR/bit-ops.mlir
index 82a2316f6c784..f3f0ebf60f468 100644
--- a/mlir/test/Dialect/SPIRV/IR/bit-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/bit-ops.mlir
@@ -149,14 +149,16 @@ func.func @bitwise_or_float(%arg0: f16, %arg1: f16) -> f16 {
 //===----------------------------------------------------------------------===//
 
 func.func @bitwise_xor_scalar(%arg: i32) -> i32 {
+  %c1 = spirv.Constant 1 : i32 // using constant to avoid folding
   // CHECK: spirv.BitwiseXor
-  %0 = spirv.BitwiseXor %arg, %arg : i32
+  %0 = spirv.BitwiseXor %c1, %arg : i32
   return %0 : i32
 }
 
 func.func @bitwise_xor_vector(%arg: vector<4xi32>) -> vector<4xi32> {
+  %c1 = spirv.Constant dense<1> : vector<4xi32> // using constant to avoid folding
   // CHECK: spirv.BitwiseXor
-  %0 = spirv.BitwiseXor %arg, %arg : vector<4xi32>
+  %0 = spirv.BitwiseXor %c1, %arg : vector<4xi32>
   return %0 : vector<4xi32>
 }
 
diff --git a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
index 0200805a44439..a71e122af2e3f 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
@@ -660,6 +660,161 @@ func.func @convert_logical_or_true_false_vector(%arg: vector<3xi1>) -> (vector<3
 
 // -----
 
+//===----------------------------------------------------------------------===//
+// spirv.BitwiseAnd
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @bitwise_and_x_0
+// CHECK-SAME: (%[[ARG:.*]]: i32)
+func.func @bitwise_and_x_0(%arg0 : i32) -> i32 {
+  // CHECK: %[[C0:.*]] = spirv.Constant 0 : i32
+  %c1 = spirv.Constant 0 : i32
+  %0 = spirv.BitwiseAnd %arg0, %c1 : i32
+
+  // CHECK: return %[[C0]]
+  return %0 : i32
+}
+
+// CHECK-LABEL: @bitwise_and_x_n1
+// CHECK-SAME: (%[[ARG:.*]]: i32)
+func.func @bitwise_and_x_n1(%arg0 : i32) -> i32 {
+  %c1 = spirv.Constant -1 : i32
+  %0 = spirv.BitwiseAnd %arg0, %c1 : i32
+
+  // CHECK: return %[[ARG]]
+  return %0 : i32
+}
+
+// CHECK-LABEL: @const_fold_scalar_band
+func.func @const_fold_scalar_band() -> i32 {
+  %c1 = spirv.Constant -268464129 : i32   // 0xefff 8fff
+  %c2 = spirv.Constant 268464128: i32     // 0x1000 7000
+
+  // 0xefff 8fff | 0x1000 7000 = 0xffff ffff = -1
+  // CHECK: %[[C0:.*]] = spirv.Constant 0
+  %0 = spirv.BitwiseAnd %c1, %c2 : i32
+
+  // CHECK: return %[[C0]]
+  return %0 : i32
+}
+
+// CHECK-LABEL: @const_fold_vector_band
+func.func @const_fold_vector_band() -> vector<3xi32> {
+  %c1 = spirv.Constant dense<[42, -55, 127]> : vector<3xi32>
+  %c2 = spirv.Constant dense<[-3, -15, 28]> : vector<3xi32>
+
+  // CHECK: %[[CV:.*]] = spirv.Constant dense<[40, -63, 28]>
+  %0 = spirv.BitwiseAnd %c1, %c2 : vector<3xi32>
+
+  // CHECK: return %[[CV]]
+  return %0 : vector<3xi32>
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.BitwiseOr
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @bitwise_or_x_0
+// CHECK-SAME: (%[[ARG:.*]]: i32)
+func.func @bitwise_or_x_0(%arg0 : i32) -> i32 {
+  %c1 = spirv.Constant 0 : i32
+  %0 = spirv.BitwiseOr %arg0, %c1 : i32
+
+  // CHECK: return %[[ARG]]
+  return %0 : i32
+}
+
+// CHECK-LABEL: @bitwise_or_x_n1
+func.func @bitwise_or_x_n1(%arg0 : i32) -> i32 {
+  // CHECK: %[[CN1:.*]] = spirv.Constant -1 : i32
+  %c1 = spirv.Constant -1 : i32
+  %0 = spirv.BitwiseOr %arg0, %c1 : i32
+
+  // CHECK: return %[[CN1]]
+  return %0 : i32
+}
+
+// CHECK-LABEL: @const_fold_scalar_bor
+func.func @const_fold_scalar_bor() -> i32 {
+  %c1 = spirv.Constant -268464129 : i32   // 0xefff 8fff
+  %c2 = spirv.Constant 268464128: i32     // 0x1000 7000
+
+  // 0xefff 8fff | 0x1000 7000 = 0xffff ffff = -1
+  // CHECK: %[[CN1:.*]] = spirv.Constant -1
+  %0 = spirv.BitwiseOr %c1, %c2 : i32
+
+  // CHECK: return %[[CN1]]
+  return %0 : i32
+}
+
+// CHECK-LABEL: @const_fold_vector_bor
+func.func @const_fold_vector_bor() -> vector<3xi32> {
+  %c1 = spirv.Constant dense<[42, -55, 127]> : vector<3xi32>
+  %c2 = spirv.Constant dense<[-3, -15, 28]> : vector<3xi32>
+
+  // CHECK: %[[CV:.*]] = spirv.Constant dense<[-1, -7, 127]>
+  %0 = spirv.BitwiseOr %c1, %c2 : vector<3xi32>
+
+  // CHECK: return %[[CV]]
+  return %0 : vector<3xi32>
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.BitwiseXor
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @bitwise_xor_x_0
+// CHECK-SAME: (%[[ARG:.*]]: i32)
+func.func @bitwise_xor_x_0(%arg0 : i32) -> i32 {
+  %c1 = spirv.Constant 0 : i32
+  %0 = spirv.BitwiseXor %arg0, %c1 : i32
+
+  // CHECK: return %[[ARG]]
+  return %0 : i32
+}
+
+// CHECK-LABEL: @bitwise_xor_x_x
+func.func @bitwise_xor_x_x(%arg0: i32, %arg1: vector<3xi32>) -> (i32, vector<3xi32>) {
+  // CHECK-DAG: %[[C0:.*]] = spirv.Constant 0
+  // CHECK-DAG: %[[CV0:.*]] = spirv.Constant dense<0>
+  %0 = spirv.BitwiseXor %arg0, %arg0 : i32
+  %1 = spirv.BitwiseXor %arg1, %arg1 : vector<3xi32>
+
+  // CHECK: return %[[C0]], %[[CV0]]
+  return %0, %1 : i32, vector<3xi32>
+}
+
+// CHECK-LABEL: @const_fold_scalar_bxor
+func.func @const_fold_scalar_bxor() -> i32 {
+  %c1 = spirv.Constant 4294967295 : i32  // 2^32 - 1: 0xffff ffff
+  %c2 = spirv.Constant -2147483648 : i32 // -2^31   : 0x8000 0000
+
+  // 0x8000 0000 ^ 0xffff fffe = 0xefff ffff
+  // CHECK: %[[CBIG:.*]] = spirv.Constant 2147483647
+  %0 = spirv.BitwiseXor %c1, %c2 : i32
+
+  // CHECK: return %[[CBIG]]
+  return %0 : i32
+}
+
+// CHECK-LABEL: @const_fold_vector_bxor
+func.func @const_fold_vector_bxor() -> vector<3xi32> {
+  %c1 = spirv.Constant dense<[42, -55, 127]> : vector<3xi32>
+  %c2 = spirv.Constant dense<[-3, -15, 28]> : vector<3xi32>
+
+  // CHECK: %[[CV:.*]] = spirv.Constant dense<[-41, 56, 99]>
+  %0 = spirv.BitwiseXor %c1, %c2 : vector<3xi32>
+
+  // CHECK: return %[[CV]]
+  return %0 : vector<3xi32>
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.mlir.selection
 //===----------------------------------------------------------------------===//

``````````

</details>


https://github.com/llvm/llvm-project/pull/74193


More information about the Mlir-commits mailing list