[Mlir-commits] [mlir] [mlir][spirv] Add basic arithmetic folds (PR #71414)
Finn Plummer
llvmlistbot at llvm.org
Thu Nov 16 11:22:20 PST 2023
https://github.com/inbelic updated https://github.com/llvm/llvm-project/pull/71414
>From 2ec719195706ec742ff31e870d0086528aaef4a5 Mon Sep 17 00:00:00 2001
From: inbelic <canadienfinn at gmail.com>
Date: Mon, 6 Nov 2023 11:06:10 +0100
Subject: [PATCH 1/3] [mlir][spirv] Add basic arithmetic folds
We do not have basic constant folds for many of the arithmetic
operations which negatively impacts readability of lowered or otherwise
generated code. This commit works towards adding more operations,
namely: [SU]Div, [SU]Mod, SRem, SNegate, IAddCarry, [SU]MulExtended
Resolves #70704
TODO:
- missing test case for vectors not added yet (just lazy)
- missing IAddCarry and [SU]MulExtended as unclear how to construct
Composite Attribute
---
.../Dialect/SPIRV/IR/SPIRVArithmeticOps.td | 11 +
.../SPIRV/IR/SPIRVCanonicalization.cpp | 182 +++++++++++++
.../SPIRV/Transforms/canonicalize.mlir | 254 ++++++++++++++++++
3 files changed, 447 insertions(+)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
index c4d1e01f9feef83..a4d342addb86f81 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
@@ -534,6 +534,8 @@ def SPIRV_SDivOp : SPIRV_ArithmeticBinaryOp<"SDiv",
```
}];
+
+ let hasFolder = 1;
}
// -----
@@ -573,6 +575,8 @@ def SPIRV_SModOp : SPIRV_ArithmeticBinaryOp<"SMod",
```
}];
+
+ let hasFolder = 1;
}
// -----
@@ -634,6 +638,8 @@ def SPIRV_SNegateOp : SPIRV_ArithmeticUnaryOp<"SNegate",
%3 = spirv.SNegate %2 : vector<4xi32>
```
}];
+
+ let hasFolder = 1;
}
// -----
@@ -673,6 +679,8 @@ def SPIRV_SRemOp : SPIRV_ArithmeticBinaryOp<"SRem",
```
}];
+
+ let hasFolder = 1;
}
// -----
@@ -707,6 +715,8 @@ def SPIRV_UDivOp : SPIRV_ArithmeticBinaryOp<"UDiv",
```
}];
+
+ let hasFolder = 1;
}
// -----
@@ -811,6 +821,7 @@ def SPIRV_UModOp : SPIRV_ArithmeticBinaryOp<"UMod",
```
}];
+ let hasFolder = 1;
let hasCanonicalizer = 1;
}
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index 9acd982dc95af6d..f596b5d1cfcbfbe 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -69,6 +69,14 @@ static Attribute extractCompositeElement(Attribute composite,
return {};
}
+static bool isDivZeroOrOverflow(const APInt &a, const APInt &b) {
+ bool div0 = b.isZero();
+
+ bool overflow = a.isMinSignedValue() && b.isAllOnes();
+
+ return div0 || overflow;
+}
+
//===----------------------------------------------------------------------===//
// TableGen'erated canonicalizers
//===----------------------------------------------------------------------===//
@@ -290,6 +298,180 @@ OpFoldResult spirv::ISubOp::fold(FoldAdaptor adaptor) {
[](APInt a, const APInt &b) { return std::move(a) - b; });
}
+//===----------------------------------------------------------------------===//
+// spirv.SDiv
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::SDivOp::fold(FoldAdaptor adaptor) {
+ // sdiv (x, 1) = x
+ if (matchPattern(getOperand2(), m_One()))
+ return getOperand1();
+
+ // According to the SPIR-V spec:
+ //
+ // Signed-integer division of Operand 1 divided by Operand 2.
+ // Results are computed per component. Behavior is undefined if Operand 2 is
+ // 0. Behavior is undefined if Operand 2 is -1 and Operand 1 is the minimum
+ // representable value for the operands' type, causing signed overflow.
+ //
+ // So don't fold during undefined behaviour.
+ bool div0OrOverflow = false;
+ auto res = constFoldBinaryOp<IntegerAttr>(
+ adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
+ if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
+ div0OrOverflow = true;
+ return a;
+ }
+ return a.sdiv(b);
+ });
+ return div0OrOverflow ? Attribute() : res;
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.SMod
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::SModOp::fold(FoldAdaptor adaptor) {
+ // smod (x, 1) = 0
+ if (matchPattern(getOperand2(), m_One()))
+ return Builder(getContext()).getIntegerAttr(getType(), 0);
+
+ // According to SPIR-V spec:
+ //
+ // Signed remainder operation for the remainder whose sign matches the sign
+ // of Operand 2. Behavior is undefined if Operand 2 is 0. Behavior is
+ // undefined if Operand 2 is -1 and Operand 1 is the minimum representable
+ // value for the operands' type, causing signed overflow. Otherwise, the
+ // result is the remainder r of Operand 1 divided by Operand 2 where if
+ // r ≠ 0, the sign of r is the same as the sign of Operand 2.
+ //
+ // So don't fold during undefined behaviour
+ bool div0OrOverflow = false;
+ auto res = constFoldBinaryOp<IntegerAttr>(
+ adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
+ if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
+ div0OrOverflow = true;
+ return a;
+ }
+ APInt c = a.abs().urem(b.abs());
+ if (c.isZero())
+ return c;
+ if (b.isNegative()) {
+ APInt zero = APInt::getZero(c.getBitWidth()); // TODO: See next todo
+ return a.isNegative() ? (zero - c) : (b + c);
+ }
+ return a.isNegative() ? (b - c) : c;
+ });
+ return div0OrOverflow ? Attribute() : res;
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.SNegate
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::SNegateOp::fold(FoldAdaptor adaptor) {
+ // -(-x) = 0 - (0 - x) = x
+ auto op = getOperand();
+ if (auto negateOp = op.getDefiningOp<spirv::SNegateOp>())
+ return negateOp->getOperand(0);
+
+ // According to the SPIR-V spec:
+ //
+ // Signed-integer subtract of Operand from zero.
+ return constFoldUnaryOp<IntegerAttr>(
+ adaptor.getOperands(), [](const APInt &a) {
+ // TODO: doing return 0 - a; also appears to work and since SPIR-V's are
+ // only 16/32/64 bits then it should be okay but double-check
+ APInt zero = APInt::getZero(a.getBitWidth());
+ return zero - a;
+ });
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.SRem
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::SRemOp::fold(FoldAdaptor adaptor) {
+ // x % 1 = 0
+ if (matchPattern(getOperand2(), m_One()))
+ return Builder(getContext()).getIntegerAttr(getType(), 0);
+
+ // According to SPIR-V spec:
+ //
+ // Signed remainder operation for the remainder whose sign matches the sign
+ // of Operand 1. Behavior is undefined if Operand 2 is 0. Behavior is
+ // undefined if Operand 2 is -1 and Operand 1 is the minimum representable
+ // value for the operands' type, causing signed overflow. Otherwise, the
+ // result is the remainder r of Operand 1 divided by Operand 2 where if
+ // r ≠ 0, the sign of r is the same as the sign of Operand 1.
+
+ // Don't fold if it would do undefined behaviour.
+ bool div0OrOverflow = false;
+ auto res = constFoldBinaryOp<IntegerAttr>(
+ adaptor.getOperands(), [&](APInt a, const APInt &b) {
+ if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
+ div0OrOverflow = true;
+ return a;
+ }
+ return a.srem(b);
+ });
+ return div0OrOverflow ? Attribute() : res;
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.UDiv
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::UDivOp::fold(FoldAdaptor adaptor) {
+ // udiv (x, 1) = x
+ if (matchPattern(getOperand2(), m_One()))
+ return getOperand1();
+
+ // According to the SPIR-V spec:
+ //
+ // Unsigned-integer division of Operand 1 divided by Operand 2. Behavior is
+ // undefined if Operand 2 is 0.
+ //
+ // So don't fold during undefined behaviour.
+ bool div0 = false;
+ auto res = constFoldBinaryOp<IntegerAttr>(
+ adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
+ if (div0 || b.isZero()) {
+ div0 = true;
+ return a;
+ }
+ return a.udiv(b);
+ });
+ return div0 ? Attribute() : res;
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.UMod
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::UModOp::fold(FoldAdaptor adaptor) {
+ // umod (x, 1) = 0
+ if (matchPattern(getOperand2(), m_One()))
+ return getOperand1();
+
+ // According to the SPIR-V spec:
+ //
+ // Unsigned modulo operation of Operand 1 modulo Operand 2. Behavior is
+ // undefined if Operand 2 is 0.
+ //
+ // So don't fold during undefined behaviour.
+ bool div0 = false;
+ auto res = constFoldBinaryOp<IntegerAttr>(
+ adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
+ if (div0 || b.isZero()) {
+ div0 = true;
+ return a;
+ }
+ return a.urem(b);
+ });
+ return div0 ? Attribute() : res;
+}
+
//===----------------------------------------------------------------------===//
// spirv.LogicalAnd
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
index 0200805a444397a..93eceb036a93a23 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
@@ -462,10 +462,264 @@ func.func @const_fold_vector_isub() -> vector<3xi32> {
// -----
+//===----------------------------------------------------------------------===//
+// spirv.SDiv
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @sdiv_x_1
+func.func @sdiv_x_1(%arg0 : i32) -> i32 {
+ // CHECK-NEXT: return %arg0 : i32
+ %c1 = spirv.Constant 1 : i32
+ %2 = spirv.SDiv %arg0, %c1: i32
+ return %2 : i32
+}
+
+// CHECK-LABEL: @sdiv_div_0_or_overflow
+func.func @sdiv_div_0_or_overflow() -> (i32, i32) {
+ // CHECK: spirv.SDiv
+ // CHECK: spirv.SDiv
+ %c0 = spirv.Constant 0 : i32
+ %cn1 = spirv.Constant -1 : i32
+ %min_i32 = spirv.Constant -2147483648 : i32
+
+ %0 = spirv.SDiv %cn1, %c0 : i32
+ %1 = spirv.SDiv %min_i32, %cn1 : i32
+ return %0, %1 : i32, i32
+}
+
+// CHECK-LABEL: @const_fold_scalar_sdiv
+func.func @const_fold_scalar_sdiv() -> (i32, i32, i32, i32) {
+ %c56 = spirv.Constant 56 : i32
+ %c7 = spirv.Constant 7 : i32
+ %cn8 = spirv.Constant -8 : i32
+ %c3 = spirv.Constant 3 : i32
+ %cn3 = spirv.Constant -3 : i32
+
+ // CHECK-DAG: spirv.Constant -18
+ // CHECK-DAG: spirv.Constant -2
+ // CHECK-DAG: spirv.Constant -7
+ // CHECK-DAG: spirv.Constant 8
+ %0 = spirv.SDiv %c56, %c7 : i32
+ %1 = spirv.SDiv %c56, %cn8 : i32
+ %2 = spirv.SDiv %cn8, %c3 : i32
+ %3 = spirv.SDiv %c56, %cn3 : i32
+ return %0, %1, %2, %3: i32, i32, i32, i32
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.SMod
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @smod_x_1
+func.func @smod_x_1(%arg0 : i32) -> i32 {
+ // CHECK: spirv.Constant 0
+ %c1 = spirv.Constant 1 : i32
+ %0 = spirv.SMod %arg0, %c1: i32
+ return %0 : i32
+}
+
+// CHECK-LABEL: @smod_div_0_or_overflow
+func.func @smod_div_0_or_overflow() -> (i32, i32) {
+ // CHECK: spirv.SMod
+ // CHECK: spirv.SMod
+ %c0 = spirv.Constant 0 : i32
+ %cn1 = spirv.Constant -1 : i32
+ %min_i32 = spirv.Constant -2147483648 : i32
+
+ %0 = spirv.SMod %cn1, %c0 : i32
+ %1 = spirv.SMod %min_i32, %cn1 : i32
+ return %0, %1 : i32, i32
+}
+
+// CHECK-LABEL: @const_fold_scalar_smod
+func.func @const_fold_scalar_smod() -> (i32, i32, i32, i32, i32, i32, i32, i32) {
+ %c56 = spirv.Constant 56 : i32
+ %cn56 = spirv.Constant -56 : i32
+ %c59 = spirv.Constant 59 : i32
+ %cn59 = spirv.Constant -59 : i32
+ %c7 = spirv.Constant 7 : i32
+ %cn8 = spirv.Constant -8 : i32
+ %c3 = spirv.Constant 3 : i32
+ %cn3 = spirv.Constant -3 : i32
+
+ // CHECK-DAG: %[[ZERO:.*]] = spirv.Constant 0 : i32
+ // CHECK-DAG: %[[TWO:.*]] = spirv.Constant 2 : i32
+ // CHECK-DAG: %[[FIFTYTHREE:.*]] = spirv.Constant 53 : i32
+ // CHECK-DAG: %[[NFIFTYTHREE:.*]] = spirv.Constant -53 : i32
+ // CHECK-DAG: %[[THREE:.*]] = spirv.Constant 3 : i32
+ // CHECK-DAG: %[[NTHREE:.*]] = spirv.Constant -3 : i32
+ %0 = spirv.SMod %c56, %c7 : i32
+ %1 = spirv.SMod %c56, %cn8 : i32
+ %2 = spirv.SMod %c56, %c3 : i32
+ %3 = spirv.SMod %cn3, %c56 : i32
+ %4 = spirv.SMod %cn3, %cn56 : i32
+ %5 = spirv.SMod %c59, %c56 : i32
+ %6 = spirv.SMod %c59, %cn56 : i32
+ %7 = spirv.SMod %cn59, %cn56 : i32
+
+ // CHECK: return %[[ZERO]], %[[ZERO]], %[[TWO]], %[[FIFTYTHREE]], %[[NTHREE]], %[[THREE]], %[[NFIFTYTHREE]], %[[NTHREE]]
+ return %0, %1, %2, %3, %4, %5, %6, %7 : i32, i32, i32, i32, i32, i32, i32, i32
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.SNegate
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @snegate_twice
+func.func @snegate_twice(%arg0 : i32) -> i32 {
+ // CHECK-NEXT: return %arg0 : i32
+ %0 = spirv.SNegate %arg0 : i32
+ %1 = spirv.SNegate %0 : i32
+ return %1 : i32
+}
+
+// CHECK-LABEL: @const_fold_scalar_snegate
+func.func @const_fold_scalar_snegate() -> (i32, i32, i32) {
+ %c0 = spirv.Constant 0 : i32
+ %c3 = spirv.Constant 3 : i32
+ %cn3 = spirv.Constant -3 : i32
+
+ // CHECK-DAG: %[[THREE:.*]] = spirv.Constant 3 : i32
+ // CHECK-DAG: %[[NTHREE:.*]] = spirv.Constant -3 : i32
+ // CHECK-DAG: %[[ZERO:.*]] = spirv.Constant 0 : i32
+ %0 = spirv.SNegate %c0 : i32
+ %1 = spirv.SNegate %c3 : i32
+ %2 = spirv.SNegate %cn3 : i32
+
+ // CHECK: return %[[ZERO]], %[[NTHREE]], %[[THREE]]
+ return %0, %1, %2 : i32, i32, i32
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.SRem
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @srem_x_1
+func.func @srem_x_1(%arg0 : i32) -> i32 {
+ // CHECK: spirv.Constant 0
+ %c1 = spirv.Constant 1 : i32
+ %0 = spirv.SRem %arg0, %c1 : i32
+ return %0 : i32
+}
+
+// CHECK-LABEL: @srem_div_0_or_overflow
+func.func @srem_div_0_or_overflow() -> (i32, i32) {
+ // CHECK: spirv.SRem
+ // CHECK: spirv.SRem
+ %c0 = spirv.Constant 0 : i32
+ %cn1 = spirv.Constant -1 : i32
+ %min_i32 = spirv.Constant -2147483648 : i32
+
+ %0 = spirv.SRem %cn1, %c0 : i32
+ %1 = spirv.SRem %min_i32, %cn1 : i32
+ return %0, %1 : i32, i32
+}
+
+// CHECK-LABEL: @const_fold_scalar_srem
+func.func @const_fold_scalar_srem() -> (i32, i32, i32, i32, i32) {
+ %c56 = spirv.Constant 56 : i32
+ %c7 = spirv.Constant 7 : i32
+ %cn8 = spirv.Constant -8 : i32
+ %c3 = spirv.Constant 3 : i32
+ %cn3 = spirv.Constant -3 : i32
+
+ // CHECK-DAG: %[[ONE:.*]] = spirv.Constant 1 : i32
+ // CHECK-DAG: %[[NTHREE:.*]] = spirv.Constant -3 : i32
+ // CHECK-DAG: %[[TWO:.*]] = spirv.Constant 2 : i32
+ // CHECK-DAG: %[[ZERO:.*]] = spirv.Constant 0 : i32
+ %0 = spirv.SRem %c56, %c7 : i32
+ %1 = spirv.SRem %c56, %cn8 : i32
+ %2 = spirv.SRem %c56, %c3 : i32
+ %3 = spirv.SRem %cn3, %c56 : i32
+ %4 = spirv.SRem %c7, %cn3 : i32
+ // CHECK: return %[[ZERO]], %[[ZERO]], %[[TWO]], %[[NTHREE]], %[[ONE]]
+ return %0, %1, %2, %3, %4 : i32, i32, i32, i32, i32
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.UDiv
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @udiv_x_1
+func.func @udiv_x_1(%arg0 : i32) -> i32 {
+ // CHECK-NEXT: return %arg0 : i32
+ %c1 = spirv.Constant 1 : i32
+ %2 = spirv.UDiv %arg0, %c1: i32
+ return %2 : i32
+}
+
+// CHECK-LABEL: @udiv_div_0
+func.func @udiv_div_0() -> i32 {
+ // CHECK: spirv.UDiv
+ %c0 = spirv.Constant 0 : i32
+ %cn1 = spirv.Constant -1 : i32
+ %0 = spirv.UDiv %cn1, %c0 : i32
+ return %0 : i32
+}
+
+// CHECK-LABEL: @const_fold_scalar_udiv
+func.func @const_fold_scalar_udiv() -> (i32, i32, i32) {
+ %c56 = spirv.Constant 56 : i32
+ %c7 = spirv.Constant 7 : i32
+ %cn8 = spirv.Constant -8 : i32
+ %c3 = spirv.Constant 3 : i32
+
+ // CHECK-DAG: spirv.Constant 0
+ // CHECK-DAG: spirv.Constant 1431655762
+ // CHECK-DAG: spirv.Constant 8
+ %0 = spirv.UDiv %c56, %c7 : i32
+ %1 = spirv.UDiv %cn8, %c3 : i32
+ %2 = spirv.UDiv %c56, %cn8 : i32
+ return %0, %1, %2 : i32, i32, i32
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.UMod
//===----------------------------------------------------------------------===//
+// CHECK-LABEL: @umod_x_1
+func.func @umod_x_1(%arg0 : i32) -> i32 {
+ // CHECK-NEXT: return %arg0 : i32
+ %c1 = spirv.Constant 1 : i32
+ %2 = spirv.UMod %arg0, %c1: i32
+ return %2 : i32
+}
+
+// CHECK-LABEL: @umod_div_0
+func.func @umod_div_0() -> i32 {
+ // CHECK: spirv.UMod
+ %c0 = spirv.Constant 0 : i32
+ %cn1 = spirv.Constant -1 : i32
+ %0 = spirv.UMod %cn1, %c0 : i32
+ return %0 : i32
+}
+
+// CHECK-LABEL: @const_fold_scalar_umod
+func.func @const_fold_scalar_umod() -> (i32, i32, i32) {
+ %c56 = spirv.Constant 56 : i32
+ %c7 = spirv.Constant 7 : i32
+ %cn8 = spirv.Constant -8 : i32
+ %c3 = spirv.Constant 3 : i32
+
+ // CHECK-DAG: spirv.Constant 0
+ // CHECK-DAG: spirv.Constant 2
+ // CHECK-DAG: spirv.Constant 56
+ %0 = spirv.UMod %c56, %c7 : i32
+ %1 = spirv.UMod %cn8, %c3 : i32
+ %2 = spirv.UMod %c56, %cn8 : i32
+ return %0, %1, %2 : i32, i32, i32
+}
+
// CHECK-LABEL: @umod_fold
// CHECK-SAME: (%[[ARG:.*]]: i32)
func.func @umod_fold(%arg0: i32) -> (i32, i32) {
>From 712976831f12a504b3903ef83e384724d4639039 Mon Sep 17 00:00:00 2001
From: inbelic <canadienfinn at gmail.com>
Date: Mon, 6 Nov 2023 19:11:27 +0100
Subject: [PATCH 2/3] add IAddCarry/[SU]MulExtended as canonicalization
- do not have the ability to create a struct constant as it is not yet
implemented, so we use CompositeConstruct instead; hence need to
create a new operation and can't use fold
---
.../Dialect/SPIRV/IR/SPIRVArithmeticOps.td | 6 +
.../SPIRV/IR/SPIRVCanonicalization.cpp | 187 ++++++++++++++++++
.../SPIRV/Transforms/canonicalize.mlir | 148 ++++++++++++++
3 files changed, 341 insertions(+)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
index a4d342addb86f81..a73989c41c04cfb 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
@@ -379,6 +379,8 @@ def SPIRV_IAddCarryOp : SPIRV_ArithmeticExtendedBinaryOp<"IAddCarry",
%2 = spirv.IAddCarry %0, %1 : !spirv.struct<(vector<2xi32>, vector<2xi32>)>
```
}];
+
+ let hasCanonicalizer = 1;
}
// -----
@@ -611,6 +613,8 @@ def SPIRV_SMulExtendedOp : SPIRV_ArithmeticExtendedBinaryOp<"SMulExtended",
%2 = spirv.SMulExtended %0, %1 : !spirv.struct<(vector<2xi32>, vector<2xi32>)>
```
}];
+
+ let hasCanonicalizer = 1;
}
// -----
@@ -752,6 +756,8 @@ def SPIRV_UMulExtendedOp : SPIRV_ArithmeticExtendedBinaryOp<"UMulExtended",
%2 = spirv.UMulExtended %0, %1 : !spirv.struct<(vector<2xi32>, vector<2xi32>)>
```
}];
+
+ let hasCanonicalizer = 1;
}
// -----
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index f596b5d1cfcbfbe..9f55f2807e458fb 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -123,6 +123,193 @@ void spirv::AccessChainOp::getCanonicalizationPatterns(
results.add<CombineChainedAccessChain>(context);
}
+//===----------------------------------------------------------------------===//
+// spirv.IAddCarry
+//===----------------------------------------------------------------------===//
+
+// We are required to use CompositeConstructOp to create a constant struct as
+// they are not yet implemented as constant, hence we can not do so in a fold.
+struct IAddCarryFold final : OpRewritePattern<spirv::IAddCarryOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(spirv::IAddCarryOp op,
+ PatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ auto operands = op.getOperands();
+
+ SmallVector<Value> constituents;
+ Type constituentType = operands[0].getType();
+
+ // iaddcarry (x, 0) = <0, x>
+ if (matchPattern(operands[1], m_Zero())) {
+ constituents.push_back(operands[1]);
+ constituents.push_back(operands[0]);
+ rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
+ constituents);
+ return success();
+ }
+
+ // According to the SPIR-V spec:
+ //
+ // Result Type must be from OpTypeStruct. The struct must have two
+ // members...
+ //
+ // Member 0 of the result gets the low-order bits (full component width) of
+ // the addition.
+ //
+ // Member 1 of the result gets the high-order (carry) bit of the result of
+ // the addition. That is, it gets the value 1 if the addition overflowed
+ // the component width, and 0 otherwise.
+ Attribute lhs;
+ Attribute rhs;
+ if (!matchPattern(operands[0], m_Constant(&lhs)) ||
+ !matchPattern(operands[1], m_Constant(&rhs)))
+ return failure();
+
+ auto adds = constFoldBinaryOp<IntegerAttr>(
+ {lhs, rhs}, [](const APInt &a, const APInt &b) { return a + b; });
+ auto carrys = constFoldBinaryOp<IntegerAttr>(
+ ArrayRef{lhs, rhs}, [](const APInt &a, const APInt &b) {
+ APInt zero = APInt::getZero(a.getBitWidth());
+ return a.ult(b) ? zero : (zero + 1);
+ });
+
+ if (!adds || !carrys)
+ return failure();
+
+ Value carrysVal =
+ rewriter.create<spirv::ConstantOp>(loc, constituentType, carrys);
+ constituents.push_back(carrysVal);
+
+ Value addsVal =
+ rewriter.create<spirv::ConstantOp>(loc, constituentType, adds);
+ constituents.push_back(addsVal);
+
+ rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
+ constituents);
+ return success();
+ }
+};
+
+void spirv::IAddCarryOp::getCanonicalizationPatterns(
+ RewritePatternSet &patterns, MLIRContext *context) {
+ patterns.add<IAddCarryFold>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.[SU]MulExtended
+//===----------------------------------------------------------------------===//
+
+// We are required to use CompositeConstructOp to create a constant struct as
+// they are not yet implemented as constant, hence we can not do so in a fold.
+template <typename MulOp, bool IsSigned>
+struct MulExtendedFold final : OpRewritePattern<MulOp> {
+ using OpRewritePattern<MulOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(MulOp op,
+ PatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ auto operands = op.getOperands();
+
+ SmallVector<Value> constituents;
+ Type constituentType = operands[0].getType();
+
+ // [su]mulextended (x, 0) = <0, 0>
+ if (matchPattern(operands[1], m_Zero())) {
+ Value zero = spirv::ConstantOp::getZero(constituentType, loc, rewriter);
+ constituents.push_back(zero);
+ constituents.push_back(zero);
+ rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
+ constituents);
+ return success();
+ }
+
+ // According to the SPIR-V spec:
+ //
+ // Result Type must be from OpTypeStruct. The struct must have two
+ // members...
+ //
+ // Member 0 of the result gets the low-order bits of the multiplication.
+ //
+ // Member 1 of the result gets the high-order bits of the multiplication.
+ Attribute lhs;
+ Attribute rhs;
+ if (!matchPattern(operands[0], m_Constant(&lhs)) ||
+ !matchPattern(operands[1], m_Constant(&rhs)))
+ return failure();
+
+ auto lowBits = constFoldBinaryOp<IntegerAttr>(
+ {lhs, rhs}, [](const APInt &a, const APInt &b) { return a * b; });
+
+ if (!lowBits)
+ return failure();
+
+ auto highBits = constFoldBinaryOp<IntegerAttr>(
+ {lhs, rhs}, [](const APInt &a, const APInt &b) {
+ unsigned bitWidth = a.getBitWidth();
+ APInt c;
+ if (IsSigned) {
+ c = a.sext(bitWidth * 2) * b.sext(bitWidth * 2);
+ } else {
+ c = a.zext(bitWidth * 2) * b.zext(bitWidth * 2);
+ }
+ return c.extractBits(bitWidth, bitWidth); // Extract high result
+ });
+
+ if (!highBits)
+ return failure();
+
+ Value lowBitsVal =
+ rewriter.create<spirv::ConstantOp>(loc, constituentType, lowBits);
+ constituents.push_back(lowBitsVal);
+
+ Value highBitsVal =
+ rewriter.create<spirv::ConstantOp>(loc, constituentType, highBits);
+ constituents.push_back(highBitsVal);
+
+ rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
+ constituents);
+ return success();
+ }
+};
+
+using SMulExtendedOpFold = MulExtendedFold<spirv::SMulExtendedOp, true>;
+void spirv::SMulExtendedOp::getCanonicalizationPatterns(
+ RewritePatternSet &patterns, MLIRContext *context) {
+ patterns.add<SMulExtendedOpFold>(context);
+}
+
+struct UMulExtendedOpXOne final : OpRewritePattern<spirv::UMulExtendedOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(spirv::UMulExtendedOp op,
+ PatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ auto operands = op.getOperands();
+
+ SmallVector<Value> constituents;
+ Type constituentType = operands[0].getType();
+
+ // umulextended (x, 1) = <x, 0>
+ if (matchPattern(operands[1], m_One())) {
+ Value zero = spirv::ConstantOp::getZero(constituentType, loc, rewriter);
+ constituents.push_back(operands[0]);
+ constituents.push_back(zero);
+ rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
+ constituents);
+ return success();
+ }
+
+ return failure();
+ }
+};
+
+using UMulExtendedOpFold = MulExtendedFold<spirv::UMulExtendedOp, false>;
+void spirv::UMulExtendedOp::getCanonicalizationPatterns(
+ RewritePatternSet &patterns, MLIRContext *context) {
+ patterns.add<UMulExtendedOpFold, UMulExtendedOpXOne>(context);
+}
+
//===----------------------------------------------------------------------===//
// spirv.UMod
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
index 93eceb036a93a23..84430e2690339ad 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
@@ -336,6 +336,52 @@ func.func @iadd_poison(%arg0: i32) -> i32 {
// -----
+//===----------------------------------------------------------------------===//
+// spirv.IAddCarry
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @iaddcarry_x_0
+func.func @iaddcarry_x_0(%arg0 : i32) -> !spirv.struct<(i32, i32)> {
+ %c0 = spirv.Constant 0 : i32
+
+ // CHECK: spirv.CompositeConstruct
+ %0 = spirv.IAddCarry %arg0, %c0 : !spirv.struct<(i32, i32)>
+ return %0 : !spirv.struct<(i32, i32)>
+}
+
+// CHECK-LABEL: @const_fold_scalar_iaddcarry
+func.func @const_fold_scalar_iaddcarry() -> (!spirv.struct<(i32, i32)>, !spirv.struct<(i32, i32)>) {
+ %c5 = spirv.Constant 5 : i32
+ %cn5 = spirv.Constant -5 : i32
+ %cn8 = spirv.Constant -8 : i32
+
+ // CHECK-DAG: spirv.Constant 0
+ // CHECK-DAG: spirv.Constant -3
+ // CHECK-DAG: spirv.CompositeConstruct
+ // CHECK-DAG: spirv.Constant 1
+ // CHECK-DAG: spirv.Constant -13
+ // CHECK-DAG: spirv.CompositeConstruct
+ %0 = spirv.IAddCarry %c5, %cn8 : !spirv.struct<(i32, i32)>
+ %1 = spirv.IAddCarry %cn5, %cn8 : !spirv.struct<(i32, i32)>
+
+ return %0, %1 : !spirv.struct<(i32, i32)>, !spirv.struct<(i32, i32)>
+}
+
+// CHECK-LABEL: @const_fold_vector_iaddcarry
+func.func @const_fold_vector_iaddcarry() -> !spirv.struct<(vector<3xi32>, vector<3xi32>)> {
+ %v0 = spirv.Constant dense<[5, -5, -1]> : vector<3xi32>
+ %v1 = spirv.Constant dense<[-8, -8, 1]> : vector<3xi32>
+
+ // CHECK: spirv.Constant dense<[0, 1, 1]>
+ // CHECK-NEXT: spirv.Constant dense<[-3, -13, 0]>
+ // CHECK-NEXT: spirv.CompositeConstruct
+ %0 = spirv.IAddCarry %v0, %v1 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+ return %0 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.IMul
//===----------------------------------------------------------------------===//
@@ -400,6 +446,108 @@ func.func @const_fold_vector_imul() -> vector<3xi32> {
// -----
+//===----------------------------------------------------------------------===//
+// spirv.SMulExtended
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @smulextended_x_0
+func.func @smulextended_x_0(%arg0 : i32) -> !spirv.struct<(i32, i32)> {
+ %c0 = spirv.Constant 0 : i32
+
+ // CHECK: spirv.CompositeConstruct
+ %0 = spirv.SMulExtended %arg0, %c0 : !spirv.struct<(i32, i32)>
+ return %0 : !spirv.struct<(i32, i32)>
+}
+
+// CHECK-LABEL: @const_fold_scalar_smulextended
+func.func @const_fold_scalar_smulextended() -> (!spirv.struct<(i32, i32)>, !spirv.struct<(i32, i32)>) {
+ %c5 = spirv.Constant 5 : i32
+ %cn5 = spirv.Constant -5 : i32
+ %cn8 = spirv.Constant -8 : i32
+
+ // CHECK-DAG: spirv.Constant -40
+ // CHECK-DAG: spirv.Constant -1
+ // CHECK-DAG: spirv.CompositeConstruct
+ // CHECK-DAG: spirv.Constant 40
+ // CHECK-DAG: spirv.Constant 0
+ // CHECK-DAG: spirv.CompositeConstruct
+ %0 = spirv.SMulExtended %c5, %cn8 : !spirv.struct<(i32, i32)>
+ %1 = spirv.SMulExtended %cn5, %cn8 : !spirv.struct<(i32, i32)>
+
+ return %0, %1 : !spirv.struct<(i32, i32)>, !spirv.struct<(i32, i32)>
+}
+
+// CHECK-LABEL: @const_fold_vector_smulextended
+func.func @const_fold_vector_smulextended() -> !spirv.struct<(vector<3xi32>, vector<3xi32>)> {
+ %v0 = spirv.Constant dense<[2147483647, -5, -1]> : vector<3xi32>
+ %v1 = spirv.Constant dense<[5, -8, 1]> : vector<3xi32>
+
+ // CHECK: spirv.Constant dense<[2147483643, 40, -1]>
+ // CHECK-NEXT: spirv.Constant dense<[2, 0, -1]>
+ // CHECK-NEXT: spirv.CompositeConstruct
+ %0 = spirv.SMulExtended %v0, %v1 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+ return %0 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.UMulExtended
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @umulextended_x_0
+func.func @umulextended_x_0(%arg0 : i32) -> !spirv.struct<(i32, i32)> {
+ %c0 = spirv.Constant 0 : i32
+
+ // CHECK: spirv.CompositeConstruct
+ %0 = spirv.UMulExtended %arg0, %c0 : !spirv.struct<(i32, i32)>
+ return %0 : !spirv.struct<(i32, i32)>
+}
+
+// CHECK-LABEL: @umulextended_x_1
+func.func @umulextended_x_1(%arg0 : i32) -> !spirv.struct<(i32, i32)> {
+ %c0 = spirv.Constant 1 : i32
+
+ // CHECK: spirv.CompositeConstruct
+ %0 = spirv.UMulExtended %arg0, %c0 : !spirv.struct<(i32, i32)>
+ return %0 : !spirv.struct<(i32, i32)>
+}
+
+// CHECK-LABEL: @const_fold_scalar_umulextended
+func.func @const_fold_scalar_umulextended() -> (!spirv.struct<(i32, i32)>, !spirv.struct<(i32, i32)>) {
+ %c5 = spirv.Constant 5 : i32
+ %cn5 = spirv.Constant -5 : i32
+ %cn8 = spirv.Constant -8 : i32
+
+ // CHECK-DAG: spirv.Constant 40
+ // CHECK-DAG: spirv.Constant -13
+ // CHECK-DAG: spirv.CompositeConstruct
+ // CHECK-DAG: spirv.Constant -40
+ // CHECK-DAG: spirv.Constant 4
+ // CHECK-DAG: spirv.CompositeConstruct
+ %0 = spirv.UMulExtended %c5, %cn8 : !spirv.struct<(i32, i32)>
+ %1 = spirv.UMulExtended %cn5, %cn8 : !spirv.struct<(i32, i32)>
+
+ return %0, %1 : !spirv.struct<(i32, i32)>, !spirv.struct<(i32, i32)>
+}
+
+// CHECK-LABEL: @const_fold_vector_umulextended
+func.func @const_fold_vector_umulextended() -> !spirv.struct<(vector<3xi32>, vector<3xi32>)> {
+ %v0 = spirv.Constant dense<[2147483647, -5, -1]> : vector<3xi32>
+ %v1 = spirv.Constant dense<[5, -8, 1]> : vector<3xi32>
+
+ // CHECK: spirv.Constant dense<[2147483643, 40, -1]>
+ // CHECK-NEXT: spirv.Constant dense<[2, -13, 0]>
+ // CHECK-NEXT: spirv.CompositeConstruct
+ %0 = spirv.UMulExtended %v0, %v1 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+ return %0 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+
+}
+
+// -----
+
+
//===----------------------------------------------------------------------===//
// spirv.ISub
//===----------------------------------------------------------------------===//
>From 2c01c1455181f8ea5aa383bcef7e4b6b1a077800 Mon Sep 17 00:00:00 2001
From: inbelic <canadienfinn at gmail.com>
Date: Thu, 16 Nov 2023 19:18:02 +0100
Subject: [PATCH 3/3] add bitwise And Xor Or folds
- move And/Or into SPIRV/../Canonicalization.cpp for consistency
- implement const folding for And/Or
- add Xor folding
---
.../mlir/Dialect/SPIRV/IR/SPIRVBitOps.td | 2 +
.../SPIRV/IR/SPIRVCanonicalization.cpp | 86 +++++++++++
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp | 49 ------
.../SPIRV/Transforms/canonicalize.mlir | 141 ++++++++++++++++++
4 files changed, 229 insertions(+), 49 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td
index 286f4de6f90f621..611ad051f5d818e 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td
@@ -412,6 +412,8 @@ def SPIRV_BitwiseXorOp : SPIRV_BitBinaryOp<"BitwiseXor",
%2 = spirv.BitwiseXor %0, %1 : vector<4xi32>
```
}];
+
+ let hasFolder = 1;
}
// -----
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index 9f55f2807e458fb..901b871f35106be 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -725,6 +725,92 @@ OpFoldResult spirv::LogicalOrOp::fold(FoldAdaptor adaptor) {
return Attribute();
}
+//===----------------------------------------------------------------------===//
+// spirv.BitwiseAndOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult
+spirv::BitwiseAndOp::fold(spirv::BitwiseAndOp::FoldAdaptor adaptor) {
+ APInt rhsMask;
+ if (matchPattern(adaptor.getOperand2(), m_ConstantInt(&rhsMask))) {
+ // x & 0 -> 0
+ if (rhsMask.isZero())
+ return getOperand2();
+
+ // x & <all ones> -> x
+ if (rhsMask.isAllOnes())
+ return getOperand1();
+
+ // (UConvert x : iN to iK) & <mask with N low bits set> -> UConvert x
+ if (auto zext = getOperand1().getDefiningOp<spirv::UConvertOp>()) {
+ int valueBits =
+ getElementTypeOrSelf(zext.getOperand()).getIntOrFloatBitWidth();
+ if (rhsMask.zextOrTrunc(valueBits).isAllOnes())
+ return getOperand1();
+ }
+ }
+
+ // According to the SPIR-V spec:
+ //
+ // Type is a scalar or vector of integer type.
+ // Results are computed per component, and within each component, per bit.
+ // So we can use the APInt | method.
+ return constFoldBinaryOp<IntegerAttr>(
+ adaptor.getOperands(),
+ [](const APInt &a, const APInt &b) { return a & b; });
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.BitwiseOrOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::BitwiseOrOp::fold(spirv::BitwiseOrOp::FoldAdaptor adaptor) {
+ APInt rhsMask;
+ if (matchPattern(adaptor.getOperand2(), m_ConstantInt(&rhsMask))) {
+ // x | 0 -> x
+ if (rhsMask.isZero())
+ return getOperand1();
+
+ // x | <all ones> -> <all ones>
+ if (rhsMask.isAllOnes())
+ return getOperand2();
+ }
+
+ // According to the SPIR-V spec:
+ //
+ // Type is a scalar or vector of integer type.
+ // Results are computed per component, and within each component, per bit.
+ // So we can use the APInt | method.
+ return constFoldBinaryOp<IntegerAttr>(
+ adaptor.getOperands(),
+ [](const APInt &a, const APInt &b) { return a | b; });
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.BitwiseXorOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult
+spirv::BitwiseXorOp::fold(spirv::BitwiseXorOp::FoldAdaptor adaptor) {
+ // x ^ 0 -> x
+ if (matchPattern(adaptor.getOperand2(), m_Zero())) {
+ return getOperand1();
+ }
+
+ // x ^ x -> 0
+ if (getOperand1() == getOperand2())
+ return Builder(getContext()).getIntegerAttr(getType(), 0);
+
+ // According to the SPIR-V spec:
+ //
+ // Type is a scalar or vector of integer type.
+ // Results are computed per component, and within each component, per bit.
+ // So we can use the APInt ^ method.
+ return constFoldBinaryOp<IntegerAttr>(
+ adaptor.getOperands(),
+ [](const APInt &a, const APInt &b) { return a ^ b; });
+}
+
//===----------------------------------------------------------------------===//
// spirv.mlir.selection
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 3906bf74ea72235..2a1d083308282a8 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -1968,55 +1968,6 @@ LogicalResult spirv::ShiftRightLogicalOp::verify() {
return verifyShiftOp(*this);
}
-//===----------------------------------------------------------------------===//
-// spirv.BtiwiseAndOp
-//===----------------------------------------------------------------------===//
-
-OpFoldResult
-spirv::BitwiseAndOp::fold(spirv::BitwiseAndOp::FoldAdaptor adaptor) {
- APInt rhsMask;
- if (!matchPattern(adaptor.getOperand2(), m_ConstantInt(&rhsMask)))
- return {};
-
- // x & 0 -> 0
- if (rhsMask.isZero())
- return getOperand2();
-
- // x & <all ones> -> x
- if (rhsMask.isAllOnes())
- return getOperand1();
-
- // (UConvert x : iN to iK) & <mask with N low bits set> -> UConvert x
- if (auto zext = getOperand1().getDefiningOp<spirv::UConvertOp>()) {
- int valueBits =
- getElementTypeOrSelf(zext.getOperand()).getIntOrFloatBitWidth();
- if (rhsMask.zextOrTrunc(valueBits).isAllOnes())
- return getOperand1();
- }
-
- return {};
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.BtiwiseOrOp
-//===----------------------------------------------------------------------===//
-
-OpFoldResult spirv::BitwiseOrOp::fold(spirv::BitwiseOrOp::FoldAdaptor adaptor) {
- APInt rhsMask;
- if (!matchPattern(adaptor.getOperand2(), m_ConstantInt(&rhsMask)))
- return {};
-
- // x | 0 -> x
- if (rhsMask.isZero())
- return getOperand1();
-
- // x | <all ones> -> <all ones>
- if (rhsMask.isAllOnes())
- return getOperand2();
-
- return {};
-}
-
//===----------------------------------------------------------------------===//
// spirv.ImageQuerySize
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
index 84430e2690339ad..c2ba83149db015f 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
@@ -1062,6 +1062,147 @@ func.func @convert_logical_or_true_false_vector(%arg: vector<3xi1>) -> (vector<3
// -----
+//===----------------------------------------------------------------------===//
+// spirv.BitwiseAnd
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @bitwise_and_x_0
+// CHECK-SAME: (%[[ARG:.*]]: i32)
+func.func @bitwise_and_x_0(%arg0 : i32) -> i32 {
+ // CHECK: %[[C0:.*]] = spirv.Constant 0 : i32
+ %c1 = spirv.Constant 0 : i32
+ %0 = spirv.BitwiseAnd %arg0, %c1 : i32
+
+ // CHECK: return %[[C0]]
+ return %0 : i32
+}
+
+// CHECK-LABEL: @bitwise_and_x_n1
+// CHECK-SAME: (%[[ARG:.*]]: i32)
+func.func @bitwise_and_x_n1(%arg0 : i32) -> i32 {
+ %c1 = spirv.Constant -1 : i32
+ %0 = spirv.BitwiseAnd %arg0, %c1 : i32
+
+ // CHECK: return %[[ARG]]
+ return %0 : i32
+}
+
+// CHECK-LABEL: @const_fold_scalar_band
+func.func @const_fold_scalar_band() -> i32 {
+ %c1 = spirv.Constant -268464129 : i32 // 0xefff 8fff
+ %c2 = spirv.Constant 268464128: i32 // 0x1000 7000
+
+ // 0xefff 8fff | 0x1000 7000 = 0xffff ffff = -1
+ // CHECK: spirv.Constant 0
+ %0 = spirv.BitwiseAnd %c1, %c2 : i32
+ return %0 : i32
+}
+
+// CHECK-LABEL: @const_fold_vector_band
+func.func @const_fold_vector_band() -> vector<3xi32> {
+ %c1 = spirv.Constant dense<[42, -55, 127]> : vector<3xi32>
+ %c2 = spirv.Constant dense<[-3, -15, 28]> : vector<3xi32>
+
+ // CHECK: spirv.Constant dense<[40, -63, 28]>
+ %0 = spirv.BitwiseAnd %c1, %c2 : vector<3xi32>
+ return %0 : vector<3xi32>
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.BitwiseOr
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @bitwise_or_x_0
+// CHECK-SAME: (%[[ARG:.*]]: i32)
+func.func @bitwise_or_x_0(%arg0 : i32) -> i32 {
+ %c1 = spirv.Constant 0 : i32
+ %0 = spirv.BitwiseOr %arg0, %c1 : i32
+
+ // CHECK: return %[[ARG]]
+ return %0 : i32
+}
+
+// CHECK-LABEL: @bitwise_or_x_n1
+func.func @bitwise_or_x_n1(%arg0 : i32) -> i32 {
+ // CHECK: %[[CN1:.*]] = spirv.Constant -1 : i32
+ %c1 = spirv.Constant -1 : i32
+ %0 = spirv.BitwiseOr %arg0, %c1 : i32
+
+ // CHECK: return %[[CN1]]
+ return %0 : i32
+}
+
+// CHECK-LABEL: @const_fold_scalar_bor
+func.func @const_fold_scalar_bor() -> i32 {
+ %c1 = spirv.Constant -268464129 : i32 // 0xefff 8fff
+ %c2 = spirv.Constant 268464128: i32 // 0x1000 7000
+
+ // 0xefff 8fff | 0x1000 7000 = 0xffff ffff = -1
+ // CHECK: spirv.Constant -1
+ %0 = spirv.BitwiseOr %c1, %c2 : i32
+ return %0 : i32
+}
+
+// CHECK-LABEL: @const_fold_vector_bor
+func.func @const_fold_vector_bor() -> vector<3xi32> {
+ %c1 = spirv.Constant dense<[42, -55, 127]> : vector<3xi32>
+ %c2 = spirv.Constant dense<[-3, -15, 28]> : vector<3xi32>
+
+ // CHECK: spirv.Constant dense<[-1, -7, 127]>
+ %0 = spirv.BitwiseOr %c1, %c2 : vector<3xi32>
+ return %0 : vector<3xi32>
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.BitwiseXor
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @bitwise_xor_x_0
+// CHECK-SAME: (%[[ARG:.*]]: i32)
+func.func @bitwise_xor_x_0(%arg0 : i32) -> i32 {
+ %c1 = spirv.Constant 0 : i32
+ %0 = spirv.BitwiseXor %arg0, %c1 : i32
+
+ // CHECK: return %[[ARG]]
+ return %0 : i32
+}
+
+// CHECK-LABEL: @bitwise_xor_x_x
+// CHECK-SAME: (%[[ARG:.*]]: i32)
+func.func @bitwise_xor_x_x(%arg0 : i32) -> i32 {
+ %0 = spirv.BitwiseXor %arg0, %arg0 : i32
+
+ // CHECK: spirv.Constant 0
+ return %0 : i32
+}
+
+// CHECK-LABEL: @const_fold_scalar_bxor
+func.func @const_fold_scalar_bxor() -> i32 {
+ %c1 = spirv.Constant 4294967295 : i32 // 2^32 - 1: 0xffff ffff
+ %c2 = spirv.Constant -2147483648 : i32 // -2^31 : 0x8000 0000
+
+ // 0x8000 0000 ^ 0xffff fffe = 0xefff ffff
+ // CHECK: spirv.Constant 2147483647
+ %0 = spirv.BitwiseXor %c1, %c2 : i32
+ return %0 : i32
+}
+
+// CHECK-LABEL: @const_fold_vector_bxor
+func.func @const_fold_vector_bxor() -> vector<3xi32> {
+ %c1 = spirv.Constant dense<[42, -55, 127]> : vector<3xi32>
+ %c2 = spirv.Constant dense<[-3, -15, 28]> : vector<3xi32>
+
+ // CHECK: spirv.Constant dense<[-41, 56, 99]>
+ %0 = spirv.BitwiseXor %c1, %c2 : vector<3xi32>
+ return %0 : vector<3xi32>
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.mlir.selection
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list