[Mlir-commits] [mlir] 40e2bb5 - [mlir][spirv] Add folding for Bitwise[Or|And|Xor] (#74193)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Dec 11 10:09:45 PST 2023
Author: Finn Plummer
Date: 2023-12-11T13:09:40-05:00
New Revision: 40e2bb5330840b56d452244f96e491b6530ce4bf
URL: https://github.com/llvm/llvm-project/commit/40e2bb5330840b56d452244f96e491b6530ce4bf
DIFF: https://github.com/llvm/llvm-project/commit/40e2bb5330840b56d452244f96e491b6530ce4bf.diff
LOG: [mlir][spirv] Add folding for Bitwise[Or|And|Xor] (#74193)
Add missing constant propogation folder for Bitwise[Or|And|Xor].
Move previous Bitwise[Or|And] fold implementations to
SPIRVCanonicalization for consistency.
Implement additional folding when lhs == rhs and rhs = 0 for Xor. As
well as, update an Xor testcase to account for this introduced folding.
This helps for readability of lowered code into SPIR-V.
Part of work for #70704
Added:
Modified:
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td
mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
mlir/test/Dialect/SPIRV/IR/bit-ops.mlir
mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td
index 7423bb3c68d7c..b460c8e68aa0c 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td
@@ -334,6 +334,8 @@ def SPIRV_BitwiseXorOp : SPIRV_BitBinaryOp<"BitwiseXor",
%2 = spirv.BitwiseXor %0, %1 : vector<4xi32>
```
}];
+
+ let hasFolder = 1;
}
// -----
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index 4643fc08c08d3..9de1707dfca46 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -722,9 +722,6 @@ OpFoldResult spirv::ShiftLeftLogicalOp::fold(
// Unfortunately due to below undefined behaviour can't fold 0 for Base.
- // According to the SPIR-V spec:
- //
- // Type is a scalar or vector of integer type.
// Results are computed per component, and within each component, per bit...
//
// The result is undefined if Shift is greater than or equal to the bit width
@@ -756,9 +753,6 @@ OpFoldResult spirv::ShiftRightArithmeticOp::fold(
// Unfortunately due to below undefined behaviour can't fold 0, -1 for Base.
- // According to the SPIR-V spec:
- //
- // Type is a scalar or vector of integer type.
// Results are computed per component, and within each component, per bit...
//
// The result is undefined if Shift is greater than or equal to the bit width
@@ -790,9 +784,6 @@ OpFoldResult spirv::ShiftRightLogicalOp::fold(
// Unfortunately due to below undefined behaviour can't fold 0 for Base.
- // According to the SPIR-V spec:
- //
- // Type is a scalar or vector of integer type.
// Results are computed per component, and within each component, per bit...
//
// The result is undefined if Shift is greater than or equal to the bit width
@@ -811,6 +802,102 @@ OpFoldResult spirv::ShiftRightLogicalOp::fold(
return shiftToLarge ? Attribute() : res;
}
+//===----------------------------------------------------------------------===//
+// spirv.BitwiseAndOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult
+spirv::BitwiseAndOp::fold(spirv::BitwiseAndOp::FoldAdaptor adaptor) {
+ // x & x -> x
+ if (getOperand1() == getOperand2()) {
+ return getOperand1();
+ }
+
+ APInt rhsMask;
+ if (matchPattern(adaptor.getOperand2(), m_ConstantInt(&rhsMask))) {
+ // x & 0 -> 0
+ if (rhsMask.isZero())
+ return getOperand2();
+
+ // x & <all ones> -> x
+ if (rhsMask.isAllOnes())
+ return getOperand1();
+
+ // (UConvert x : iN to iK) & <mask with N low bits set> -> UConvert x
+ if (auto zext = getOperand1().getDefiningOp<spirv::UConvertOp>()) {
+ int valueBits =
+ getElementTypeOrSelf(zext.getOperand()).getIntOrFloatBitWidth();
+ if (rhsMask.zextOrTrunc(valueBits).isAllOnes())
+ return getOperand1();
+ }
+ }
+
+ // According to the SPIR-V spec:
+ //
+ // Type is a scalar or vector of integer type.
+ // Results are computed per component, and within each component, per bit.
+ // So we can use the APInt & method.
+ return constFoldBinaryOp<IntegerAttr>(
+ adaptor.getOperands(),
+ [](const APInt &a, const APInt &b) { return a & b; });
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.BitwiseOrOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::BitwiseOrOp::fold(spirv::BitwiseOrOp::FoldAdaptor adaptor) {
+ // x | x -> x
+ if (getOperand1() == getOperand2()) {
+ return getOperand1();
+ }
+
+ APInt rhsMask;
+ if (matchPattern(adaptor.getOperand2(), m_ConstantInt(&rhsMask))) {
+ // x | 0 -> x
+ if (rhsMask.isZero())
+ return getOperand1();
+
+ // x | <all ones> -> <all ones>
+ if (rhsMask.isAllOnes())
+ return getOperand2();
+ }
+
+ // According to the SPIR-V spec:
+ //
+ // Type is a scalar or vector of integer type.
+ // Results are computed per component, and within each component, per bit.
+ // So we can use the APInt | method.
+ return constFoldBinaryOp<IntegerAttr>(
+ adaptor.getOperands(),
+ [](const APInt &a, const APInt &b) { return a | b; });
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.BitwiseXorOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult
+spirv::BitwiseXorOp::fold(spirv::BitwiseXorOp::FoldAdaptor adaptor) {
+ // x ^ 0 -> x
+ if (matchPattern(adaptor.getOperand2(), m_Zero())) {
+ return getOperand1();
+ }
+
+ // x ^ x -> 0
+ if (getOperand1() == getOperand2())
+ return Builder(getContext()).getZeroAttr(getType());
+
+ // According to the SPIR-V spec:
+ //
+ // Type is a scalar or vector of integer type.
+ // Results are computed per component, and within each component, per bit.
+ // So we can use the APInt ^ method.
+ return constFoldBinaryOp<IntegerAttr>(
+ adaptor.getOperands(),
+ [](const APInt &a, const APInt &b) { return a ^ b; });
+}
+
//===----------------------------------------------------------------------===//
// spirv.mlir.selection
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 3906bf74ea722..2a1d083308282 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -1968,55 +1968,6 @@ LogicalResult spirv::ShiftRightLogicalOp::verify() {
return verifyShiftOp(*this);
}
-//===----------------------------------------------------------------------===//
-// spirv.BtiwiseAndOp
-//===----------------------------------------------------------------------===//
-
-OpFoldResult
-spirv::BitwiseAndOp::fold(spirv::BitwiseAndOp::FoldAdaptor adaptor) {
- APInt rhsMask;
- if (!matchPattern(adaptor.getOperand2(), m_ConstantInt(&rhsMask)))
- return {};
-
- // x & 0 -> 0
- if (rhsMask.isZero())
- return getOperand2();
-
- // x & <all ones> -> x
- if (rhsMask.isAllOnes())
- return getOperand1();
-
- // (UConvert x : iN to iK) & <mask with N low bits set> -> UConvert x
- if (auto zext = getOperand1().getDefiningOp<spirv::UConvertOp>()) {
- int valueBits =
- getElementTypeOrSelf(zext.getOperand()).getIntOrFloatBitWidth();
- if (rhsMask.zextOrTrunc(valueBits).isAllOnes())
- return getOperand1();
- }
-
- return {};
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.BtiwiseOrOp
-//===----------------------------------------------------------------------===//
-
-OpFoldResult spirv::BitwiseOrOp::fold(spirv::BitwiseOrOp::FoldAdaptor adaptor) {
- APInt rhsMask;
- if (!matchPattern(adaptor.getOperand2(), m_ConstantInt(&rhsMask)))
- return {};
-
- // x | 0 -> x
- if (rhsMask.isZero())
- return getOperand1();
-
- // x | <all ones> -> <all ones>
- if (rhsMask.isAllOnes())
- return getOperand2();
-
- return {};
-}
-
//===----------------------------------------------------------------------===//
// spirv.ImageQuerySize
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/IR/bit-ops.mlir b/mlir/test/Dialect/SPIRV/IR/bit-ops.mlir
index 82a2316f6c784..f3f0ebf60f468 100644
--- a/mlir/test/Dialect/SPIRV/IR/bit-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/bit-ops.mlir
@@ -149,14 +149,16 @@ func.func @bitwise_or_float(%arg0: f16, %arg1: f16) -> f16 {
//===----------------------------------------------------------------------===//
func.func @bitwise_xor_scalar(%arg: i32) -> i32 {
+ %c1 = spirv.Constant 1 : i32 // using constant to avoid folding
// CHECK: spirv.BitwiseXor
- %0 = spirv.BitwiseXor %arg, %arg : i32
+ %0 = spirv.BitwiseXor %c1, %arg : i32
return %0 : i32
}
func.func @bitwise_xor_vector(%arg: vector<4xi32>) -> vector<4xi32> {
+ %c1 = spirv.Constant dense<1> : vector<4xi32> // using constant to avoid folding
// CHECK: spirv.BitwiseXor
- %0 = spirv.BitwiseXor %arg, %arg : vector<4xi32>
+ %0 = spirv.BitwiseXor %c1, %arg : vector<4xi32>
return %0 : vector<4xi32>
}
diff --git a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
index 6be4b8d20ff36..29bea91ce461d 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
@@ -1317,6 +1317,194 @@ func.func @const_fold_vector_lsr() -> vector<3xi32> {
// -----
+//===----------------------------------------------------------------------===//
+// spirv.BitwiseAnd
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @bitwise_and_x_x
+// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: vector<3xi32>)
+func.func @bitwise_and_x_x(%arg0: i32, %arg1: vector<3xi32>) -> (i32, vector<3xi32>) {
+ %0 = spirv.BitwiseAnd %arg0, %arg0 : i32
+ %1 = spirv.BitwiseAnd %arg1, %arg1 : vector<3xi32>
+
+ // CHECK: return %[[ARG0]], %[[ARG1]]
+ return %0, %1 : i32, vector<3xi32>
+}
+
+// CHECK-LABEL: @bitwise_and_x_0
+func.func @bitwise_and_x_0(%arg0 : i32, %arg1 : vector<3xi32>) -> (i32, vector<3xi32>) {
+ // CHECK-DAG: %[[C0:.*]] = spirv.Constant 0 : i32
+ // CHECK-DAG: %[[CV0:.*]] = spirv.Constant dense<0> : vector<3xi32>
+ %c0 = spirv.Constant 0 : i32
+ %cv0 = spirv.Constant dense<0> : vector<3xi32>
+
+ %0 = spirv.BitwiseAnd %arg0, %c0 : i32
+ %1 = spirv.BitwiseAnd %arg1, %cv0 : vector<3xi32>
+
+ // CHECK: return %[[C0]], %[[CV0]]
+ return %0, %1 : i32, vector<3xi32>
+}
+
+// CHECK-LABEL: @bitwise_and_x_n1
+// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: vector<3xi32>)
+func.func @bitwise_and_x_n1(%arg0: i32, %arg1: vector<3xi32>) -> (i32, vector<3xi32>) {
+ %cn1 = spirv.Constant -1 : i32
+ %cvn1 = spirv.Constant dense<-1> : vector<3xi32>
+ %0 = spirv.BitwiseAnd %arg0, %cn1 : i32
+ %1 = spirv.BitwiseAnd %arg1, %cvn1 : vector<3xi32>
+
+ // CHECK: return %[[ARG0]], %[[ARG1]]
+ return %0, %1 : i32, vector<3xi32>
+}
+
+// CHECK-LABEL: @const_fold_scalar_band
+func.func @const_fold_scalar_band() -> i32 {
+ %c1 = spirv.Constant -268464129 : i32 // 0xefff 8fff
+ %c2 = spirv.Constant 268464128: i32 // 0x1000 7000
+
+ // 0xefff 8fff | 0x1000 7000 = 0xffff ffff = -1
+ // CHECK: %[[C0:.*]] = spirv.Constant 0
+ %0 = spirv.BitwiseAnd %c1, %c2 : i32
+
+ // CHECK: return %[[C0]]
+ return %0 : i32
+}
+
+// CHECK-LABEL: @const_fold_vector_band
+func.func @const_fold_vector_band() -> vector<3xi32> {
+ %c1 = spirv.Constant dense<[42, -55, 127]> : vector<3xi32>
+ %c2 = spirv.Constant dense<[-3, -15, 28]> : vector<3xi32>
+
+ // CHECK: %[[CV:.*]] = spirv.Constant dense<[40, -63, 28]>
+ %0 = spirv.BitwiseAnd %c1, %c2 : vector<3xi32>
+
+ // CHECK: return %[[CV]]
+ return %0 : vector<3xi32>
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.BitwiseOr
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @bitwise_or_x_x
+// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: vector<3xi32>)
+func.func @bitwise_or_x_x(%arg0: i32, %arg1: vector<3xi32>) -> (i32, vector<3xi32>) {
+ %0 = spirv.BitwiseOr %arg0, %arg0 : i32
+ %1 = spirv.BitwiseOr %arg1, %arg1 : vector<3xi32>
+
+ // CHECK: return %[[ARG0]], %[[ARG1]]
+ return %0, %1 : i32, vector<3xi32>
+}
+
+// CHECK-LABEL: @bitwise_or_x_0
+// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: vector<3xi32>)
+func.func @bitwise_or_x_0(%arg0: i32, %arg1: vector<3xi32>) -> (i32, vector<3xi32>) {
+ %c1 = spirv.Constant 0 : i32
+ %cv1 = spirv.Constant dense<0> : vector<3xi32>
+ %0 = spirv.BitwiseOr %arg0, %c1 : i32
+ %1 = spirv.BitwiseOr %arg1, %cv1 : vector<3xi32>
+
+ // CHECK: return %[[ARG0]], %[[ARG1]]
+ return %0, %1 : i32, vector<3xi32>
+}
+
+// CHECK-LABEL: @bitwise_or_x_n1
+func.func @bitwise_or_x_n1(%arg0 : i32, %arg1 : vector<3xi32>) -> (i32, vector<3xi32>) {
+ // CHECK-DAG: %[[CN1:.*]] = spirv.Constant -1 : i32
+ // CHECK-DAG: %[[CVN1:.*]] = spirv.Constant dense<-1> : vector<3xi32>
+ %cn1 = spirv.Constant -1 : i32
+ %cvn1 = spirv.Constant dense<-1> : vector<3xi32>
+ %0 = spirv.BitwiseOr %arg0, %cn1 : i32
+ %1 = spirv.BitwiseOr %arg1, %cvn1 : vector<3xi32>
+
+ // CHECK: return %[[CN1]], %[[CVN1]]
+ return %0, %1 : i32, vector<3xi32>
+}
+
+// CHECK-LABEL: @const_fold_scalar_bor
+func.func @const_fold_scalar_bor() -> i32 {
+ %c1 = spirv.Constant -268464129 : i32 // 0xefff 8fff
+ %c2 = spirv.Constant 268464128: i32 // 0x1000 7000
+
+ // 0xefff 8fff | 0x1000 7000 = 0xffff ffff = -1
+ // CHECK: %[[CN1:.*]] = spirv.Constant -1
+ %0 = spirv.BitwiseOr %c1, %c2 : i32
+
+ // CHECK: return %[[CN1]]
+ return %0 : i32
+}
+
+// CHECK-LABEL: @const_fold_vector_bor
+func.func @const_fold_vector_bor() -> vector<3xi32> {
+ %c1 = spirv.Constant dense<[42, -55, 127]> : vector<3xi32>
+ %c2 = spirv.Constant dense<[-3, -15, 28]> : vector<3xi32>
+
+ // CHECK: %[[CV:.*]] = spirv.Constant dense<[-1, -7, 127]>
+ %0 = spirv.BitwiseOr %c1, %c2 : vector<3xi32>
+
+ // CHECK: return %[[CV]]
+ return %0 : vector<3xi32>
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.BitwiseXor
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @bitwise_xor_x_0
+// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: vector<3xi32>)
+func.func @bitwise_xor_x_0(%arg0: i32, %arg1: vector<3xi32>) -> (i32, vector<3xi32>) {
+ %c0 = spirv.Constant 0 : i32
+ %cv0 = spirv.Constant dense<0> : vector<3xi32>
+
+ %0 = spirv.BitwiseXor %arg0, %c0 : i32
+ %1 = spirv.BitwiseXor %arg1, %cv0 : vector<3xi32>
+
+ // CHECK: return %[[ARG0]], %[[ARG1]]
+ return %0, %1 : i32, vector<3xi32>
+}
+
+// CHECK-LABEL: @bitwise_xor_x_x
+func.func @bitwise_xor_x_x(%arg0 : i32, %arg1 : vector<3xi32>) -> (i32, vector<3xi32>) {
+ // CHECK-DAG: %[[C0:.*]] = spirv.Constant 0
+ // CHECK-DAG: %[[CV0:.*]] = spirv.Constant dense<0>
+ %0 = spirv.BitwiseXor %arg0, %arg0 : i32
+ %1 = spirv.BitwiseXor %arg1, %arg1 : vector<3xi32>
+
+ // CHECK: return %[[C0]], %[[CV0]]
+ return %0, %1 : i32, vector<3xi32>
+}
+
+// CHECK-LABEL: @const_fold_scalar_bxor
+func.func @const_fold_scalar_bxor() -> i32 {
+ %c1 = spirv.Constant 4294967295 : i32 // 2^32 - 1: 0xffff ffff
+ %c2 = spirv.Constant -2147483648 : i32 // -2^31 : 0x8000 0000
+
+ // 0x8000 0000 ^ 0xffff fffe = 0xefff ffff
+ // CHECK: %[[CBIG:.*]] = spirv.Constant 2147483647
+ %0 = spirv.BitwiseXor %c1, %c2 : i32
+
+ // CHECK: return %[[CBIG]]
+ return %0 : i32
+}
+
+// CHECK-LABEL: @const_fold_vector_bxor
+func.func @const_fold_vector_bxor() -> vector<3xi32> {
+ %c1 = spirv.Constant dense<[42, -55, 127]> : vector<3xi32>
+ %c2 = spirv.Constant dense<[-3, -15, 28]> : vector<3xi32>
+
+ // CHECK: %[[CV:.*]] = spirv.Constant dense<[-41, 56, 99]>
+ %0 = spirv.BitwiseXor %c1, %c2 : vector<3xi32>
+
+ // CHECK: return %[[CV]]
+ return %0 : vector<3xi32>
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.mlir.selection
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list