[Mlir-commits] [mlir] [mlir][spirv] Add basic arithmetic folds (PR #71414)

Finn Plummer llvmlistbot at llvm.org
Mon Nov 6 10:11:55 PST 2023


https://github.com/inbelic updated https://github.com/llvm/llvm-project/pull/71414

>From 2ec719195706ec742ff31e870d0086528aaef4a5 Mon Sep 17 00:00:00 2001
From: inbelic <canadienfinn at gmail.com>
Date: Mon, 6 Nov 2023 11:06:10 +0100
Subject: [PATCH 1/2] [mlir][spirv] Add basic arithmetic folds

We do not have basic constant folds for many of the arithmetic
operations which negatively impacts readability of lowered or otherwise
generated code. This commit works towards adding more operations,
namely: [SU]Div, [SU]Mod, SRem, SNegate, IAddCarry, [SU]MulExtended

Resolves #70704

TODO:
- missing test case for vectors not added yet (just lazy)
- missing IAddCarry and [SU]MulExtended as unclear how to construct
  Composite Attribute
---
 .../Dialect/SPIRV/IR/SPIRVArithmeticOps.td    |  11 +
 .../SPIRV/IR/SPIRVCanonicalization.cpp        | 182 +++++++++++++
 .../SPIRV/Transforms/canonicalize.mlir        | 254 ++++++++++++++++++
 3 files changed, 447 insertions(+)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
index c4d1e01f9feef83..a4d342addb86f81 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
@@ -534,6 +534,8 @@ def SPIRV_SDivOp : SPIRV_ArithmeticBinaryOp<"SDiv",
 
     ```
   }];
+
+  let hasFolder = 1;
 }
 
 // -----
@@ -573,6 +575,8 @@ def SPIRV_SModOp : SPIRV_ArithmeticBinaryOp<"SMod",
 
     ```
   }];
+
+  let hasFolder = 1;
 }
 
 // -----
@@ -634,6 +638,8 @@ def SPIRV_SNegateOp : SPIRV_ArithmeticUnaryOp<"SNegate",
     %3 = spirv.SNegate %2 : vector<4xi32>
     ```
   }];
+
+  let hasFolder = 1;
 }
 
 // -----
@@ -673,6 +679,8 @@ def SPIRV_SRemOp : SPIRV_ArithmeticBinaryOp<"SRem",
 
     ```
   }];
+
+  let hasFolder = 1;
 }
 
 // -----
@@ -707,6 +715,8 @@ def SPIRV_UDivOp : SPIRV_ArithmeticBinaryOp<"UDiv",
 
     ```
   }];
+
+  let hasFolder = 1;
 }
 
 // -----
@@ -811,6 +821,7 @@ def SPIRV_UModOp : SPIRV_ArithmeticBinaryOp<"UMod",
     ```
   }];
 
+  let hasFolder = 1;
   let hasCanonicalizer = 1;
 }
 
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index 9acd982dc95af6d..f596b5d1cfcbfbe 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -69,6 +69,14 @@ static Attribute extractCompositeElement(Attribute composite,
   return {};
 }
 
+static bool isDivZeroOrOverflow(const APInt &a, const APInt &b) {
+  bool div0 = b.isZero();
+
+  bool overflow = a.isMinSignedValue() && b.isAllOnes();
+
+  return div0 || overflow;
+}
+
 //===----------------------------------------------------------------------===//
 // TableGen'erated canonicalizers
 //===----------------------------------------------------------------------===//
@@ -290,6 +298,180 @@ OpFoldResult spirv::ISubOp::fold(FoldAdaptor adaptor) {
       [](APInt a, const APInt &b) { return std::move(a) - b; });
 }
 
+//===----------------------------------------------------------------------===//
+// spirv.SDiv
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::SDivOp::fold(FoldAdaptor adaptor) {
+  // sdiv (x, 1) = x
+  if (matchPattern(getOperand2(), m_One()))
+    return getOperand1();
+
+  // According to the SPIR-V spec:
+  //
+  // Signed-integer division of Operand 1 divided by Operand 2.
+  // Results are computed per component. Behavior is undefined if Operand 2 is
+  // 0. Behavior is undefined if Operand 2 is -1 and Operand 1 is the minimum
+  // representable value for the operands' type, causing signed overflow.
+  //
+  // So don't fold during undefined behaviour.
+  bool div0OrOverflow = false;
+  auto res = constFoldBinaryOp<IntegerAttr>(
+      adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
+        if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
+          div0OrOverflow = true;
+          return a;
+        }
+        return a.sdiv(b);
+      });
+  return div0OrOverflow ? Attribute() : res;
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.SMod
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::SModOp::fold(FoldAdaptor adaptor) {
+  // smod (x, 1) = 0
+  if (matchPattern(getOperand2(), m_One()))
+    return Builder(getContext()).getIntegerAttr(getType(), 0);
+
+  // According to SPIR-V spec:
+  //
+  // Signed remainder operation for the remainder whose sign matches the sign
+  // of Operand 2. Behavior is undefined if Operand 2 is 0. Behavior is
+  // undefined if Operand 2 is -1 and Operand 1 is the minimum representable
+  // value for the operands' type, causing signed overflow. Otherwise, the
+  // result is the remainder r of Operand 1 divided by Operand 2 where if
+  // r ≠ 0, the sign of r is the same as the sign of Operand 2.
+  //
+  // So don't fold during undefined behaviour
+  bool div0OrOverflow = false;
+  auto res = constFoldBinaryOp<IntegerAttr>(
+      adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
+        if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
+          div0OrOverflow = true;
+          return a;
+        }
+        APInt c = a.abs().urem(b.abs());
+        if (c.isZero())
+          return c;
+        if (b.isNegative()) {
+          APInt zero = APInt::getZero(c.getBitWidth()); // TODO: See next todo
+          return a.isNegative() ? (zero - c) : (b + c);
+        }
+        return a.isNegative() ? (b - c) : c;
+      });
+  return div0OrOverflow ? Attribute() : res;
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.SNegate
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::SNegateOp::fold(FoldAdaptor adaptor) {
+  // -(-x) = 0 - (0 - x) = x
+  auto op = getOperand();
+  if (auto negateOp = op.getDefiningOp<spirv::SNegateOp>())
+    return negateOp->getOperand(0);
+
+  // According to the SPIR-V spec:
+  //
+  // Signed-integer subtract of Operand from zero.
+  return constFoldUnaryOp<IntegerAttr>(
+      adaptor.getOperands(), [](const APInt &a) {
+        // TODO: doing return 0 - a; also appears to work and since SPIR-V's are
+        // only 16/32/64 bits then it should be okay but double-check
+        APInt zero = APInt::getZero(a.getBitWidth());
+        return zero - a;
+      });
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.SRem
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::SRemOp::fold(FoldAdaptor adaptor) {
+  // x % 1 = 0
+  if (matchPattern(getOperand2(), m_One()))
+    return Builder(getContext()).getIntegerAttr(getType(), 0);
+
+  // According to SPIR-V spec:
+  //
+  // Signed remainder operation for the remainder whose sign matches the sign
+  // of Operand 1. Behavior is undefined if Operand 2 is 0. Behavior is
+  // undefined if Operand 2 is -1 and Operand 1 is the minimum representable
+  // value for the operands' type, causing signed overflow. Otherwise, the
+  // result is the remainder r of Operand 1 divided by Operand 2 where if
+  // r ≠ 0, the sign of r is the same as the sign of Operand 1.
+
+  // Don't fold if it would do undefined behaviour.
+  bool div0OrOverflow = false;
+  auto res = constFoldBinaryOp<IntegerAttr>(
+      adaptor.getOperands(), [&](APInt a, const APInt &b) {
+        if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
+          div0OrOverflow = true;
+          return a;
+        }
+        return a.srem(b);
+      });
+  return div0OrOverflow ? Attribute() : res;
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.UDiv
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::UDivOp::fold(FoldAdaptor adaptor) {
+  // udiv (x, 1) = x
+  if (matchPattern(getOperand2(), m_One()))
+    return getOperand1();
+
+  // According to the SPIR-V spec:
+  //
+  // Unsigned-integer division of Operand 1 divided by Operand 2. Behavior is
+  // undefined if Operand 2 is 0.
+  //
+  // So don't fold during undefined behaviour.
+  bool div0 = false;
+  auto res = constFoldBinaryOp<IntegerAttr>(
+      adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
+        if (div0 || b.isZero()) {
+          div0 = true;
+          return a;
+        }
+        return a.udiv(b);
+      });
+  return div0 ? Attribute() : res;
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.UMod
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::UModOp::fold(FoldAdaptor adaptor) {
+  // umod (x, 1) = 0
+  if (matchPattern(getOperand2(), m_One()))
+    return getOperand1();
+
+  // According to the SPIR-V spec:
+  //
+  // Unsigned modulo operation of Operand 1 modulo Operand 2. Behavior is
+  // undefined if Operand 2 is 0.
+  //
+  // So don't fold during undefined behaviour.
+  bool div0 = false;
+  auto res = constFoldBinaryOp<IntegerAttr>(
+      adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
+        if (div0 || b.isZero()) {
+          div0 = true;
+          return a;
+        }
+        return a.urem(b);
+      });
+  return div0 ? Attribute() : res;
+}
+
 //===----------------------------------------------------------------------===//
 // spirv.LogicalAnd
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
index 0200805a444397a..93eceb036a93a23 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
@@ -462,10 +462,264 @@ func.func @const_fold_vector_isub() -> vector<3xi32> {
 
 // -----
 
+//===----------------------------------------------------------------------===//
+// spirv.SDiv
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @sdiv_x_1
+func.func @sdiv_x_1(%arg0 : i32) -> i32 {
+  // CHECK-NEXT: return %arg0 : i32
+  %c1 = spirv.Constant 1  : i32
+  %2 = spirv.SDiv %arg0, %c1: i32
+  return %2 : i32
+}
+
+// CHECK-LABEL: @sdiv_div_0_or_overflow
+func.func @sdiv_div_0_or_overflow() -> (i32, i32) {
+  // CHECK: spirv.SDiv
+  // CHECK: spirv.SDiv
+  %c0 = spirv.Constant 0 : i32
+  %cn1 = spirv.Constant -1 : i32
+  %min_i32 = spirv.Constant -2147483648 : i32
+
+  %0 = spirv.SDiv %cn1, %c0 : i32
+  %1 = spirv.SDiv %min_i32, %cn1 : i32
+  return %0, %1 : i32, i32
+}
+
+// CHECK-LABEL: @const_fold_scalar_sdiv
+func.func @const_fold_scalar_sdiv() -> (i32, i32, i32, i32) {
+  %c56 = spirv.Constant 56 : i32
+  %c7 = spirv.Constant 7 : i32
+  %cn8 = spirv.Constant -8 : i32
+  %c3 = spirv.Constant 3 : i32
+  %cn3 = spirv.Constant -3 : i32
+
+  // CHECK-DAG: spirv.Constant -18
+  // CHECK-DAG: spirv.Constant -2
+  // CHECK-DAG: spirv.Constant -7
+  // CHECK-DAG: spirv.Constant 8
+  %0 = spirv.SDiv %c56, %c7 : i32
+  %1 = spirv.SDiv %c56, %cn8 : i32
+  %2 = spirv.SDiv %cn8, %c3 : i32
+  %3 = spirv.SDiv %c56, %cn3 : i32
+  return %0, %1, %2, %3: i32, i32, i32, i32
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.SMod
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @smod_x_1
+func.func @smod_x_1(%arg0 : i32) -> i32 {
+  // CHECK: spirv.Constant 0
+  %c1 = spirv.Constant 1 : i32
+  %0 = spirv.SMod %arg0, %c1: i32
+  return %0 : i32
+}
+
+// CHECK-LABEL: @smod_div_0_or_overflow
+func.func @smod_div_0_or_overflow() -> (i32, i32) {
+  // CHECK: spirv.SMod
+  // CHECK: spirv.SMod
+  %c0 = spirv.Constant 0 : i32
+  %cn1 = spirv.Constant -1 : i32
+  %min_i32 = spirv.Constant -2147483648 : i32
+
+  %0 = spirv.SMod %cn1, %c0 : i32
+  %1 = spirv.SMod %min_i32, %cn1 : i32
+  return %0, %1 : i32, i32
+}
+
+// CHECK-LABEL: @const_fold_scalar_smod
+func.func @const_fold_scalar_smod() -> (i32, i32, i32, i32, i32, i32, i32, i32) {
+  %c56 = spirv.Constant 56 : i32
+  %cn56 = spirv.Constant -56 : i32
+  %c59 = spirv.Constant 59 : i32
+  %cn59 = spirv.Constant -59 : i32
+  %c7 = spirv.Constant 7 : i32
+  %cn8 = spirv.Constant -8 : i32
+  %c3 = spirv.Constant 3 : i32
+  %cn3 = spirv.Constant -3 : i32
+
+  // CHECK-DAG: %[[ZERO:.*]] = spirv.Constant 0 : i32
+  // CHECK-DAG: %[[TWO:.*]] = spirv.Constant 2 : i32
+  // CHECK-DAG: %[[FIFTYTHREE:.*]] = spirv.Constant 53 : i32
+  // CHECK-DAG: %[[NFIFTYTHREE:.*]] = spirv.Constant -53 : i32
+  // CHECK-DAG: %[[THREE:.*]] = spirv.Constant 3 : i32
+  // CHECK-DAG: %[[NTHREE:.*]] = spirv.Constant -3 : i32
+  %0 = spirv.SMod %c56, %c7 : i32
+  %1 = spirv.SMod %c56, %cn8 : i32
+  %2 = spirv.SMod %c56, %c3 : i32
+  %3 = spirv.SMod %cn3, %c56 : i32
+  %4 = spirv.SMod %cn3, %cn56 : i32
+  %5 = spirv.SMod %c59, %c56 : i32
+  %6 = spirv.SMod %c59, %cn56 : i32
+  %7 = spirv.SMod %cn59, %cn56 : i32
+
+  // CHECK: return %[[ZERO]], %[[ZERO]], %[[TWO]], %[[FIFTYTHREE]], %[[NTHREE]], %[[THREE]], %[[NFIFTYTHREE]], %[[NTHREE]]
+  return %0, %1, %2, %3, %4, %5, %6, %7 : i32, i32, i32, i32, i32, i32, i32, i32
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.SNegate
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @snegate_twice
+func.func @snegate_twice(%arg0 : i32) -> i32 {
+  // CHECK-NEXT: return %arg0 : i32
+  %0 = spirv.SNegate %arg0 : i32
+  %1 = spirv.SNegate %0 : i32
+  return %1 : i32
+}
+
+// CHECK-LABEL: @const_fold_scalar_snegate
+func.func @const_fold_scalar_snegate() -> (i32, i32, i32) {
+  %c0 = spirv.Constant 0 : i32
+  %c3 = spirv.Constant 3 : i32
+  %cn3 = spirv.Constant -3 : i32
+
+  // CHECK-DAG: %[[THREE:.*]] = spirv.Constant 3 : i32
+  // CHECK-DAG: %[[NTHREE:.*]] = spirv.Constant -3 : i32
+  // CHECK-DAG: %[[ZERO:.*]] = spirv.Constant 0 : i32
+  %0 = spirv.SNegate %c0 : i32
+  %1 = spirv.SNegate %c3 : i32
+  %2 = spirv.SNegate %cn3 : i32
+
+  // CHECK: return %[[ZERO]], %[[NTHREE]], %[[THREE]]
+  return %0, %1, %2  : i32, i32, i32
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.SRem
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @srem_x_1
+func.func @srem_x_1(%arg0 : i32) -> i32 {
+  // CHECK: spirv.Constant 0
+  %c1 = spirv.Constant 1 : i32
+  %0 = spirv.SRem %arg0, %c1 : i32
+  return %0 : i32
+}
+
+// CHECK-LABEL: @srem_div_0_or_overflow
+func.func @srem_div_0_or_overflow() -> (i32, i32) {
+  // CHECK: spirv.SRem
+  // CHECK: spirv.SRem
+  %c0 = spirv.Constant 0 : i32
+  %cn1 = spirv.Constant -1 : i32
+  %min_i32 = spirv.Constant -2147483648 : i32
+
+  %0 = spirv.SRem %cn1, %c0 : i32
+  %1 = spirv.SRem %min_i32, %cn1 : i32
+  return %0, %1 : i32, i32
+}
+
+// CHECK-LABEL: @const_fold_scalar_srem
+func.func @const_fold_scalar_srem() -> (i32, i32, i32, i32, i32) {
+  %c56 = spirv.Constant 56 : i32
+  %c7 = spirv.Constant 7 : i32
+  %cn8 = spirv.Constant -8 : i32
+  %c3 = spirv.Constant 3 : i32
+  %cn3 = spirv.Constant -3 : i32
+
+  // CHECK-DAG: %[[ONE:.*]] = spirv.Constant 1 : i32
+  // CHECK-DAG: %[[NTHREE:.*]] = spirv.Constant -3 : i32
+  // CHECK-DAG: %[[TWO:.*]] = spirv.Constant 2 : i32
+  // CHECK-DAG: %[[ZERO:.*]] = spirv.Constant 0 : i32
+  %0 = spirv.SRem %c56, %c7 : i32
+  %1 = spirv.SRem %c56, %cn8 : i32
+  %2 = spirv.SRem %c56, %c3 : i32
+  %3 = spirv.SRem %cn3, %c56 : i32
+  %4 = spirv.SRem %c7, %cn3 : i32
+  // CHECK: return %[[ZERO]], %[[ZERO]], %[[TWO]], %[[NTHREE]], %[[ONE]]
+  return %0, %1, %2, %3, %4 : i32, i32, i32, i32, i32
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.UDiv
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @udiv_x_1
+func.func @udiv_x_1(%arg0 : i32) -> i32 {
+  // CHECK-NEXT: return %arg0 : i32
+  %c1 = spirv.Constant 1  : i32
+  %2 = spirv.UDiv %arg0, %c1: i32
+  return %2 : i32
+}
+
+// CHECK-LABEL: @udiv_div_0
+func.func @udiv_div_0() -> i32 {
+  // CHECK: spirv.UDiv
+  %c0 = spirv.Constant 0 : i32
+  %cn1 = spirv.Constant -1 : i32
+  %0 = spirv.UDiv %cn1, %c0 : i32
+  return %0 : i32
+}
+
+// CHECK-LABEL: @const_fold_scalar_udiv
+func.func @const_fold_scalar_udiv() -> (i32, i32, i32) {
+  %c56 = spirv.Constant 56 : i32
+  %c7 = spirv.Constant 7 : i32
+  %cn8 = spirv.Constant -8 : i32
+  %c3 = spirv.Constant 3 : i32
+
+  // CHECK-DAG: spirv.Constant 0
+  // CHECK-DAG: spirv.Constant 1431655762
+  // CHECK-DAG: spirv.Constant 8
+  %0 = spirv.UDiv %c56, %c7 : i32
+  %1 = spirv.UDiv %cn8, %c3 : i32
+  %2 = spirv.UDiv %c56, %cn8 : i32
+  return %0, %1, %2 : i32, i32, i32
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.UMod
 //===----------------------------------------------------------------------===//
 
+// CHECK-LABEL: @umod_x_1
+func.func @umod_x_1(%arg0 : i32) -> i32 {
+  // CHECK-NEXT: return %arg0 : i32
+  %c1 = spirv.Constant 1  : i32
+  %2 = spirv.UMod %arg0, %c1: i32
+  return %2 : i32
+}
+
+// CHECK-LABEL: @umod_div_0
+func.func @umod_div_0() -> i32 {
+  // CHECK: spirv.UMod
+  %c0 = spirv.Constant 0 : i32
+  %cn1 = spirv.Constant -1 : i32
+  %0 = spirv.UMod %cn1, %c0 : i32
+  return %0 : i32
+}
+
+// CHECK-LABEL: @const_fold_scalar_umod
+func.func @const_fold_scalar_umod() -> (i32, i32, i32) {
+  %c56 = spirv.Constant 56 : i32
+  %c7 = spirv.Constant 7 : i32
+  %cn8 = spirv.Constant -8 : i32
+  %c3 = spirv.Constant 3 : i32
+
+  // CHECK-DAG: spirv.Constant 0
+  // CHECK-DAG: spirv.Constant 2
+  // CHECK-DAG: spirv.Constant 56
+  %0 = spirv.UMod %c56, %c7 : i32
+  %1 = spirv.UMod %cn8, %c3 : i32
+  %2 = spirv.UMod %c56, %cn8 : i32
+  return %0, %1, %2 : i32, i32, i32
+}
+
 // CHECK-LABEL: @umod_fold
 // CHECK-SAME: (%[[ARG:.*]]: i32)
 func.func @umod_fold(%arg0: i32) -> (i32, i32) {

>From ce0575755c388662067c3be469a0208fb8eeb63c Mon Sep 17 00:00:00 2001
From: inbelic <canadienfinn at gmail.com>
Date: Mon, 6 Nov 2023 19:11:27 +0100
Subject: [PATCH 2/2] initial attempt for IAddCarry

---
 .../Dialect/SPIRV/IR/SPIRVArithmeticOps.td    |  2 ++
 .../SPIRV/IR/SPIRVCanonicalization.cpp        | 34 +++++++++++++++++++
 .../SPIRV/Transforms/canonicalize.mlir        | 25 ++++++++++++++
 3 files changed, 61 insertions(+)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
index a4d342addb86f81..63ebb064cd846b5 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
@@ -379,6 +379,8 @@ def SPIRV_IAddCarryOp : SPIRV_ArithmeticExtendedBinaryOp<"IAddCarry",
     %2 = spirv.IAddCarry %0, %1 : !spirv.struct<(vector<2xi32>, vector<2xi32>)>
     ```
   }];
+
+  let hasFolder = 1;
 }
 
 // -----
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index f596b5d1cfcbfbe..16b90b80c56ed48 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -257,6 +257,40 @@ OpFoldResult spirv::IAddOp::fold(FoldAdaptor adaptor) {
       [](APInt a, const APInt &b) { return std::move(a) + b; });
 }
 
+//===----------------------------------------------------------------------===//
+// spirv.IAddCarry
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::IAddCarryOp::fold(FoldAdaptor adaptor) {
+
+  // According to the SPIR-V spec:
+  //
+  //  Result Type must be from OpTypeStruct.  The struct must have two
+  //  members...
+  //
+  //  Member 0 of the result gets the low-order bits (full component width) of
+  //  the addition.
+  //
+  //  Member 1 of the result gets the high-order (carry) bit of the result of
+  //  the addition. That is, it gets the value 1 if the addition overflowed
+  //  the component width, and 0 otherwise.
+  auto operands = adaptor.getOperands();
+  auto adds = constFoldBinaryOp<IntegerAttr>(
+        operands,
+        [](const APInt &a, const APInt &b) {
+          return a + b;
+        });
+  auto carrys = constFoldBinaryOp<IntegerAttr>(
+        ArrayRef{operands[0], adds},
+        [](const APInt &a, const APInt &b) {
+          APInt zero = APInt::getZero(a.getBitWidth());
+          return a.ult(b) ? zero : (zero + 1);
+        });
+  // FIXME: How can we return this as a the spirv.constant attr? Maybe
+  // DenseElementAttr or ArrayAttr? Neither seem to do the trick :(
+  return ArrayAttr::get(getContext(), {carrys, adds});
+}
+
 //===----------------------------------------------------------------------===//
 // spirv.IMul
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
index 93eceb036a93a23..bc7c804384562da 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
@@ -336,6 +336,31 @@ func.func @iadd_poison(%arg0: i32) -> i32 {
 
 // -----
 
+//===----------------------------------------------------------------------===//
+// spirv.IAddCarry
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @const_fold_scalar_iaddcarry
+func.func @const_fold_scalar_iaddcarry() -> !spirv.struct<(i32, i32)> {
+  %c5 = spirv.Constant 5 : i32
+  %cn8 = spirv.Constant -8 : i32
+
+  %0 = spirv.IAddCarry %c5, %cn8 : !spirv.struct<(i32, i32)>
+  return %0 : !spirv.struct<(i32, i32)>
+}
+
+// CHECK-LABEL: @const_fold_vector_iaddcarry
+func.func @const_fold_vector_iaddcarry() -> !spirv.struct<(vector<3xi32>, vector<3xi32>)> {
+  %c5 = spirv.Constant dense<5> : vector<3xi32>
+  %cn8 = spirv.Constant dense<-8> : vector<3xi32>
+
+  %0 = spirv.IAddCarry %c5, %cn8 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+  return %0 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.IMul
 //===----------------------------------------------------------------------===//



More information about the Mlir-commits mailing list