[Mlir-commits] [mlir] [mlir][spirv] Add folding for Bitwise[Or|And|Xor] (PR #74193)

Finn Plummer llvmlistbot at llvm.org
Sun Dec 10 05:23:59 PST 2023


https://github.com/inbelic updated https://github.com/llvm/llvm-project/pull/74193

>From da997382f9e6781d6f9fb9fb97f89bc63ee4bdd2 Mon Sep 17 00:00:00 2001
From: inbelic <canadienfinn at gmail.com>
Date: Fri, 24 Nov 2023 10:03:40 +0100
Subject: [PATCH 1/2] [mlir][spirv] Add folding for Bitwise[Or|And|Xor]

Add missing constant propogation folder for Bitwise[Or|And|Xor].

Move previous Bitwise[Or|And] fold implementations to
SPIRVCanonicalization for consistency.

Implement additional folding when lhs == rhs and rhs = 0 for Xor.
As well as, update an Xor testcase to account for this introduced folding.

This helps for readability of lowered code into SPIR-V.

Part of work for #70704
---
 .../mlir/Dialect/SPIRV/IR/SPIRVBitOps.td      |   2 +
 .../SPIRV/IR/SPIRVCanonicalization.cpp        |  95 ++++++++++-
 mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp        |  49 ------
 mlir/test/Dialect/SPIRV/IR/bit-ops.mlir       |   6 +-
 .../SPIRV/Transforms/canonicalize.mlir        | 155 ++++++++++++++++++
 5 files changed, 247 insertions(+), 60 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td
index 7423bb3c68d7c..b460c8e68aa0c 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td
@@ -334,6 +334,8 @@ def SPIRV_BitwiseXorOp : SPIRV_BitBinaryOp<"BitwiseXor",
     %2 = spirv.BitwiseXor %0, %1 : vector<4xi32>
     ```
   }];
+
+  let hasFolder = 1;
 }
 
 // -----
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index 4643fc08c08d3..e98d054851d31 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -722,9 +722,6 @@ OpFoldResult spirv::ShiftLeftLogicalOp::fold(
 
   // Unfortunately due to below undefined behaviour can't fold 0 for Base.
 
-  // According to the SPIR-V spec:
-  //
-  // Type is a scalar or vector of integer type.
   // Results are computed per component, and within each component, per bit...
   //
   // The result is undefined if Shift is greater than or equal to the bit width
@@ -756,9 +753,6 @@ OpFoldResult spirv::ShiftRightArithmeticOp::fold(
 
   // Unfortunately due to below undefined behaviour can't fold 0, -1 for Base.
 
-  // According to the SPIR-V spec:
-  //
-  // Type is a scalar or vector of integer type.
   // Results are computed per component, and within each component, per bit...
   //
   // The result is undefined if Shift is greater than or equal to the bit width
@@ -790,9 +784,6 @@ OpFoldResult spirv::ShiftRightLogicalOp::fold(
 
   // Unfortunately due to below undefined behaviour can't fold 0 for Base.
 
-  // According to the SPIR-V spec:
-  //
-  // Type is a scalar or vector of integer type.
   // Results are computed per component, and within each component, per bit...
   //
   // The result is undefined if Shift is greater than or equal to the bit width
@@ -811,6 +802,92 @@ OpFoldResult spirv::ShiftRightLogicalOp::fold(
   return shiftToLarge ? Attribute() : res;
 }
 
+//===----------------------------------------------------------------------===//
+// spirv.BitwiseAndOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult
+spirv::BitwiseAndOp::fold(spirv::BitwiseAndOp::FoldAdaptor adaptor) {
+  APInt rhsMask;
+  if (matchPattern(adaptor.getOperand2(), m_ConstantInt(&rhsMask))) {
+    // x & 0 -> 0
+    if (rhsMask.isZero())
+      return getOperand2();
+
+    // x & <all ones> -> x
+    if (rhsMask.isAllOnes())
+      return getOperand1();
+
+    // (UConvert x : iN to iK) & <mask with N low bits set> -> UConvert x
+    if (auto zext = getOperand1().getDefiningOp<spirv::UConvertOp>()) {
+      int valueBits =
+          getElementTypeOrSelf(zext.getOperand()).getIntOrFloatBitWidth();
+      if (rhsMask.zextOrTrunc(valueBits).isAllOnes())
+        return getOperand1();
+    }
+  }
+
+  // According to the SPIR-V spec:
+  //
+  // Type is a scalar or vector of integer type.
+  // Results are computed per component, and within each component, per bit.
+  // So we can use the APInt & method.
+  return constFoldBinaryOp<IntegerAttr>(
+      adaptor.getOperands(),
+      [](const APInt &a, const APInt &b) { return a & b; });
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.BitwiseOrOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::BitwiseOrOp::fold(spirv::BitwiseOrOp::FoldAdaptor adaptor) {
+  APInt rhsMask;
+  if (matchPattern(adaptor.getOperand2(), m_ConstantInt(&rhsMask))) {
+    // x | 0 -> x
+    if (rhsMask.isZero())
+      return getOperand1();
+
+    // x | <all ones> -> <all ones>
+    if (rhsMask.isAllOnes())
+      return getOperand2();
+  }
+
+  // According to the SPIR-V spec:
+  //
+  // Type is a scalar or vector of integer type.
+  // Results are computed per component, and within each component, per bit.
+  // So we can use the APInt | method.
+  return constFoldBinaryOp<IntegerAttr>(
+      adaptor.getOperands(),
+      [](const APInt &a, const APInt &b) { return a | b; });
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.BitwiseXorOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult
+spirv::BitwiseXorOp::fold(spirv::BitwiseXorOp::FoldAdaptor adaptor) {
+  // x ^ 0 -> x
+  if (matchPattern(adaptor.getOperand2(), m_Zero())) {
+    return getOperand1();
+  }
+
+  // x ^ x -> 0
+  if (getOperand1() == getOperand2())
+    return Builder(getContext()).getZeroAttr(getType());
+
+  // According to the SPIR-V spec:
+  //
+  // Type is a scalar or vector of integer type.
+  // Results are computed per component, and within each component, per bit.
+  // So we can use the APInt ^ method.
+  return constFoldBinaryOp<IntegerAttr>(
+      adaptor.getOperands(),
+      [](const APInt &a, const APInt &b) { return a ^ b; });
+}
+
 //===----------------------------------------------------------------------===//
 // spirv.mlir.selection
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 3906bf74ea722..2a1d083308282 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -1968,55 +1968,6 @@ LogicalResult spirv::ShiftRightLogicalOp::verify() {
   return verifyShiftOp(*this);
 }
 
-//===----------------------------------------------------------------------===//
-// spirv.BtiwiseAndOp
-//===----------------------------------------------------------------------===//
-
-OpFoldResult
-spirv::BitwiseAndOp::fold(spirv::BitwiseAndOp::FoldAdaptor adaptor) {
-  APInt rhsMask;
-  if (!matchPattern(adaptor.getOperand2(), m_ConstantInt(&rhsMask)))
-    return {};
-
-  // x & 0 -> 0
-  if (rhsMask.isZero())
-    return getOperand2();
-
-  // x & <all ones> -> x
-  if (rhsMask.isAllOnes())
-    return getOperand1();
-
-  // (UConvert x : iN to iK) & <mask with N low bits set> -> UConvert x
-  if (auto zext = getOperand1().getDefiningOp<spirv::UConvertOp>()) {
-    int valueBits =
-        getElementTypeOrSelf(zext.getOperand()).getIntOrFloatBitWidth();
-    if (rhsMask.zextOrTrunc(valueBits).isAllOnes())
-      return getOperand1();
-  }
-
-  return {};
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.BtiwiseOrOp
-//===----------------------------------------------------------------------===//
-
-OpFoldResult spirv::BitwiseOrOp::fold(spirv::BitwiseOrOp::FoldAdaptor adaptor) {
-  APInt rhsMask;
-  if (!matchPattern(adaptor.getOperand2(), m_ConstantInt(&rhsMask)))
-    return {};
-
-  // x | 0 -> x
-  if (rhsMask.isZero())
-    return getOperand1();
-
-  // x | <all ones> -> <all ones>
-  if (rhsMask.isAllOnes())
-    return getOperand2();
-
-  return {};
-}
-
 //===----------------------------------------------------------------------===//
 // spirv.ImageQuerySize
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/IR/bit-ops.mlir b/mlir/test/Dialect/SPIRV/IR/bit-ops.mlir
index 82a2316f6c784..f3f0ebf60f468 100644
--- a/mlir/test/Dialect/SPIRV/IR/bit-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/bit-ops.mlir
@@ -149,14 +149,16 @@ func.func @bitwise_or_float(%arg0: f16, %arg1: f16) -> f16 {
 //===----------------------------------------------------------------------===//
 
 func.func @bitwise_xor_scalar(%arg: i32) -> i32 {
+  %c1 = spirv.Constant 1 : i32 // using constant to avoid folding
   // CHECK: spirv.BitwiseXor
-  %0 = spirv.BitwiseXor %arg, %arg : i32
+  %0 = spirv.BitwiseXor %c1, %arg : i32
   return %0 : i32
 }
 
 func.func @bitwise_xor_vector(%arg: vector<4xi32>) -> vector<4xi32> {
+  %c1 = spirv.Constant dense<1> : vector<4xi32> // using constant to avoid folding
   // CHECK: spirv.BitwiseXor
-  %0 = spirv.BitwiseXor %arg, %arg : vector<4xi32>
+  %0 = spirv.BitwiseXor %c1, %arg : vector<4xi32>
   return %0 : vector<4xi32>
 }
 
diff --git a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
index 6be4b8d20ff36..2d39f0ccb7642 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
@@ -1317,6 +1317,161 @@ func.func @const_fold_vector_lsr() -> vector<3xi32> {
 
 // -----
 
+//===----------------------------------------------------------------------===//
+// spirv.BitwiseAnd
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @bitwise_and_x_0
+// CHECK-SAME: (%[[ARG:.*]]: i32)
+func.func @bitwise_and_x_0(%arg0 : i32) -> i32 {
+  // CHECK: %[[C0:.*]] = spirv.Constant 0 : i32
+  %c1 = spirv.Constant 0 : i32
+  %0 = spirv.BitwiseAnd %arg0, %c1 : i32
+
+  // CHECK: return %[[C0]]
+  return %0 : i32
+}
+
+// CHECK-LABEL: @bitwise_and_x_n1
+// CHECK-SAME: (%[[ARG:.*]]: i32)
+func.func @bitwise_and_x_n1(%arg0 : i32) -> i32 {
+  %c1 = spirv.Constant -1 : i32
+  %0 = spirv.BitwiseAnd %arg0, %c1 : i32
+
+  // CHECK: return %[[ARG]]
+  return %0 : i32
+}
+
+// CHECK-LABEL: @const_fold_scalar_band
+func.func @const_fold_scalar_band() -> i32 {
+  %c1 = spirv.Constant -268464129 : i32   // 0xefff 8fff
+  %c2 = spirv.Constant 268464128: i32     // 0x1000 7000
+
+  // 0xefff 8fff | 0x1000 7000 = 0xffff ffff = -1
+  // CHECK: %[[C0:.*]] = spirv.Constant 0
+  %0 = spirv.BitwiseAnd %c1, %c2 : i32
+
+  // CHECK: return %[[C0]]
+  return %0 : i32
+}
+
+// CHECK-LABEL: @const_fold_vector_band
+func.func @const_fold_vector_band() -> vector<3xi32> {
+  %c1 = spirv.Constant dense<[42, -55, 127]> : vector<3xi32>
+  %c2 = spirv.Constant dense<[-3, -15, 28]> : vector<3xi32>
+
+  // CHECK: %[[CV:.*]] = spirv.Constant dense<[40, -63, 28]>
+  %0 = spirv.BitwiseAnd %c1, %c2 : vector<3xi32>
+
+  // CHECK: return %[[CV]]
+  return %0 : vector<3xi32>
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.BitwiseOr
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @bitwise_or_x_0
+// CHECK-SAME: (%[[ARG:.*]]: i32)
+func.func @bitwise_or_x_0(%arg0 : i32) -> i32 {
+  %c1 = spirv.Constant 0 : i32
+  %0 = spirv.BitwiseOr %arg0, %c1 : i32
+
+  // CHECK: return %[[ARG]]
+  return %0 : i32
+}
+
+// CHECK-LABEL: @bitwise_or_x_n1
+func.func @bitwise_or_x_n1(%arg0 : i32) -> i32 {
+  // CHECK: %[[CN1:.*]] = spirv.Constant -1 : i32
+  %c1 = spirv.Constant -1 : i32
+  %0 = spirv.BitwiseOr %arg0, %c1 : i32
+
+  // CHECK: return %[[CN1]]
+  return %0 : i32
+}
+
+// CHECK-LABEL: @const_fold_scalar_bor
+func.func @const_fold_scalar_bor() -> i32 {
+  %c1 = spirv.Constant -268464129 : i32   // 0xefff 8fff
+  %c2 = spirv.Constant 268464128: i32     // 0x1000 7000
+
+  // 0xefff 8fff | 0x1000 7000 = 0xffff ffff = -1
+  // CHECK: %[[CN1:.*]] = spirv.Constant -1
+  %0 = spirv.BitwiseOr %c1, %c2 : i32
+
+  // CHECK: return %[[CN1]]
+  return %0 : i32
+}
+
+// CHECK-LABEL: @const_fold_vector_bor
+func.func @const_fold_vector_bor() -> vector<3xi32> {
+  %c1 = spirv.Constant dense<[42, -55, 127]> : vector<3xi32>
+  %c2 = spirv.Constant dense<[-3, -15, 28]> : vector<3xi32>
+
+  // CHECK: %[[CV:.*]] = spirv.Constant dense<[-1, -7, 127]>
+  %0 = spirv.BitwiseOr %c1, %c2 : vector<3xi32>
+
+  // CHECK: return %[[CV]]
+  return %0 : vector<3xi32>
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.BitwiseXor
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @bitwise_xor_x_0
+// CHECK-SAME: (%[[ARG:.*]]: i32)
+func.func @bitwise_xor_x_0(%arg0 : i32) -> i32 {
+  %c1 = spirv.Constant 0 : i32
+  %0 = spirv.BitwiseXor %arg0, %c1 : i32
+
+  // CHECK: return %[[ARG]]
+  return %0 : i32
+}
+
+// CHECK-LABEL: @bitwise_xor_x_x
+func.func @bitwise_xor_x_x(%arg0: i32, %arg1: vector<3xi32>) -> (i32, vector<3xi32>) {
+  // CHECK-DAG: %[[C0:.*]] = spirv.Constant 0
+  // CHECK-DAG: %[[CV0:.*]] = spirv.Constant dense<0>
+  %0 = spirv.BitwiseXor %arg0, %arg0 : i32
+  %1 = spirv.BitwiseXor %arg1, %arg1 : vector<3xi32>
+
+  // CHECK: return %[[C0]], %[[CV0]]
+  return %0, %1 : i32, vector<3xi32>
+}
+
+// CHECK-LABEL: @const_fold_scalar_bxor
+func.func @const_fold_scalar_bxor() -> i32 {
+  %c1 = spirv.Constant 4294967295 : i32  // 2^32 - 1: 0xffff ffff
+  %c2 = spirv.Constant -2147483648 : i32 // -2^31   : 0x8000 0000
+
+  // 0x8000 0000 ^ 0xffff fffe = 0xefff ffff
+  // CHECK: %[[CBIG:.*]] = spirv.Constant 2147483647
+  %0 = spirv.BitwiseXor %c1, %c2 : i32
+
+  // CHECK: return %[[CBIG]]
+  return %0 : i32
+}
+
+// CHECK-LABEL: @const_fold_vector_bxor
+func.func @const_fold_vector_bxor() -> vector<3xi32> {
+  %c1 = spirv.Constant dense<[42, -55, 127]> : vector<3xi32>
+  %c2 = spirv.Constant dense<[-3, -15, 28]> : vector<3xi32>
+
+  // CHECK: %[[CV:.*]] = spirv.Constant dense<[-41, 56, 99]>
+  %0 = spirv.BitwiseXor %c1, %c2 : vector<3xi32>
+
+  // CHECK: return %[[CV]]
+  return %0 : vector<3xi32>
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.mlir.selection
 //===----------------------------------------------------------------------===//

>From 8fac6e243a176561d926520e4988993bd5a949e8 Mon Sep 17 00:00:00 2001
From: inbelic <canadienfinn at gmail.com>
Date: Sun, 10 Dec 2023 13:00:26 +0100
Subject: [PATCH 2/2] review comments:

- add folding for x & x -> x, and, x | x -> x
- extend testcases to display vector capabilities
---
 .../SPIRV/IR/SPIRVCanonicalization.cpp        | 10 ++
 .../SPIRV/Transforms/canonicalize.mlir        | 93 +++++++++++++------
 2 files changed, 73 insertions(+), 30 deletions(-)

diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index e98d054851d31..9de1707dfca46 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -808,6 +808,11 @@ OpFoldResult spirv::ShiftRightLogicalOp::fold(
 
 OpFoldResult
 spirv::BitwiseAndOp::fold(spirv::BitwiseAndOp::FoldAdaptor adaptor) {
+  // x & x -> x
+  if (getOperand1() == getOperand2()) {
+    return getOperand1();
+  }
+
   APInt rhsMask;
   if (matchPattern(adaptor.getOperand2(), m_ConstantInt(&rhsMask))) {
     // x & 0 -> 0
@@ -842,6 +847,11 @@ spirv::BitwiseAndOp::fold(spirv::BitwiseAndOp::FoldAdaptor adaptor) {
 //===----------------------------------------------------------------------===//
 
 OpFoldResult spirv::BitwiseOrOp::fold(spirv::BitwiseOrOp::FoldAdaptor adaptor) {
+  // x | x -> x
+  if (getOperand1() == getOperand2()) {
+    return getOperand1();
+  }
+
   APInt rhsMask;
   if (matchPattern(adaptor.getOperand2(), m_ConstantInt(&rhsMask))) {
     // x | 0 -> x
diff --git a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
index 2d39f0ccb7642..29bea91ce461d 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
@@ -1321,25 +1321,40 @@ func.func @const_fold_vector_lsr() -> vector<3xi32> {
 // spirv.BitwiseAnd
 //===----------------------------------------------------------------------===//
 
+// CHECK-LABEL: @bitwise_and_x_x
+// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: vector<3xi32>)
+func.func @bitwise_and_x_x(%arg0: i32, %arg1: vector<3xi32>) -> (i32, vector<3xi32>) {
+  %0 = spirv.BitwiseAnd %arg0, %arg0 : i32
+  %1 = spirv.BitwiseAnd %arg1, %arg1 : vector<3xi32>
+
+  // CHECK: return %[[ARG0]], %[[ARG1]]
+  return %0, %1 : i32, vector<3xi32>
+}
+
 // CHECK-LABEL: @bitwise_and_x_0
-// CHECK-SAME: (%[[ARG:.*]]: i32)
-func.func @bitwise_and_x_0(%arg0 : i32) -> i32 {
-  // CHECK: %[[C0:.*]] = spirv.Constant 0 : i32
-  %c1 = spirv.Constant 0 : i32
-  %0 = spirv.BitwiseAnd %arg0, %c1 : i32
+func.func @bitwise_and_x_0(%arg0 : i32, %arg1 : vector<3xi32>) -> (i32, vector<3xi32>) {
+  // CHECK-DAG: %[[C0:.*]] = spirv.Constant 0 : i32
+  // CHECK-DAG: %[[CV0:.*]] = spirv.Constant dense<0> : vector<3xi32>
+  %c0 = spirv.Constant 0 : i32
+  %cv0 = spirv.Constant dense<0> : vector<3xi32>
 
-  // CHECK: return %[[C0]]
-  return %0 : i32
+  %0 = spirv.BitwiseAnd %arg0, %c0 : i32
+  %1 = spirv.BitwiseAnd %arg1, %cv0 : vector<3xi32>
+
+  // CHECK: return %[[C0]], %[[CV0]]
+  return %0, %1 : i32, vector<3xi32>
 }
 
 // CHECK-LABEL: @bitwise_and_x_n1
-// CHECK-SAME: (%[[ARG:.*]]: i32)
-func.func @bitwise_and_x_n1(%arg0 : i32) -> i32 {
-  %c1 = spirv.Constant -1 : i32
-  %0 = spirv.BitwiseAnd %arg0, %c1 : i32
+// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: vector<3xi32>)
+func.func @bitwise_and_x_n1(%arg0: i32, %arg1: vector<3xi32>) -> (i32, vector<3xi32>) {
+  %cn1 = spirv.Constant -1 : i32
+  %cvn1 = spirv.Constant dense<-1> : vector<3xi32>
+  %0 = spirv.BitwiseAnd %arg0, %cn1 : i32
+  %1 = spirv.BitwiseAnd %arg1, %cvn1 : vector<3xi32>
 
-  // CHECK: return %[[ARG]]
-  return %0 : i32
+  // CHECK: return %[[ARG0]], %[[ARG1]]
+  return %0, %1 : i32, vector<3xi32>
 }
 
 // CHECK-LABEL: @const_fold_scalar_band
@@ -1373,24 +1388,39 @@ func.func @const_fold_vector_band() -> vector<3xi32> {
 // spirv.BitwiseOr
 //===----------------------------------------------------------------------===//
 
+// CHECK-LABEL: @bitwise_or_x_x
+// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: vector<3xi32>)
+func.func @bitwise_or_x_x(%arg0: i32, %arg1: vector<3xi32>) -> (i32, vector<3xi32>) {
+  %0 = spirv.BitwiseOr %arg0, %arg0 : i32
+  %1 = spirv.BitwiseOr %arg1, %arg1 : vector<3xi32>
+
+  // CHECK: return %[[ARG0]], %[[ARG1]]
+  return %0, %1 : i32, vector<3xi32>
+}
+
 // CHECK-LABEL: @bitwise_or_x_0
-// CHECK-SAME: (%[[ARG:.*]]: i32)
-func.func @bitwise_or_x_0(%arg0 : i32) -> i32 {
+// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: vector<3xi32>)
+func.func @bitwise_or_x_0(%arg0: i32, %arg1: vector<3xi32>) -> (i32, vector<3xi32>) {
   %c1 = spirv.Constant 0 : i32
+  %cv1 = spirv.Constant dense<0> : vector<3xi32>
   %0 = spirv.BitwiseOr %arg0, %c1 : i32
+  %1 = spirv.BitwiseOr %arg1, %cv1 : vector<3xi32>
 
-  // CHECK: return %[[ARG]]
-  return %0 : i32
+  // CHECK: return %[[ARG0]], %[[ARG1]]
+  return %0, %1 : i32, vector<3xi32>
 }
 
 // CHECK-LABEL: @bitwise_or_x_n1
-func.func @bitwise_or_x_n1(%arg0 : i32) -> i32 {
-  // CHECK: %[[CN1:.*]] = spirv.Constant -1 : i32
-  %c1 = spirv.Constant -1 : i32
-  %0 = spirv.BitwiseOr %arg0, %c1 : i32
+func.func @bitwise_or_x_n1(%arg0 : i32, %arg1 : vector<3xi32>) -> (i32, vector<3xi32>) {
+  // CHECK-DAG: %[[CN1:.*]] = spirv.Constant -1 : i32
+  // CHECK-DAG: %[[CVN1:.*]] = spirv.Constant dense<-1> : vector<3xi32>
+  %cn1 = spirv.Constant -1 : i32
+  %cvn1 = spirv.Constant dense<-1> : vector<3xi32>
+  %0 = spirv.BitwiseOr %arg0, %cn1 : i32
+  %1 = spirv.BitwiseOr %arg1, %cvn1 : vector<3xi32>
 
-  // CHECK: return %[[CN1]]
-  return %0 : i32
+  // CHECK: return %[[CN1]], %[[CVN1]]
+  return %0, %1 : i32, vector<3xi32>
 }
 
 // CHECK-LABEL: @const_fold_scalar_bor
@@ -1425,17 +1455,20 @@ func.func @const_fold_vector_bor() -> vector<3xi32> {
 //===----------------------------------------------------------------------===//
 
 // CHECK-LABEL: @bitwise_xor_x_0
-// CHECK-SAME: (%[[ARG:.*]]: i32)
-func.func @bitwise_xor_x_0(%arg0 : i32) -> i32 {
-  %c1 = spirv.Constant 0 : i32
-  %0 = spirv.BitwiseXor %arg0, %c1 : i32
+// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: vector<3xi32>)
+func.func @bitwise_xor_x_0(%arg0: i32, %arg1: vector<3xi32>) -> (i32, vector<3xi32>) {
+  %c0 = spirv.Constant 0 : i32
+  %cv0 = spirv.Constant dense<0> : vector<3xi32>
 
-  // CHECK: return %[[ARG]]
-  return %0 : i32
+  %0 = spirv.BitwiseXor %arg0, %c0 : i32
+  %1 = spirv.BitwiseXor %arg1, %cv0 : vector<3xi32>
+
+  // CHECK: return %[[ARG0]], %[[ARG1]]
+  return %0, %1 : i32, vector<3xi32>
 }
 
 // CHECK-LABEL: @bitwise_xor_x_x
-func.func @bitwise_xor_x_x(%arg0: i32, %arg1: vector<3xi32>) -> (i32, vector<3xi32>) {
+func.func @bitwise_xor_x_x(%arg0 : i32, %arg1 : vector<3xi32>) -> (i32, vector<3xi32>) {
   // CHECK-DAG: %[[C0:.*]] = spirv.Constant 0
   // CHECK-DAG: %[[CV0:.*]] = spirv.Constant dense<0>
   %0 = spirv.BitwiseXor %arg0, %arg0 : i32



More information about the Mlir-commits mailing list