[Mlir-commits] [mlir] 40e2bb5 - [mlir][spirv] Add folding for Bitwise[Or|And|Xor] (#74193)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Dec 11 10:09:45 PST 2023


Author: Finn Plummer
Date: 2023-12-11T13:09:40-05:00
New Revision: 40e2bb5330840b56d452244f96e491b6530ce4bf

URL: https://github.com/llvm/llvm-project/commit/40e2bb5330840b56d452244f96e491b6530ce4bf
DIFF: https://github.com/llvm/llvm-project/commit/40e2bb5330840b56d452244f96e491b6530ce4bf.diff

LOG: [mlir][spirv] Add folding for Bitwise[Or|And|Xor] (#74193)

Add missing constant propogation folder for Bitwise[Or|And|Xor].

Move previous Bitwise[Or|And] fold implementations to
SPIRVCanonicalization for consistency.

Implement additional folding when lhs == rhs and rhs = 0 for Xor. As
well as, update an Xor testcase to account for this introduced folding.

This helps for readability of lowered code into SPIR-V.

Part of work for #70704

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td
    mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
    mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
    mlir/test/Dialect/SPIRV/IR/bit-ops.mlir
    mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td
index 7423bb3c68d7c..b460c8e68aa0c 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td
@@ -334,6 +334,8 @@ def SPIRV_BitwiseXorOp : SPIRV_BitBinaryOp<"BitwiseXor",
     %2 = spirv.BitwiseXor %0, %1 : vector<4xi32>
     ```
   }];
+
+  let hasFolder = 1;
 }
 
 // -----

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index 4643fc08c08d3..9de1707dfca46 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -722,9 +722,6 @@ OpFoldResult spirv::ShiftLeftLogicalOp::fold(
 
   // Unfortunately due to below undefined behaviour can't fold 0 for Base.
 
-  // According to the SPIR-V spec:
-  //
-  // Type is a scalar or vector of integer type.
   // Results are computed per component, and within each component, per bit...
   //
   // The result is undefined if Shift is greater than or equal to the bit width
@@ -756,9 +753,6 @@ OpFoldResult spirv::ShiftRightArithmeticOp::fold(
 
   // Unfortunately due to below undefined behaviour can't fold 0, -1 for Base.
 
-  // According to the SPIR-V spec:
-  //
-  // Type is a scalar or vector of integer type.
   // Results are computed per component, and within each component, per bit...
   //
   // The result is undefined if Shift is greater than or equal to the bit width
@@ -790,9 +784,6 @@ OpFoldResult spirv::ShiftRightLogicalOp::fold(
 
   // Unfortunately due to below undefined behaviour can't fold 0 for Base.
 
-  // According to the SPIR-V spec:
-  //
-  // Type is a scalar or vector of integer type.
   // Results are computed per component, and within each component, per bit...
   //
   // The result is undefined if Shift is greater than or equal to the bit width
@@ -811,6 +802,102 @@ OpFoldResult spirv::ShiftRightLogicalOp::fold(
   return shiftToLarge ? Attribute() : res;
 }
 
+//===----------------------------------------------------------------------===//
+// spirv.BitwiseAndOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult
+spirv::BitwiseAndOp::fold(spirv::BitwiseAndOp::FoldAdaptor adaptor) {
+  // x & x -> x
+  if (getOperand1() == getOperand2()) {
+    return getOperand1();
+  }
+
+  APInt rhsMask;
+  if (matchPattern(adaptor.getOperand2(), m_ConstantInt(&rhsMask))) {
+    // x & 0 -> 0
+    if (rhsMask.isZero())
+      return getOperand2();
+
+    // x & <all ones> -> x
+    if (rhsMask.isAllOnes())
+      return getOperand1();
+
+    // (UConvert x : iN to iK) & <mask with N low bits set> -> UConvert x
+    if (auto zext = getOperand1().getDefiningOp<spirv::UConvertOp>()) {
+      int valueBits =
+          getElementTypeOrSelf(zext.getOperand()).getIntOrFloatBitWidth();
+      if (rhsMask.zextOrTrunc(valueBits).isAllOnes())
+        return getOperand1();
+    }
+  }
+
+  // According to the SPIR-V spec:
+  //
+  // Type is a scalar or vector of integer type.
+  // Results are computed per component, and within each component, per bit.
+  // So we can use the APInt & method.
+  return constFoldBinaryOp<IntegerAttr>(
+      adaptor.getOperands(),
+      [](const APInt &a, const APInt &b) { return a & b; });
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.BitwiseOrOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::BitwiseOrOp::fold(spirv::BitwiseOrOp::FoldAdaptor adaptor) {
+  // x | x -> x
+  if (getOperand1() == getOperand2()) {
+    return getOperand1();
+  }
+
+  APInt rhsMask;
+  if (matchPattern(adaptor.getOperand2(), m_ConstantInt(&rhsMask))) {
+    // x | 0 -> x
+    if (rhsMask.isZero())
+      return getOperand1();
+
+    // x | <all ones> -> <all ones>
+    if (rhsMask.isAllOnes())
+      return getOperand2();
+  }
+
+  // According to the SPIR-V spec:
+  //
+  // Type is a scalar or vector of integer type.
+  // Results are computed per component, and within each component, per bit.
+  // So we can use the APInt | method.
+  return constFoldBinaryOp<IntegerAttr>(
+      adaptor.getOperands(),
+      [](const APInt &a, const APInt &b) { return a | b; });
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.BitwiseXorOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult
+spirv::BitwiseXorOp::fold(spirv::BitwiseXorOp::FoldAdaptor adaptor) {
+  // x ^ 0 -> x
+  if (matchPattern(adaptor.getOperand2(), m_Zero())) {
+    return getOperand1();
+  }
+
+  // x ^ x -> 0
+  if (getOperand1() == getOperand2())
+    return Builder(getContext()).getZeroAttr(getType());
+
+  // According to the SPIR-V spec:
+  //
+  // Type is a scalar or vector of integer type.
+  // Results are computed per component, and within each component, per bit.
+  // So we can use the APInt ^ method.
+  return constFoldBinaryOp<IntegerAttr>(
+      adaptor.getOperands(),
+      [](const APInt &a, const APInt &b) { return a ^ b; });
+}
+
 //===----------------------------------------------------------------------===//
 // spirv.mlir.selection
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 3906bf74ea722..2a1d083308282 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -1968,55 +1968,6 @@ LogicalResult spirv::ShiftRightLogicalOp::verify() {
   return verifyShiftOp(*this);
 }
 
-//===----------------------------------------------------------------------===//
-// spirv.BtiwiseAndOp
-//===----------------------------------------------------------------------===//
-
-OpFoldResult
-spirv::BitwiseAndOp::fold(spirv::BitwiseAndOp::FoldAdaptor adaptor) {
-  APInt rhsMask;
-  if (!matchPattern(adaptor.getOperand2(), m_ConstantInt(&rhsMask)))
-    return {};
-
-  // x & 0 -> 0
-  if (rhsMask.isZero())
-    return getOperand2();
-
-  // x & <all ones> -> x
-  if (rhsMask.isAllOnes())
-    return getOperand1();
-
-  // (UConvert x : iN to iK) & <mask with N low bits set> -> UConvert x
-  if (auto zext = getOperand1().getDefiningOp<spirv::UConvertOp>()) {
-    int valueBits =
-        getElementTypeOrSelf(zext.getOperand()).getIntOrFloatBitWidth();
-    if (rhsMask.zextOrTrunc(valueBits).isAllOnes())
-      return getOperand1();
-  }
-
-  return {};
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.BtiwiseOrOp
-//===----------------------------------------------------------------------===//
-
-OpFoldResult spirv::BitwiseOrOp::fold(spirv::BitwiseOrOp::FoldAdaptor adaptor) {
-  APInt rhsMask;
-  if (!matchPattern(adaptor.getOperand2(), m_ConstantInt(&rhsMask)))
-    return {};
-
-  // x | 0 -> x
-  if (rhsMask.isZero())
-    return getOperand1();
-
-  // x | <all ones> -> <all ones>
-  if (rhsMask.isAllOnes())
-    return getOperand2();
-
-  return {};
-}
-
 //===----------------------------------------------------------------------===//
 // spirv.ImageQuerySize
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/SPIRV/IR/bit-ops.mlir b/mlir/test/Dialect/SPIRV/IR/bit-ops.mlir
index 82a2316f6c784..f3f0ebf60f468 100644
--- a/mlir/test/Dialect/SPIRV/IR/bit-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/bit-ops.mlir
@@ -149,14 +149,16 @@ func.func @bitwise_or_float(%arg0: f16, %arg1: f16) -> f16 {
 //===----------------------------------------------------------------------===//
 
 func.func @bitwise_xor_scalar(%arg: i32) -> i32 {
+  %c1 = spirv.Constant 1 : i32 // using constant to avoid folding
   // CHECK: spirv.BitwiseXor
-  %0 = spirv.BitwiseXor %arg, %arg : i32
+  %0 = spirv.BitwiseXor %c1, %arg : i32
   return %0 : i32
 }
 
 func.func @bitwise_xor_vector(%arg: vector<4xi32>) -> vector<4xi32> {
+  %c1 = spirv.Constant dense<1> : vector<4xi32> // using constant to avoid folding
   // CHECK: spirv.BitwiseXor
-  %0 = spirv.BitwiseXor %arg, %arg : vector<4xi32>
+  %0 = spirv.BitwiseXor %c1, %arg : vector<4xi32>
   return %0 : vector<4xi32>
 }
 

diff  --git a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
index 6be4b8d20ff36..29bea91ce461d 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
@@ -1317,6 +1317,194 @@ func.func @const_fold_vector_lsr() -> vector<3xi32> {
 
 // -----
 
+//===----------------------------------------------------------------------===//
+// spirv.BitwiseAnd
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @bitwise_and_x_x
+// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: vector<3xi32>)
+func.func @bitwise_and_x_x(%arg0: i32, %arg1: vector<3xi32>) -> (i32, vector<3xi32>) {
+  %0 = spirv.BitwiseAnd %arg0, %arg0 : i32
+  %1 = spirv.BitwiseAnd %arg1, %arg1 : vector<3xi32>
+
+  // CHECK: return %[[ARG0]], %[[ARG1]]
+  return %0, %1 : i32, vector<3xi32>
+}
+
+// CHECK-LABEL: @bitwise_and_x_0
+func.func @bitwise_and_x_0(%arg0 : i32, %arg1 : vector<3xi32>) -> (i32, vector<3xi32>) {
+  // CHECK-DAG: %[[C0:.*]] = spirv.Constant 0 : i32
+  // CHECK-DAG: %[[CV0:.*]] = spirv.Constant dense<0> : vector<3xi32>
+  %c0 = spirv.Constant 0 : i32
+  %cv0 = spirv.Constant dense<0> : vector<3xi32>
+
+  %0 = spirv.BitwiseAnd %arg0, %c0 : i32
+  %1 = spirv.BitwiseAnd %arg1, %cv0 : vector<3xi32>
+
+  // CHECK: return %[[C0]], %[[CV0]]
+  return %0, %1 : i32, vector<3xi32>
+}
+
+// CHECK-LABEL: @bitwise_and_x_n1
+// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: vector<3xi32>)
+func.func @bitwise_and_x_n1(%arg0: i32, %arg1: vector<3xi32>) -> (i32, vector<3xi32>) {
+  %cn1 = spirv.Constant -1 : i32
+  %cvn1 = spirv.Constant dense<-1> : vector<3xi32>
+  %0 = spirv.BitwiseAnd %arg0, %cn1 : i32
+  %1 = spirv.BitwiseAnd %arg1, %cvn1 : vector<3xi32>
+
+  // CHECK: return %[[ARG0]], %[[ARG1]]
+  return %0, %1 : i32, vector<3xi32>
+}
+
+// CHECK-LABEL: @const_fold_scalar_band
+func.func @const_fold_scalar_band() -> i32 {
+  %c1 = spirv.Constant -268464129 : i32   // 0xefff 8fff
+  %c2 = spirv.Constant 268464128: i32     // 0x1000 7000
+
+  // 0xefff 8fff | 0x1000 7000 = 0xffff ffff = -1
+  // CHECK: %[[C0:.*]] = spirv.Constant 0
+  %0 = spirv.BitwiseAnd %c1, %c2 : i32
+
+  // CHECK: return %[[C0]]
+  return %0 : i32
+}
+
+// CHECK-LABEL: @const_fold_vector_band
+func.func @const_fold_vector_band() -> vector<3xi32> {
+  %c1 = spirv.Constant dense<[42, -55, 127]> : vector<3xi32>
+  %c2 = spirv.Constant dense<[-3, -15, 28]> : vector<3xi32>
+
+  // CHECK: %[[CV:.*]] = spirv.Constant dense<[40, -63, 28]>
+  %0 = spirv.BitwiseAnd %c1, %c2 : vector<3xi32>
+
+  // CHECK: return %[[CV]]
+  return %0 : vector<3xi32>
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.BitwiseOr
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @bitwise_or_x_x
+// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: vector<3xi32>)
+func.func @bitwise_or_x_x(%arg0: i32, %arg1: vector<3xi32>) -> (i32, vector<3xi32>) {
+  %0 = spirv.BitwiseOr %arg0, %arg0 : i32
+  %1 = spirv.BitwiseOr %arg1, %arg1 : vector<3xi32>
+
+  // CHECK: return %[[ARG0]], %[[ARG1]]
+  return %0, %1 : i32, vector<3xi32>
+}
+
+// CHECK-LABEL: @bitwise_or_x_0
+// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: vector<3xi32>)
+func.func @bitwise_or_x_0(%arg0: i32, %arg1: vector<3xi32>) -> (i32, vector<3xi32>) {
+  %c1 = spirv.Constant 0 : i32
+  %cv1 = spirv.Constant dense<0> : vector<3xi32>
+  %0 = spirv.BitwiseOr %arg0, %c1 : i32
+  %1 = spirv.BitwiseOr %arg1, %cv1 : vector<3xi32>
+
+  // CHECK: return %[[ARG0]], %[[ARG1]]
+  return %0, %1 : i32, vector<3xi32>
+}
+
+// CHECK-LABEL: @bitwise_or_x_n1
+func.func @bitwise_or_x_n1(%arg0 : i32, %arg1 : vector<3xi32>) -> (i32, vector<3xi32>) {
+  // CHECK-DAG: %[[CN1:.*]] = spirv.Constant -1 : i32
+  // CHECK-DAG: %[[CVN1:.*]] = spirv.Constant dense<-1> : vector<3xi32>
+  %cn1 = spirv.Constant -1 : i32
+  %cvn1 = spirv.Constant dense<-1> : vector<3xi32>
+  %0 = spirv.BitwiseOr %arg0, %cn1 : i32
+  %1 = spirv.BitwiseOr %arg1, %cvn1 : vector<3xi32>
+
+  // CHECK: return %[[CN1]], %[[CVN1]]
+  return %0, %1 : i32, vector<3xi32>
+}
+
+// CHECK-LABEL: @const_fold_scalar_bor
+func.func @const_fold_scalar_bor() -> i32 {
+  %c1 = spirv.Constant -268464129 : i32   // 0xefff 8fff
+  %c2 = spirv.Constant 268464128: i32     // 0x1000 7000
+
+  // 0xefff 8fff | 0x1000 7000 = 0xffff ffff = -1
+  // CHECK: %[[CN1:.*]] = spirv.Constant -1
+  %0 = spirv.BitwiseOr %c1, %c2 : i32
+
+  // CHECK: return %[[CN1]]
+  return %0 : i32
+}
+
+// CHECK-LABEL: @const_fold_vector_bor
+func.func @const_fold_vector_bor() -> vector<3xi32> {
+  %c1 = spirv.Constant dense<[42, -55, 127]> : vector<3xi32>
+  %c2 = spirv.Constant dense<[-3, -15, 28]> : vector<3xi32>
+
+  // CHECK: %[[CV:.*]] = spirv.Constant dense<[-1, -7, 127]>
+  %0 = spirv.BitwiseOr %c1, %c2 : vector<3xi32>
+
+  // CHECK: return %[[CV]]
+  return %0 : vector<3xi32>
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.BitwiseXor
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @bitwise_xor_x_0
+// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: vector<3xi32>)
+func.func @bitwise_xor_x_0(%arg0: i32, %arg1: vector<3xi32>) -> (i32, vector<3xi32>) {
+  %c0 = spirv.Constant 0 : i32
+  %cv0 = spirv.Constant dense<0> : vector<3xi32>
+
+  %0 = spirv.BitwiseXor %arg0, %c0 : i32
+  %1 = spirv.BitwiseXor %arg1, %cv0 : vector<3xi32>
+
+  // CHECK: return %[[ARG0]], %[[ARG1]]
+  return %0, %1 : i32, vector<3xi32>
+}
+
+// CHECK-LABEL: @bitwise_xor_x_x
+func.func @bitwise_xor_x_x(%arg0 : i32, %arg1 : vector<3xi32>) -> (i32, vector<3xi32>) {
+  // CHECK-DAG: %[[C0:.*]] = spirv.Constant 0
+  // CHECK-DAG: %[[CV0:.*]] = spirv.Constant dense<0>
+  %0 = spirv.BitwiseXor %arg0, %arg0 : i32
+  %1 = spirv.BitwiseXor %arg1, %arg1 : vector<3xi32>
+
+  // CHECK: return %[[C0]], %[[CV0]]
+  return %0, %1 : i32, vector<3xi32>
+}
+
+// CHECK-LABEL: @const_fold_scalar_bxor
+func.func @const_fold_scalar_bxor() -> i32 {
+  %c1 = spirv.Constant 4294967295 : i32  // 2^32 - 1: 0xffff ffff
+  %c2 = spirv.Constant -2147483648 : i32 // -2^31   : 0x8000 0000
+
+  // 0x8000 0000 ^ 0xffff fffe = 0xefff ffff
+  // CHECK: %[[CBIG:.*]] = spirv.Constant 2147483647
+  %0 = spirv.BitwiseXor %c1, %c2 : i32
+
+  // CHECK: return %[[CBIG]]
+  return %0 : i32
+}
+
+// CHECK-LABEL: @const_fold_vector_bxor
+func.func @const_fold_vector_bxor() -> vector<3xi32> {
+  %c1 = spirv.Constant dense<[42, -55, 127]> : vector<3xi32>
+  %c2 = spirv.Constant dense<[-3, -15, 28]> : vector<3xi32>
+
+  // CHECK: %[[CV:.*]] = spirv.Constant dense<[-41, 56, 99]>
+  %0 = spirv.BitwiseXor %c1, %c2 : vector<3xi32>
+
+  // CHECK: return %[[CV]]
+  return %0 : vector<3xi32>
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.mlir.selection
 //===----------------------------------------------------------------------===//


        


More information about the Mlir-commits mailing list