[Mlir-commits] [mlir] [mlir][spirv] Add basic arithmetic folds (PR #71414)
Finn Plummer
llvmlistbot at llvm.org
Wed Nov 8 07:28:01 PST 2023
https://github.com/inbelic updated https://github.com/llvm/llvm-project/pull/71414
>From 2ec719195706ec742ff31e870d0086528aaef4a5 Mon Sep 17 00:00:00 2001
From: inbelic <canadienfinn at gmail.com>
Date: Mon, 6 Nov 2023 11:06:10 +0100
Subject: [PATCH 1/2] [mlir][spirv] Add basic arithmetic folds
We do not have basic constant folds for many of the arithmetic
operations which negatively impacts readability of lowered or otherwise
generated code. This commit works towards adding more operations,
namely: [SU]Div, [SU]Mod, SRem, SNegate, IAddCarry, [SU]MulExtended
Resolves #70704
TODO:
- missing test case for vectors not added yet (just lazy)
- missing IAddCarry and [SU]MulExtended as unclear how to construct
Composite Attribute
---
.../Dialect/SPIRV/IR/SPIRVArithmeticOps.td | 11 +
.../SPIRV/IR/SPIRVCanonicalization.cpp | 182 +++++++++++++
.../SPIRV/Transforms/canonicalize.mlir | 254 ++++++++++++++++++
3 files changed, 447 insertions(+)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
index c4d1e01f9feef83..a4d342addb86f81 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
@@ -534,6 +534,8 @@ def SPIRV_SDivOp : SPIRV_ArithmeticBinaryOp<"SDiv",
```
}];
+
+ let hasFolder = 1;
}
// -----
@@ -573,6 +575,8 @@ def SPIRV_SModOp : SPIRV_ArithmeticBinaryOp<"SMod",
```
}];
+
+ let hasFolder = 1;
}
// -----
@@ -634,6 +638,8 @@ def SPIRV_SNegateOp : SPIRV_ArithmeticUnaryOp<"SNegate",
%3 = spirv.SNegate %2 : vector<4xi32>
```
}];
+
+ let hasFolder = 1;
}
// -----
@@ -673,6 +679,8 @@ def SPIRV_SRemOp : SPIRV_ArithmeticBinaryOp<"SRem",
```
}];
+
+ let hasFolder = 1;
}
// -----
@@ -707,6 +715,8 @@ def SPIRV_UDivOp : SPIRV_ArithmeticBinaryOp<"UDiv",
```
}];
+
+ let hasFolder = 1;
}
// -----
@@ -811,6 +821,7 @@ def SPIRV_UModOp : SPIRV_ArithmeticBinaryOp<"UMod",
```
}];
+ let hasFolder = 1;
let hasCanonicalizer = 1;
}
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index 9acd982dc95af6d..f596b5d1cfcbfbe 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -69,6 +69,14 @@ static Attribute extractCompositeElement(Attribute composite,
return {};
}
+static bool isDivZeroOrOverflow(const APInt &a, const APInt &b) {
+ bool div0 = b.isZero();
+
+ bool overflow = a.isMinSignedValue() && b.isAllOnes();
+
+ return div0 || overflow;
+}
+
//===----------------------------------------------------------------------===//
// TableGen'erated canonicalizers
//===----------------------------------------------------------------------===//
@@ -290,6 +298,180 @@ OpFoldResult spirv::ISubOp::fold(FoldAdaptor adaptor) {
[](APInt a, const APInt &b) { return std::move(a) - b; });
}
+//===----------------------------------------------------------------------===//
+// spirv.SDiv
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::SDivOp::fold(FoldAdaptor adaptor) {
+ // sdiv (x, 1) = x
+ if (matchPattern(getOperand2(), m_One()))
+ return getOperand1();
+
+ // According to the SPIR-V spec:
+ //
+ // Signed-integer division of Operand 1 divided by Operand 2.
+ // Results are computed per component. Behavior is undefined if Operand 2 is
+ // 0. Behavior is undefined if Operand 2 is -1 and Operand 1 is the minimum
+ // representable value for the operands' type, causing signed overflow.
+ //
+ // So don't fold during undefined behaviour.
+ bool div0OrOverflow = false;
+ auto res = constFoldBinaryOp<IntegerAttr>(
+ adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
+ if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
+ div0OrOverflow = true;
+ return a;
+ }
+ return a.sdiv(b);
+ });
+ return div0OrOverflow ? Attribute() : res;
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.SMod
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::SModOp::fold(FoldAdaptor adaptor) {
+ // smod (x, 1) = 0
+ if (matchPattern(getOperand2(), m_One()))
+ return Builder(getContext()).getIntegerAttr(getType(), 0);
+
+ // According to SPIR-V spec:
+ //
+ // Signed remainder operation for the remainder whose sign matches the sign
+ // of Operand 2. Behavior is undefined if Operand 2 is 0. Behavior is
+ // undefined if Operand 2 is -1 and Operand 1 is the minimum representable
+ // value for the operands' type, causing signed overflow. Otherwise, the
+ // result is the remainder r of Operand 1 divided by Operand 2 where if
+ // r ≠ 0, the sign of r is the same as the sign of Operand 2.
+ //
+ // So don't fold during undefined behaviour
+ bool div0OrOverflow = false;
+ auto res = constFoldBinaryOp<IntegerAttr>(
+ adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
+ if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
+ div0OrOverflow = true;
+ return a;
+ }
+ APInt c = a.abs().urem(b.abs());
+ if (c.isZero())
+ return c;
+ if (b.isNegative()) {
+ APInt zero = APInt::getZero(c.getBitWidth()); // TODO: See next todo
+ return a.isNegative() ? (zero - c) : (b + c);
+ }
+ return a.isNegative() ? (b - c) : c;
+ });
+ return div0OrOverflow ? Attribute() : res;
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.SNegate
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::SNegateOp::fold(FoldAdaptor adaptor) {
+ // -(-x) = 0 - (0 - x) = x
+ auto op = getOperand();
+ if (auto negateOp = op.getDefiningOp<spirv::SNegateOp>())
+ return negateOp->getOperand(0);
+
+ // According to the SPIR-V spec:
+ //
+ // Signed-integer subtract of Operand from zero.
+ return constFoldUnaryOp<IntegerAttr>(
+ adaptor.getOperands(), [](const APInt &a) {
+ // TODO: doing return 0 - a; also appears to work and since SPIR-V's are
+ // only 16/32/64 bits then it should be okay but double-check
+ APInt zero = APInt::getZero(a.getBitWidth());
+ return zero - a;
+ });
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.SRem
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::SRemOp::fold(FoldAdaptor adaptor) {
+ // x % 1 = 0
+ if (matchPattern(getOperand2(), m_One()))
+ return Builder(getContext()).getIntegerAttr(getType(), 0);
+
+ // According to SPIR-V spec:
+ //
+ // Signed remainder operation for the remainder whose sign matches the sign
+ // of Operand 1. Behavior is undefined if Operand 2 is 0. Behavior is
+ // undefined if Operand 2 is -1 and Operand 1 is the minimum representable
+ // value for the operands' type, causing signed overflow. Otherwise, the
+ // result is the remainder r of Operand 1 divided by Operand 2 where if
+ // r ≠ 0, the sign of r is the same as the sign of Operand 1.
+
+ // Don't fold if it would do undefined behaviour.
+ bool div0OrOverflow = false;
+ auto res = constFoldBinaryOp<IntegerAttr>(
+ adaptor.getOperands(), [&](APInt a, const APInt &b) {
+ if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
+ div0OrOverflow = true;
+ return a;
+ }
+ return a.srem(b);
+ });
+ return div0OrOverflow ? Attribute() : res;
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.UDiv
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::UDivOp::fold(FoldAdaptor adaptor) {
+ // udiv (x, 1) = x
+ if (matchPattern(getOperand2(), m_One()))
+ return getOperand1();
+
+ // According to the SPIR-V spec:
+ //
+ // Unsigned-integer division of Operand 1 divided by Operand 2. Behavior is
+ // undefined if Operand 2 is 0.
+ //
+ // So don't fold during undefined behaviour.
+ bool div0 = false;
+ auto res = constFoldBinaryOp<IntegerAttr>(
+ adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
+ if (div0 || b.isZero()) {
+ div0 = true;
+ return a;
+ }
+ return a.udiv(b);
+ });
+ return div0 ? Attribute() : res;
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.UMod
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::UModOp::fold(FoldAdaptor adaptor) {
+ // umod (x, 1) = 0
+ if (matchPattern(getOperand2(), m_One()))
+ return getOperand1();
+
+ // According to the SPIR-V spec:
+ //
+ // Unsigned modulo operation of Operand 1 modulo Operand 2. Behavior is
+ // undefined if Operand 2 is 0.
+ //
+ // So don't fold during undefined behaviour.
+ bool div0 = false;
+ auto res = constFoldBinaryOp<IntegerAttr>(
+ adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
+ if (div0 || b.isZero()) {
+ div0 = true;
+ return a;
+ }
+ return a.urem(b);
+ });
+ return div0 ? Attribute() : res;
+}
+
//===----------------------------------------------------------------------===//
// spirv.LogicalAnd
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
index 0200805a444397a..93eceb036a93a23 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
@@ -462,10 +462,264 @@ func.func @const_fold_vector_isub() -> vector<3xi32> {
// -----
+//===----------------------------------------------------------------------===//
+// spirv.SDiv
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @sdiv_x_1
+func.func @sdiv_x_1(%arg0 : i32) -> i32 {
+ // CHECK-NEXT: return %arg0 : i32
+ %c1 = spirv.Constant 1 : i32
+ %2 = spirv.SDiv %arg0, %c1: i32
+ return %2 : i32
+}
+
+// CHECK-LABEL: @sdiv_div_0_or_overflow
+func.func @sdiv_div_0_or_overflow() -> (i32, i32) {
+ // CHECK: spirv.SDiv
+ // CHECK: spirv.SDiv
+ %c0 = spirv.Constant 0 : i32
+ %cn1 = spirv.Constant -1 : i32
+ %min_i32 = spirv.Constant -2147483648 : i32
+
+ %0 = spirv.SDiv %cn1, %c0 : i32
+ %1 = spirv.SDiv %min_i32, %cn1 : i32
+ return %0, %1 : i32, i32
+}
+
+// CHECK-LABEL: @const_fold_scalar_sdiv
+func.func @const_fold_scalar_sdiv() -> (i32, i32, i32, i32) {
+ %c56 = spirv.Constant 56 : i32
+ %c7 = spirv.Constant 7 : i32
+ %cn8 = spirv.Constant -8 : i32
+ %c3 = spirv.Constant 3 : i32
+ %cn3 = spirv.Constant -3 : i32
+
+ // CHECK-DAG: spirv.Constant -18
+ // CHECK-DAG: spirv.Constant -2
+ // CHECK-DAG: spirv.Constant -7
+ // CHECK-DAG: spirv.Constant 8
+ %0 = spirv.SDiv %c56, %c7 : i32
+ %1 = spirv.SDiv %c56, %cn8 : i32
+ %2 = spirv.SDiv %cn8, %c3 : i32
+ %3 = spirv.SDiv %c56, %cn3 : i32
+ return %0, %1, %2, %3: i32, i32, i32, i32
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.SMod
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @smod_x_1
+func.func @smod_x_1(%arg0 : i32) -> i32 {
+ // CHECK: spirv.Constant 0
+ %c1 = spirv.Constant 1 : i32
+ %0 = spirv.SMod %arg0, %c1: i32
+ return %0 : i32
+}
+
+// CHECK-LABEL: @smod_div_0_or_overflow
+func.func @smod_div_0_or_overflow() -> (i32, i32) {
+ // CHECK: spirv.SMod
+ // CHECK: spirv.SMod
+ %c0 = spirv.Constant 0 : i32
+ %cn1 = spirv.Constant -1 : i32
+ %min_i32 = spirv.Constant -2147483648 : i32
+
+ %0 = spirv.SMod %cn1, %c0 : i32
+ %1 = spirv.SMod %min_i32, %cn1 : i32
+ return %0, %1 : i32, i32
+}
+
+// CHECK-LABEL: @const_fold_scalar_smod
+func.func @const_fold_scalar_smod() -> (i32, i32, i32, i32, i32, i32, i32, i32) {
+ %c56 = spirv.Constant 56 : i32
+ %cn56 = spirv.Constant -56 : i32
+ %c59 = spirv.Constant 59 : i32
+ %cn59 = spirv.Constant -59 : i32
+ %c7 = spirv.Constant 7 : i32
+ %cn8 = spirv.Constant -8 : i32
+ %c3 = spirv.Constant 3 : i32
+ %cn3 = spirv.Constant -3 : i32
+
+ // CHECK-DAG: %[[ZERO:.*]] = spirv.Constant 0 : i32
+ // CHECK-DAG: %[[TWO:.*]] = spirv.Constant 2 : i32
+ // CHECK-DAG: %[[FIFTYTHREE:.*]] = spirv.Constant 53 : i32
+ // CHECK-DAG: %[[NFIFTYTHREE:.*]] = spirv.Constant -53 : i32
+ // CHECK-DAG: %[[THREE:.*]] = spirv.Constant 3 : i32
+ // CHECK-DAG: %[[NTHREE:.*]] = spirv.Constant -3 : i32
+ %0 = spirv.SMod %c56, %c7 : i32
+ %1 = spirv.SMod %c56, %cn8 : i32
+ %2 = spirv.SMod %c56, %c3 : i32
+ %3 = spirv.SMod %cn3, %c56 : i32
+ %4 = spirv.SMod %cn3, %cn56 : i32
+ %5 = spirv.SMod %c59, %c56 : i32
+ %6 = spirv.SMod %c59, %cn56 : i32
+ %7 = spirv.SMod %cn59, %cn56 : i32
+
+ // CHECK: return %[[ZERO]], %[[ZERO]], %[[TWO]], %[[FIFTYTHREE]], %[[NTHREE]], %[[THREE]], %[[NFIFTYTHREE]], %[[NTHREE]]
+ return %0, %1, %2, %3, %4, %5, %6, %7 : i32, i32, i32, i32, i32, i32, i32, i32
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.SNegate
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @snegate_twice
+func.func @snegate_twice(%arg0 : i32) -> i32 {
+ // CHECK-NEXT: return %arg0 : i32
+ %0 = spirv.SNegate %arg0 : i32
+ %1 = spirv.SNegate %0 : i32
+ return %1 : i32
+}
+
+// CHECK-LABEL: @const_fold_scalar_snegate
+func.func @const_fold_scalar_snegate() -> (i32, i32, i32) {
+ %c0 = spirv.Constant 0 : i32
+ %c3 = spirv.Constant 3 : i32
+ %cn3 = spirv.Constant -3 : i32
+
+ // CHECK-DAG: %[[THREE:.*]] = spirv.Constant 3 : i32
+ // CHECK-DAG: %[[NTHREE:.*]] = spirv.Constant -3 : i32
+ // CHECK-DAG: %[[ZERO:.*]] = spirv.Constant 0 : i32
+ %0 = spirv.SNegate %c0 : i32
+ %1 = spirv.SNegate %c3 : i32
+ %2 = spirv.SNegate %cn3 : i32
+
+ // CHECK: return %[[ZERO]], %[[NTHREE]], %[[THREE]]
+ return %0, %1, %2 : i32, i32, i32
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.SRem
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @srem_x_1
+func.func @srem_x_1(%arg0 : i32) -> i32 {
+ // CHECK: spirv.Constant 0
+ %c1 = spirv.Constant 1 : i32
+ %0 = spirv.SRem %arg0, %c1 : i32
+ return %0 : i32
+}
+
+// CHECK-LABEL: @srem_div_0_or_overflow
+func.func @srem_div_0_or_overflow() -> (i32, i32) {
+ // CHECK: spirv.SRem
+ // CHECK: spirv.SRem
+ %c0 = spirv.Constant 0 : i32
+ %cn1 = spirv.Constant -1 : i32
+ %min_i32 = spirv.Constant -2147483648 : i32
+
+ %0 = spirv.SRem %cn1, %c0 : i32
+ %1 = spirv.SRem %min_i32, %cn1 : i32
+ return %0, %1 : i32, i32
+}
+
+// CHECK-LABEL: @const_fold_scalar_srem
+func.func @const_fold_scalar_srem() -> (i32, i32, i32, i32, i32) {
+ %c56 = spirv.Constant 56 : i32
+ %c7 = spirv.Constant 7 : i32
+ %cn8 = spirv.Constant -8 : i32
+ %c3 = spirv.Constant 3 : i32
+ %cn3 = spirv.Constant -3 : i32
+
+ // CHECK-DAG: %[[ONE:.*]] = spirv.Constant 1 : i32
+ // CHECK-DAG: %[[NTHREE:.*]] = spirv.Constant -3 : i32
+ // CHECK-DAG: %[[TWO:.*]] = spirv.Constant 2 : i32
+ // CHECK-DAG: %[[ZERO:.*]] = spirv.Constant 0 : i32
+ %0 = spirv.SRem %c56, %c7 : i32
+ %1 = spirv.SRem %c56, %cn8 : i32
+ %2 = spirv.SRem %c56, %c3 : i32
+ %3 = spirv.SRem %cn3, %c56 : i32
+ %4 = spirv.SRem %c7, %cn3 : i32
+ // CHECK: return %[[ZERO]], %[[ZERO]], %[[TWO]], %[[NTHREE]], %[[ONE]]
+ return %0, %1, %2, %3, %4 : i32, i32, i32, i32, i32
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.UDiv
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @udiv_x_1
+func.func @udiv_x_1(%arg0 : i32) -> i32 {
+ // CHECK-NEXT: return %arg0 : i32
+ %c1 = spirv.Constant 1 : i32
+ %2 = spirv.UDiv %arg0, %c1: i32
+ return %2 : i32
+}
+
+// CHECK-LABEL: @udiv_div_0
+func.func @udiv_div_0() -> i32 {
+ // CHECK: spirv.UDiv
+ %c0 = spirv.Constant 0 : i32
+ %cn1 = spirv.Constant -1 : i32
+ %0 = spirv.UDiv %cn1, %c0 : i32
+ return %0 : i32
+}
+
+// CHECK-LABEL: @const_fold_scalar_udiv
+func.func @const_fold_scalar_udiv() -> (i32, i32, i32) {
+ %c56 = spirv.Constant 56 : i32
+ %c7 = spirv.Constant 7 : i32
+ %cn8 = spirv.Constant -8 : i32
+ %c3 = spirv.Constant 3 : i32
+
+ // CHECK-DAG: spirv.Constant 0
+ // CHECK-DAG: spirv.Constant 1431655762
+ // CHECK-DAG: spirv.Constant 8
+ %0 = spirv.UDiv %c56, %c7 : i32
+ %1 = spirv.UDiv %cn8, %c3 : i32
+ %2 = spirv.UDiv %c56, %cn8 : i32
+ return %0, %1, %2 : i32, i32, i32
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.UMod
//===----------------------------------------------------------------------===//
+// CHECK-LABEL: @umod_x_1
+func.func @umod_x_1(%arg0 : i32) -> i32 {
+ // CHECK-NEXT: return %arg0 : i32
+ %c1 = spirv.Constant 1 : i32
+ %2 = spirv.UMod %arg0, %c1: i32
+ return %2 : i32
+}
+
+// CHECK-LABEL: @umod_div_0
+func.func @umod_div_0() -> i32 {
+ // CHECK: spirv.UMod
+ %c0 = spirv.Constant 0 : i32
+ %cn1 = spirv.Constant -1 : i32
+ %0 = spirv.UMod %cn1, %c0 : i32
+ return %0 : i32
+}
+
+// CHECK-LABEL: @const_fold_scalar_umod
+func.func @const_fold_scalar_umod() -> (i32, i32, i32) {
+ %c56 = spirv.Constant 56 : i32
+ %c7 = spirv.Constant 7 : i32
+ %cn8 = spirv.Constant -8 : i32
+ %c3 = spirv.Constant 3 : i32
+
+ // CHECK-DAG: spirv.Constant 0
+ // CHECK-DAG: spirv.Constant 2
+ // CHECK-DAG: spirv.Constant 56
+ %0 = spirv.UMod %c56, %c7 : i32
+ %1 = spirv.UMod %cn8, %c3 : i32
+ %2 = spirv.UMod %c56, %cn8 : i32
+ return %0, %1, %2 : i32, i32, i32
+}
+
// CHECK-LABEL: @umod_fold
// CHECK-SAME: (%[[ARG:.*]]: i32)
func.func @umod_fold(%arg0: i32) -> (i32, i32) {
>From d5f5891b8a39e636e76b6427a2bb1aa5b4ad255c Mon Sep 17 00:00:00 2001
From: inbelic <canadienfinn at gmail.com>
Date: Mon, 6 Nov 2023 19:11:27 +0100
Subject: [PATCH 2/2] add IAddCarry as a canonicalization
- do not have the ability to create a struct constant as it is not yet
implemented, so we use CompositeConstruct instead; hence need to
create a new operation and can't use fold
---
.../Dialect/SPIRV/IR/SPIRVArithmeticOps.td | 2 +
.../SPIRV/IR/SPIRVCanonicalization.cpp | 73 +++++++++++++++++++
.../SPIRV/Transforms/canonicalize.mlir | 47 ++++++++++++
3 files changed, 122 insertions(+)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
index a4d342addb86f81..bf14417a5966f39 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
@@ -379,6 +379,8 @@ def SPIRV_IAddCarryOp : SPIRV_ArithmeticExtendedBinaryOp<"IAddCarry",
%2 = spirv.IAddCarry %0, %1 : !spirv.struct<(vector<2xi32>, vector<2xi32>)>
```
}];
+
+ let hasCanonicalizer = 1;
}
// -----
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index f596b5d1cfcbfbe..6dfeea3f434eac6 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -123,6 +123,79 @@ void spirv::AccessChainOp::getCanonicalizationPatterns(
results.add<CombineChainedAccessChain>(context);
}
+//===----------------------------------------------------------------------===//
+// spirv.IAddCarry
+//===----------------------------------------------------------------------===//
+
+// We are required to use CompositeConstructOp to create a constant struct as
+// they are not yet implemented as constant, hence we can not do so in a fold.
+struct IAddCarryFold final : OpRewritePattern<spirv::IAddCarryOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(spirv::IAddCarryOp op,
+ PatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ auto operands = op.getOperands();
+
+ SmallVector<Value> constituents;
+ Type constituentType = operands[0].getType();
+
+ // iaddcarry (x, 0) = <0, x>
+ if (matchPattern(operands[1], m_Zero())) {
+ constituents.push_back(operands[1]);
+ constituents.push_back(operands[0]);
+ rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
+ constituents);
+ return success();
+ }
+
+ // According to the SPIR-V spec:
+ //
+ // Result Type must be from OpTypeStruct. The struct must have two
+ // members...
+ //
+ // Member 0 of the result gets the low-order bits (full component width) of
+ // the addition.
+ //
+ // Member 1 of the result gets the high-order (carry) bit of the result of
+ // the addition. That is, it gets the value 1 if the addition overflowed
+ // the component width, and 0 otherwise.
+ Attribute lhs;
+ Attribute rhs;
+ if (!matchPattern(operands[0], m_Constant(&lhs)) ||
+ !matchPattern(operands[1], m_Constant(&rhs)))
+ return failure();
+
+ auto adds = constFoldBinaryOp<IntegerAttr>(
+ {lhs, rhs}, [](const APInt &a, const APInt &b) { return a + b; });
+ auto carrys = constFoldBinaryOp<IntegerAttr>(
+ ArrayRef{lhs, rhs}, [](const APInt &a, const APInt &b) {
+ APInt zero = APInt::getZero(a.getBitWidth());
+ return a.ult(b) ? zero : (zero + 1);
+ });
+
+ if (!adds || !carrys)
+ return failure();
+
+ Value carrysVal =
+ rewriter.create<spirv::ConstantOp>(loc, constituentType, carrys);
+ constituents.push_back(carrysVal);
+
+ Value addsVal =
+ rewriter.create<spirv::ConstantOp>(loc, constituentType, adds);
+ constituents.push_back(addsVal);
+
+ rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
+ constituents);
+ return success();
+ }
+};
+
+void spirv::IAddCarryOp::getCanonicalizationPatterns(
+ RewritePatternSet &patterns, MLIRContext *context) {
+ patterns.add<IAddCarryFold>(context);
+}
+
//===----------------------------------------------------------------------===//
// spirv.UMod
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
index 93eceb036a93a23..265bfade23dd863 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
@@ -336,6 +336,53 @@ func.func @iadd_poison(%arg0: i32) -> i32 {
// -----
+//===----------------------------------------------------------------------===//
+// spirv.IAddCarry
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @iaddcarry_x_0
+func.func @iaddcarry_x_0(%arg0 : i32) -> !spirv.struct<(i32, i32)> {
+ %c0 = spirv.Constant 0 : i32
+
+ // CHECK: spirv.CompositeConstruct
+ %0 = spirv.IAddCarry %arg0, %c0 : !spirv.struct<(i32, i32)>
+ return %0 : !spirv.struct<(i32, i32)>
+}
+
+// CHECK-LABEL: @const_fold_scalar_iaddcarry
+func.func @const_fold_scalar_iaddcarry() -> (!spirv.struct<(i32, i32)>, !spirv.struct<(i32, i32)>) {
+ %c5 = spirv.Constant 5 : i32
+ %cn5 = spirv.Constant -5 : i32
+ %cn8 = spirv.Constant -8 : i32
+
+ // CHECK-DAG: spirv.Constant 0
+ // CHECK-DAG: spirv.Constant -3
+ // CHECK-DAG: spirv.CompositeConstruct
+ // CHECK-DAG: spirv.Constant 1
+ // CHECK-DAG: spirv.Constant -13
+ // CHECK-DAG: spirv.CompositeConstruct
+ %0 = spirv.IAddCarry %c5, %cn8 : !spirv.struct<(i32, i32)>
+ %1 = spirv.IAddCarry %cn5, %cn8 : !spirv.struct<(i32, i32)>
+
+ return %0, %1 : !spirv.struct<(i32, i32)>, !spirv.struct<(i32, i32)>
+}
+
+// CHECK-LABEL: @const_fold_vector_iaddcarry
+func.func @const_fold_vector_iaddcarry() -> !spirv.struct<(vector<3xi32>, vector<3xi32>)> {
+ %lhs_vals = arith.constant dense<[-1, 24, 4294967295, 43]> : vector<4xi32>
+ %v0 = spirv.Constant dense<[5, -5, -1]> : vector<3xi32>
+ %v1 = spirv.Constant dense<[-8, -8, 1]> : vector<3xi32>
+
+ // CHECK: spirv.Constant dense<[0, 1, 1]>
+ // CHECK-NEXT: spirv.Constant dense<[-3, -13, 0]>
+ // CHECK-NEXT: spirv.CompositeConstruct
+ %0 = spirv.IAddCarry %v0, %v1 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+ return %0 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.IMul
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list