[Mlir-commits] [mlir] [mlir][spirv] Add basic arithmetic folds (PR #71414)

Finn Plummer llvmlistbot at llvm.org
Fri Nov 10 07:39:33 PST 2023


https://github.com/inbelic updated https://github.com/llvm/llvm-project/pull/71414

>From 2ec719195706ec742ff31e870d0086528aaef4a5 Mon Sep 17 00:00:00 2001
From: inbelic <canadienfinn at gmail.com>
Date: Mon, 6 Nov 2023 11:06:10 +0100
Subject: [PATCH 1/2] [mlir][spirv] Add basic arithmetic folds

We do not have basic constant folds for many of the arithmetic
operations which negatively impacts readability of lowered or otherwise
generated code. This commit works towards adding more operations,
namely: [SU]Div, [SU]Mod, SRem, SNegate, IAddCarry, [SU]MulExtended

Resolves #70704

TODO:
- missing test case for vectors not added yet (just lazy)
- missing IAddCarry and [SU]MulExtended as unclear how to construct
  Composite Attribute
---
 .../Dialect/SPIRV/IR/SPIRVArithmeticOps.td    |  11 +
 .../SPIRV/IR/SPIRVCanonicalization.cpp        | 182 +++++++++++++
 .../SPIRV/Transforms/canonicalize.mlir        | 254 ++++++++++++++++++
 3 files changed, 447 insertions(+)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
index c4d1e01f9feef83..a4d342addb86f81 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
@@ -534,6 +534,8 @@ def SPIRV_SDivOp : SPIRV_ArithmeticBinaryOp<"SDiv",
 
     ```
   }];
+
+  let hasFolder = 1;
 }
 
 // -----
@@ -573,6 +575,8 @@ def SPIRV_SModOp : SPIRV_ArithmeticBinaryOp<"SMod",
 
     ```
   }];
+
+  let hasFolder = 1;
 }
 
 // -----
@@ -634,6 +638,8 @@ def SPIRV_SNegateOp : SPIRV_ArithmeticUnaryOp<"SNegate",
     %3 = spirv.SNegate %2 : vector<4xi32>
     ```
   }];
+
+  let hasFolder = 1;
 }
 
 // -----
@@ -673,6 +679,8 @@ def SPIRV_SRemOp : SPIRV_ArithmeticBinaryOp<"SRem",
 
     ```
   }];
+
+  let hasFolder = 1;
 }
 
 // -----
@@ -707,6 +715,8 @@ def SPIRV_UDivOp : SPIRV_ArithmeticBinaryOp<"UDiv",
 
     ```
   }];
+
+  let hasFolder = 1;
 }
 
 // -----
@@ -811,6 +821,7 @@ def SPIRV_UModOp : SPIRV_ArithmeticBinaryOp<"UMod",
     ```
   }];
 
+  let hasFolder = 1;
   let hasCanonicalizer = 1;
 }
 
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index 9acd982dc95af6d..f596b5d1cfcbfbe 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -69,6 +69,14 @@ static Attribute extractCompositeElement(Attribute composite,
   return {};
 }
 
+static bool isDivZeroOrOverflow(const APInt &a, const APInt &b) {
+  bool div0 = b.isZero();
+
+  bool overflow = a.isMinSignedValue() && b.isAllOnes();
+
+  return div0 || overflow;
+}
+
 //===----------------------------------------------------------------------===//
 // TableGen'erated canonicalizers
 //===----------------------------------------------------------------------===//
@@ -290,6 +298,180 @@ OpFoldResult spirv::ISubOp::fold(FoldAdaptor adaptor) {
       [](APInt a, const APInt &b) { return std::move(a) - b; });
 }
 
+//===----------------------------------------------------------------------===//
+// spirv.SDiv
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::SDivOp::fold(FoldAdaptor adaptor) {
+  // sdiv (x, 1) = x
+  if (matchPattern(getOperand2(), m_One()))
+    return getOperand1();
+
+  // According to the SPIR-V spec:
+  //
+  // Signed-integer division of Operand 1 divided by Operand 2.
+  // Results are computed per component. Behavior is undefined if Operand 2 is
+  // 0. Behavior is undefined if Operand 2 is -1 and Operand 1 is the minimum
+  // representable value for the operands' type, causing signed overflow.
+  //
+  // So don't fold during undefined behaviour.
+  bool div0OrOverflow = false;
+  auto res = constFoldBinaryOp<IntegerAttr>(
+      adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
+        if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
+          div0OrOverflow = true;
+          return a;
+        }
+        return a.sdiv(b);
+      });
+  return div0OrOverflow ? Attribute() : res;
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.SMod
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::SModOp::fold(FoldAdaptor adaptor) {
+  // smod (x, 1) = 0
+  if (matchPattern(getOperand2(), m_One()))
+    return Builder(getContext()).getIntegerAttr(getType(), 0);
+
+  // According to SPIR-V spec:
+  //
+  // Signed remainder operation for the remainder whose sign matches the sign
+  // of Operand 2. Behavior is undefined if Operand 2 is 0. Behavior is
+  // undefined if Operand 2 is -1 and Operand 1 is the minimum representable
+  // value for the operands' type, causing signed overflow. Otherwise, the
+  // result is the remainder r of Operand 1 divided by Operand 2 where if
+  // r ≠ 0, the sign of r is the same as the sign of Operand 2.
+  //
+  // So don't fold during undefined behaviour
+  bool div0OrOverflow = false;
+  auto res = constFoldBinaryOp<IntegerAttr>(
+      adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
+        if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
+          div0OrOverflow = true;
+          return a;
+        }
+        APInt c = a.abs().urem(b.abs());
+        if (c.isZero())
+          return c;
+        if (b.isNegative()) {
+          APInt zero = APInt::getZero(c.getBitWidth()); // TODO: See next todo
+          return a.isNegative() ? (zero - c) : (b + c);
+        }
+        return a.isNegative() ? (b - c) : c;
+      });
+  return div0OrOverflow ? Attribute() : res;
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.SNegate
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::SNegateOp::fold(FoldAdaptor adaptor) {
+  // -(-x) = 0 - (0 - x) = x
+  auto op = getOperand();
+  if (auto negateOp = op.getDefiningOp<spirv::SNegateOp>())
+    return negateOp->getOperand(0);
+
+  // According to the SPIR-V spec:
+  //
+  // Signed-integer subtract of Operand from zero.
+  return constFoldUnaryOp<IntegerAttr>(
+      adaptor.getOperands(), [](const APInt &a) {
+        // TODO: doing return 0 - a; also appears to work and since SPIR-V's are
+        // only 16/32/64 bits then it should be okay but double-check
+        APInt zero = APInt::getZero(a.getBitWidth());
+        return zero - a;
+      });
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.SRem
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::SRemOp::fold(FoldAdaptor adaptor) {
+  // x % 1 = 0
+  if (matchPattern(getOperand2(), m_One()))
+    return Builder(getContext()).getIntegerAttr(getType(), 0);
+
+  // According to SPIR-V spec:
+  //
+  // Signed remainder operation for the remainder whose sign matches the sign
+  // of Operand 1. Behavior is undefined if Operand 2 is 0. Behavior is
+  // undefined if Operand 2 is -1 and Operand 1 is the minimum representable
+  // value for the operands' type, causing signed overflow. Otherwise, the
+  // result is the remainder r of Operand 1 divided by Operand 2 where if
+  // r ≠ 0, the sign of r is the same as the sign of Operand 1.
+
+  // Don't fold if it would do undefined behaviour.
+  bool div0OrOverflow = false;
+  auto res = constFoldBinaryOp<IntegerAttr>(
+      adaptor.getOperands(), [&](APInt a, const APInt &b) {
+        if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
+          div0OrOverflow = true;
+          return a;
+        }
+        return a.srem(b);
+      });
+  return div0OrOverflow ? Attribute() : res;
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.UDiv
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::UDivOp::fold(FoldAdaptor adaptor) {
+  // udiv (x, 1) = x
+  if (matchPattern(getOperand2(), m_One()))
+    return getOperand1();
+
+  // According to the SPIR-V spec:
+  //
+  // Unsigned-integer division of Operand 1 divided by Operand 2. Behavior is
+  // undefined if Operand 2 is 0.
+  //
+  // So don't fold during undefined behaviour.
+  bool div0 = false;
+  auto res = constFoldBinaryOp<IntegerAttr>(
+      adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
+        if (div0 || b.isZero()) {
+          div0 = true;
+          return a;
+        }
+        return a.udiv(b);
+      });
+  return div0 ? Attribute() : res;
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.UMod
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::UModOp::fold(FoldAdaptor adaptor) {
+  // umod (x, 1) = 0
+  if (matchPattern(getOperand2(), m_One()))
+    return getOperand1();
+
+  // According to the SPIR-V spec:
+  //
+  // Unsigned modulo operation of Operand 1 modulo Operand 2. Behavior is
+  // undefined if Operand 2 is 0.
+  //
+  // So don't fold during undefined behaviour.
+  bool div0 = false;
+  auto res = constFoldBinaryOp<IntegerAttr>(
+      adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
+        if (div0 || b.isZero()) {
+          div0 = true;
+          return a;
+        }
+        return a.urem(b);
+      });
+  return div0 ? Attribute() : res;
+}
+
 //===----------------------------------------------------------------------===//
 // spirv.LogicalAnd
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
index 0200805a444397a..93eceb036a93a23 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
@@ -462,10 +462,264 @@ func.func @const_fold_vector_isub() -> vector<3xi32> {
 
 // -----
 
+//===----------------------------------------------------------------------===//
+// spirv.SDiv
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @sdiv_x_1
+func.func @sdiv_x_1(%arg0 : i32) -> i32 {
+  // CHECK-NEXT: return %arg0 : i32
+  %c1 = spirv.Constant 1  : i32
+  %2 = spirv.SDiv %arg0, %c1: i32
+  return %2 : i32
+}
+
+// CHECK-LABEL: @sdiv_div_0_or_overflow
+func.func @sdiv_div_0_or_overflow() -> (i32, i32) {
+  // CHECK: spirv.SDiv
+  // CHECK: spirv.SDiv
+  %c0 = spirv.Constant 0 : i32
+  %cn1 = spirv.Constant -1 : i32
+  %min_i32 = spirv.Constant -2147483648 : i32
+
+  %0 = spirv.SDiv %cn1, %c0 : i32
+  %1 = spirv.SDiv %min_i32, %cn1 : i32
+  return %0, %1 : i32, i32
+}
+
+// CHECK-LABEL: @const_fold_scalar_sdiv
+func.func @const_fold_scalar_sdiv() -> (i32, i32, i32, i32) {
+  %c56 = spirv.Constant 56 : i32
+  %c7 = spirv.Constant 7 : i32
+  %cn8 = spirv.Constant -8 : i32
+  %c3 = spirv.Constant 3 : i32
+  %cn3 = spirv.Constant -3 : i32
+
+  // CHECK-DAG: spirv.Constant -18
+  // CHECK-DAG: spirv.Constant -2
+  // CHECK-DAG: spirv.Constant -7
+  // CHECK-DAG: spirv.Constant 8
+  %0 = spirv.SDiv %c56, %c7 : i32
+  %1 = spirv.SDiv %c56, %cn8 : i32
+  %2 = spirv.SDiv %cn8, %c3 : i32
+  %3 = spirv.SDiv %c56, %cn3 : i32
+  return %0, %1, %2, %3: i32, i32, i32, i32
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.SMod
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @smod_x_1
+func.func @smod_x_1(%arg0 : i32) -> i32 {
+  // CHECK: spirv.Constant 0
+  %c1 = spirv.Constant 1 : i32
+  %0 = spirv.SMod %arg0, %c1: i32
+  return %0 : i32
+}
+
+// CHECK-LABEL: @smod_div_0_or_overflow
+func.func @smod_div_0_or_overflow() -> (i32, i32) {
+  // CHECK: spirv.SMod
+  // CHECK: spirv.SMod
+  %c0 = spirv.Constant 0 : i32
+  %cn1 = spirv.Constant -1 : i32
+  %min_i32 = spirv.Constant -2147483648 : i32
+
+  %0 = spirv.SMod %cn1, %c0 : i32
+  %1 = spirv.SMod %min_i32, %cn1 : i32
+  return %0, %1 : i32, i32
+}
+
+// CHECK-LABEL: @const_fold_scalar_smod
+func.func @const_fold_scalar_smod() -> (i32, i32, i32, i32, i32, i32, i32, i32) {
+  %c56 = spirv.Constant 56 : i32
+  %cn56 = spirv.Constant -56 : i32
+  %c59 = spirv.Constant 59 : i32
+  %cn59 = spirv.Constant -59 : i32
+  %c7 = spirv.Constant 7 : i32
+  %cn8 = spirv.Constant -8 : i32
+  %c3 = spirv.Constant 3 : i32
+  %cn3 = spirv.Constant -3 : i32
+
+  // CHECK-DAG: %[[ZERO:.*]] = spirv.Constant 0 : i32
+  // CHECK-DAG: %[[TWO:.*]] = spirv.Constant 2 : i32
+  // CHECK-DAG: %[[FIFTYTHREE:.*]] = spirv.Constant 53 : i32
+  // CHECK-DAG: %[[NFIFTYTHREE:.*]] = spirv.Constant -53 : i32
+  // CHECK-DAG: %[[THREE:.*]] = spirv.Constant 3 : i32
+  // CHECK-DAG: %[[NTHREE:.*]] = spirv.Constant -3 : i32
+  %0 = spirv.SMod %c56, %c7 : i32
+  %1 = spirv.SMod %c56, %cn8 : i32
+  %2 = spirv.SMod %c56, %c3 : i32
+  %3 = spirv.SMod %cn3, %c56 : i32
+  %4 = spirv.SMod %cn3, %cn56 : i32
+  %5 = spirv.SMod %c59, %c56 : i32
+  %6 = spirv.SMod %c59, %cn56 : i32
+  %7 = spirv.SMod %cn59, %cn56 : i32
+
+  // CHECK: return %[[ZERO]], %[[ZERO]], %[[TWO]], %[[FIFTYTHREE]], %[[NTHREE]], %[[THREE]], %[[NFIFTYTHREE]], %[[NTHREE]]
+  return %0, %1, %2, %3, %4, %5, %6, %7 : i32, i32, i32, i32, i32, i32, i32, i32
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.SNegate
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @snegate_twice
+func.func @snegate_twice(%arg0 : i32) -> i32 {
+  // CHECK-NEXT: return %arg0 : i32
+  %0 = spirv.SNegate %arg0 : i32
+  %1 = spirv.SNegate %0 : i32
+  return %1 : i32
+}
+
+// CHECK-LABEL: @const_fold_scalar_snegate
+func.func @const_fold_scalar_snegate() -> (i32, i32, i32) {
+  %c0 = spirv.Constant 0 : i32
+  %c3 = spirv.Constant 3 : i32
+  %cn3 = spirv.Constant -3 : i32
+
+  // CHECK-DAG: %[[THREE:.*]] = spirv.Constant 3 : i32
+  // CHECK-DAG: %[[NTHREE:.*]] = spirv.Constant -3 : i32
+  // CHECK-DAG: %[[ZERO:.*]] = spirv.Constant 0 : i32
+  %0 = spirv.SNegate %c0 : i32
+  %1 = spirv.SNegate %c3 : i32
+  %2 = spirv.SNegate %cn3 : i32
+
+  // CHECK: return %[[ZERO]], %[[NTHREE]], %[[THREE]]
+  return %0, %1, %2  : i32, i32, i32
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.SRem
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @srem_x_1
+func.func @srem_x_1(%arg0 : i32) -> i32 {
+  // CHECK: spirv.Constant 0
+  %c1 = spirv.Constant 1 : i32
+  %0 = spirv.SRem %arg0, %c1 : i32
+  return %0 : i32
+}
+
+// CHECK-LABEL: @srem_div_0_or_overflow
+func.func @srem_div_0_or_overflow() -> (i32, i32) {
+  // CHECK: spirv.SRem
+  // CHECK: spirv.SRem
+  %c0 = spirv.Constant 0 : i32
+  %cn1 = spirv.Constant -1 : i32
+  %min_i32 = spirv.Constant -2147483648 : i32
+
+  %0 = spirv.SRem %cn1, %c0 : i32
+  %1 = spirv.SRem %min_i32, %cn1 : i32
+  return %0, %1 : i32, i32
+}
+
+// CHECK-LABEL: @const_fold_scalar_srem
+func.func @const_fold_scalar_srem() -> (i32, i32, i32, i32, i32) {
+  %c56 = spirv.Constant 56 : i32
+  %c7 = spirv.Constant 7 : i32
+  %cn8 = spirv.Constant -8 : i32
+  %c3 = spirv.Constant 3 : i32
+  %cn3 = spirv.Constant -3 : i32
+
+  // CHECK-DAG: %[[ONE:.*]] = spirv.Constant 1 : i32
+  // CHECK-DAG: %[[NTHREE:.*]] = spirv.Constant -3 : i32
+  // CHECK-DAG: %[[TWO:.*]] = spirv.Constant 2 : i32
+  // CHECK-DAG: %[[ZERO:.*]] = spirv.Constant 0 : i32
+  %0 = spirv.SRem %c56, %c7 : i32
+  %1 = spirv.SRem %c56, %cn8 : i32
+  %2 = spirv.SRem %c56, %c3 : i32
+  %3 = spirv.SRem %cn3, %c56 : i32
+  %4 = spirv.SRem %c7, %cn3 : i32
+  // CHECK: return %[[ZERO]], %[[ZERO]], %[[TWO]], %[[NTHREE]], %[[ONE]]
+  return %0, %1, %2, %3, %4 : i32, i32, i32, i32, i32
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.UDiv
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @udiv_x_1
+func.func @udiv_x_1(%arg0 : i32) -> i32 {
+  // CHECK-NEXT: return %arg0 : i32
+  %c1 = spirv.Constant 1  : i32
+  %2 = spirv.UDiv %arg0, %c1: i32
+  return %2 : i32
+}
+
+// CHECK-LABEL: @udiv_div_0
+func.func @udiv_div_0() -> i32 {
+  // CHECK: spirv.UDiv
+  %c0 = spirv.Constant 0 : i32
+  %cn1 = spirv.Constant -1 : i32
+  %0 = spirv.UDiv %cn1, %c0 : i32
+  return %0 : i32
+}
+
+// CHECK-LABEL: @const_fold_scalar_udiv
+func.func @const_fold_scalar_udiv() -> (i32, i32, i32) {
+  %c56 = spirv.Constant 56 : i32
+  %c7 = spirv.Constant 7 : i32
+  %cn8 = spirv.Constant -8 : i32
+  %c3 = spirv.Constant 3 : i32
+
+  // CHECK-DAG: spirv.Constant 0
+  // CHECK-DAG: spirv.Constant 1431655762
+  // CHECK-DAG: spirv.Constant 8
+  %0 = spirv.UDiv %c56, %c7 : i32
+  %1 = spirv.UDiv %cn8, %c3 : i32
+  %2 = spirv.UDiv %c56, %cn8 : i32
+  return %0, %1, %2 : i32, i32, i32
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.UMod
 //===----------------------------------------------------------------------===//
 
+// CHECK-LABEL: @umod_x_1
+func.func @umod_x_1(%arg0 : i32) -> i32 {
+  // CHECK-NEXT: return %arg0 : i32
+  %c1 = spirv.Constant 1  : i32
+  %2 = spirv.UMod %arg0, %c1: i32
+  return %2 : i32
+}
+
+// CHECK-LABEL: @umod_div_0
+func.func @umod_div_0() -> i32 {
+  // CHECK: spirv.UMod
+  %c0 = spirv.Constant 0 : i32
+  %cn1 = spirv.Constant -1 : i32
+  %0 = spirv.UMod %cn1, %c0 : i32
+  return %0 : i32
+}
+
+// CHECK-LABEL: @const_fold_scalar_umod
+func.func @const_fold_scalar_umod() -> (i32, i32, i32) {
+  %c56 = spirv.Constant 56 : i32
+  %c7 = spirv.Constant 7 : i32
+  %cn8 = spirv.Constant -8 : i32
+  %c3 = spirv.Constant 3 : i32
+
+  // CHECK-DAG: spirv.Constant 0
+  // CHECK-DAG: spirv.Constant 2
+  // CHECK-DAG: spirv.Constant 56
+  %0 = spirv.UMod %c56, %c7 : i32
+  %1 = spirv.UMod %cn8, %c3 : i32
+  %2 = spirv.UMod %c56, %cn8 : i32
+  return %0, %1, %2 : i32, i32, i32
+}
+
 // CHECK-LABEL: @umod_fold
 // CHECK-SAME: (%[[ARG:.*]]: i32)
 func.func @umod_fold(%arg0: i32) -> (i32, i32) {

>From 712976831f12a504b3903ef83e384724d4639039 Mon Sep 17 00:00:00 2001
From: inbelic <canadienfinn at gmail.com>
Date: Mon, 6 Nov 2023 19:11:27 +0100
Subject: [PATCH 2/2] add IAddCarry/[SU]MulExtended as canonicalization

- do not have the ability to create a struct constant as it is not yet
  implemented, so we use CompositeConstruct instead; hence need to
  create a new operation and can't use fold
---
 .../Dialect/SPIRV/IR/SPIRVArithmeticOps.td    |   6 +
 .../SPIRV/IR/SPIRVCanonicalization.cpp        | 187 ++++++++++++++++++
 .../SPIRV/Transforms/canonicalize.mlir        | 148 ++++++++++++++
 3 files changed, 341 insertions(+)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
index a4d342addb86f81..a73989c41c04cfb 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
@@ -379,6 +379,8 @@ def SPIRV_IAddCarryOp : SPIRV_ArithmeticExtendedBinaryOp<"IAddCarry",
     %2 = spirv.IAddCarry %0, %1 : !spirv.struct<(vector<2xi32>, vector<2xi32>)>
     ```
   }];
+
+  let hasCanonicalizer = 1;
 }
 
 // -----
@@ -611,6 +613,8 @@ def SPIRV_SMulExtendedOp : SPIRV_ArithmeticExtendedBinaryOp<"SMulExtended",
     %2 = spirv.SMulExtended %0, %1 : !spirv.struct<(vector<2xi32>, vector<2xi32>)>
     ```
   }];
+
+  let hasCanonicalizer = 1;
 }
 
 // -----
@@ -752,6 +756,8 @@ def SPIRV_UMulExtendedOp : SPIRV_ArithmeticExtendedBinaryOp<"UMulExtended",
     %2 = spirv.UMulExtended %0, %1 : !spirv.struct<(vector<2xi32>, vector<2xi32>)>
     ```
   }];
+
+  let hasCanonicalizer = 1;
 }
 
 // -----
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index f596b5d1cfcbfbe..9f55f2807e458fb 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -123,6 +123,193 @@ void spirv::AccessChainOp::getCanonicalizationPatterns(
   results.add<CombineChainedAccessChain>(context);
 }
 
+//===----------------------------------------------------------------------===//
+// spirv.IAddCarry
+//===----------------------------------------------------------------------===//
+
+// We are required to use CompositeConstructOp to create a constant struct as
+// they are not yet implemented as constant, hence we can not do so in a fold.
+struct IAddCarryFold final : OpRewritePattern<spirv::IAddCarryOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(spirv::IAddCarryOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    auto operands = op.getOperands();
+
+    SmallVector<Value> constituents;
+    Type constituentType = operands[0].getType();
+
+    // iaddcarry (x, 0) = <0, x>
+    if (matchPattern(operands[1], m_Zero())) {
+      constituents.push_back(operands[1]);
+      constituents.push_back(operands[0]);
+      rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
+                                                               constituents);
+      return success();
+    }
+
+    // According to the SPIR-V spec:
+    //
+    //  Result Type must be from OpTypeStruct.  The struct must have two
+    //  members...
+    //
+    //  Member 0 of the result gets the low-order bits (full component width) of
+    //  the addition.
+    //
+    //  Member 1 of the result gets the high-order (carry) bit of the result of
+    //  the addition. That is, it gets the value 1 if the addition overflowed
+    //  the component width, and 0 otherwise.
+    Attribute lhs;
+    Attribute rhs;
+    if (!matchPattern(operands[0], m_Constant(&lhs)) ||
+        !matchPattern(operands[1], m_Constant(&rhs)))
+      return failure();
+
+    auto adds = constFoldBinaryOp<IntegerAttr>(
+        {lhs, rhs}, [](const APInt &a, const APInt &b) { return a + b; });
+    auto carrys = constFoldBinaryOp<IntegerAttr>(
+        ArrayRef{lhs, rhs}, [](const APInt &a, const APInt &b) {
+          APInt zero = APInt::getZero(a.getBitWidth());
+          return a.ult(b) ? zero : (zero + 1);
+        });
+
+    if (!adds || !carrys)
+      return failure();
+
+    Value carrysVal =
+        rewriter.create<spirv::ConstantOp>(loc, constituentType, carrys);
+    constituents.push_back(carrysVal);
+
+    Value addsVal =
+        rewriter.create<spirv::ConstantOp>(loc, constituentType, adds);
+    constituents.push_back(addsVal);
+
+    rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
+                                                             constituents);
+    return success();
+  }
+};
+
+void spirv::IAddCarryOp::getCanonicalizationPatterns(
+    RewritePatternSet &patterns, MLIRContext *context) {
+  patterns.add<IAddCarryFold>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.[SU]MulExtended
+//===----------------------------------------------------------------------===//
+
+// We are required to use CompositeConstructOp to create a constant struct as
+// they are not yet implemented as constant, hence we can not do so in a fold.
+template <typename MulOp, bool IsSigned>
+struct MulExtendedFold final : OpRewritePattern<MulOp> {
+  using OpRewritePattern<MulOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(MulOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    auto operands = op.getOperands();
+
+    SmallVector<Value> constituents;
+    Type constituentType = operands[0].getType();
+
+    // [su]mulextended (x, 0) = <0, 0>
+    if (matchPattern(operands[1], m_Zero())) {
+      Value zero = spirv::ConstantOp::getZero(constituentType, loc, rewriter);
+      constituents.push_back(zero);
+      constituents.push_back(zero);
+      rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
+                                                               constituents);
+      return success();
+    }
+
+    // According to the SPIR-V spec:
+    //
+    // Result Type must be from OpTypeStruct.  The struct must have two
+    // members...
+    //
+    // Member 0 of the result gets the low-order bits of the multiplication.
+    //
+    // Member 1 of the result gets the high-order bits of the multiplication.
+    Attribute lhs;
+    Attribute rhs;
+    if (!matchPattern(operands[0], m_Constant(&lhs)) ||
+        !matchPattern(operands[1], m_Constant(&rhs)))
+      return failure();
+
+    auto lowBits = constFoldBinaryOp<IntegerAttr>(
+        {lhs, rhs}, [](const APInt &a, const APInt &b) { return a * b; });
+
+    if (!lowBits)
+      return failure();
+
+    auto highBits = constFoldBinaryOp<IntegerAttr>(
+        {lhs, rhs}, [](const APInt &a, const APInt &b) {
+          unsigned bitWidth = a.getBitWidth();
+          APInt c;
+          if (IsSigned) {
+            c = a.sext(bitWidth * 2) * b.sext(bitWidth * 2);
+          } else {
+            c = a.zext(bitWidth * 2) * b.zext(bitWidth * 2);
+          }
+          return c.extractBits(bitWidth, bitWidth); // Extract high result
+        });
+
+    if (!highBits)
+      return failure();
+
+    Value lowBitsVal =
+        rewriter.create<spirv::ConstantOp>(loc, constituentType, lowBits);
+    constituents.push_back(lowBitsVal);
+
+    Value highBitsVal =
+        rewriter.create<spirv::ConstantOp>(loc, constituentType, highBits);
+    constituents.push_back(highBitsVal);
+
+    rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
+                                                             constituents);
+    return success();
+  }
+};
+
+using SMulExtendedOpFold = MulExtendedFold<spirv::SMulExtendedOp, true>;
+void spirv::SMulExtendedOp::getCanonicalizationPatterns(
+    RewritePatternSet &patterns, MLIRContext *context) {
+  patterns.add<SMulExtendedOpFold>(context);
+}
+
+struct UMulExtendedOpXOne final : OpRewritePattern<spirv::UMulExtendedOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(spirv::UMulExtendedOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    auto operands = op.getOperands();
+
+    SmallVector<Value> constituents;
+    Type constituentType = operands[0].getType();
+
+    // umulextended (x, 1) = <x, 0>
+    if (matchPattern(operands[1], m_One())) {
+      Value zero = spirv::ConstantOp::getZero(constituentType, loc, rewriter);
+      constituents.push_back(operands[0]);
+      constituents.push_back(zero);
+      rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
+                                                               constituents);
+      return success();
+    }
+
+    return failure();
+  }
+};
+
+using UMulExtendedOpFold = MulExtendedFold<spirv::UMulExtendedOp, false>;
+void spirv::UMulExtendedOp::getCanonicalizationPatterns(
+    RewritePatternSet &patterns, MLIRContext *context) {
+  patterns.add<UMulExtendedOpFold, UMulExtendedOpXOne>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // spirv.UMod
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
index 93eceb036a93a23..84430e2690339ad 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
@@ -336,6 +336,52 @@ func.func @iadd_poison(%arg0: i32) -> i32 {
 
 // -----
 
+//===----------------------------------------------------------------------===//
+// spirv.IAddCarry
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @iaddcarry_x_0
+func.func @iaddcarry_x_0(%arg0 : i32) -> !spirv.struct<(i32, i32)> {
+  %c0 = spirv.Constant 0 : i32
+
+  // CHECK: spirv.CompositeConstruct
+  %0 = spirv.IAddCarry %arg0, %c0 : !spirv.struct<(i32, i32)>
+  return %0 : !spirv.struct<(i32, i32)>
+}
+
+// CHECK-LABEL: @const_fold_scalar_iaddcarry
+func.func @const_fold_scalar_iaddcarry() -> (!spirv.struct<(i32, i32)>, !spirv.struct<(i32, i32)>) {
+  %c5 = spirv.Constant 5 : i32
+  %cn5 = spirv.Constant -5 : i32
+  %cn8 = spirv.Constant -8 : i32
+
+  // CHECK-DAG: spirv.Constant 0
+  // CHECK-DAG: spirv.Constant -3
+  // CHECK-DAG: spirv.CompositeConstruct
+  // CHECK-DAG: spirv.Constant 1
+  // CHECK-DAG: spirv.Constant -13
+  // CHECK-DAG: spirv.CompositeConstruct
+  %0 = spirv.IAddCarry %c5, %cn8 : !spirv.struct<(i32, i32)>
+  %1 = spirv.IAddCarry %cn5, %cn8 : !spirv.struct<(i32, i32)>
+
+  return %0, %1 : !spirv.struct<(i32, i32)>, !spirv.struct<(i32, i32)>
+}
+
+// CHECK-LABEL: @const_fold_vector_iaddcarry
+func.func @const_fold_vector_iaddcarry() -> !spirv.struct<(vector<3xi32>, vector<3xi32>)> {
+  %v0 = spirv.Constant dense<[5, -5, -1]> : vector<3xi32>
+  %v1 = spirv.Constant dense<[-8, -8, 1]> : vector<3xi32>
+
+  // CHECK: spirv.Constant dense<[0, 1, 1]>
+  // CHECK-NEXT: spirv.Constant dense<[-3, -13, 0]>
+  // CHECK-NEXT: spirv.CompositeConstruct
+  %0 = spirv.IAddCarry %v0, %v1 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+  return %0 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.IMul
 //===----------------------------------------------------------------------===//
@@ -400,6 +446,108 @@ func.func @const_fold_vector_imul() -> vector<3xi32> {
 
 // -----
 
+//===----------------------------------------------------------------------===//
+// spirv.SMulExtended
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @smulextended_x_0
+func.func @smulextended_x_0(%arg0 : i32) -> !spirv.struct<(i32, i32)> {
+  %c0 = spirv.Constant 0 : i32
+
+  // CHECK: spirv.CompositeConstruct
+  %0 = spirv.SMulExtended %arg0, %c0 : !spirv.struct<(i32, i32)>
+  return %0 : !spirv.struct<(i32, i32)>
+}
+
+// CHECK-LABEL: @const_fold_scalar_smulextended
+func.func @const_fold_scalar_smulextended() -> (!spirv.struct<(i32, i32)>, !spirv.struct<(i32, i32)>) {
+  %c5 = spirv.Constant 5 : i32
+  %cn5 = spirv.Constant -5 : i32
+  %cn8 = spirv.Constant -8 : i32
+
+  // CHECK-DAG: spirv.Constant -40
+  // CHECK-DAG: spirv.Constant -1
+  // CHECK-DAG: spirv.CompositeConstruct
+  // CHECK-DAG: spirv.Constant 40
+  // CHECK-DAG: spirv.Constant 0
+  // CHECK-DAG: spirv.CompositeConstruct
+  %0 = spirv.SMulExtended %c5, %cn8 : !spirv.struct<(i32, i32)>
+  %1 = spirv.SMulExtended %cn5, %cn8 : !spirv.struct<(i32, i32)>
+
+  return %0, %1 : !spirv.struct<(i32, i32)>, !spirv.struct<(i32, i32)>
+}
+
+// CHECK-LABEL: @const_fold_vector_smulextended
+func.func @const_fold_vector_smulextended() -> !spirv.struct<(vector<3xi32>, vector<3xi32>)> {
+  %v0 = spirv.Constant dense<[2147483647, -5, -1]> : vector<3xi32>
+  %v1 = spirv.Constant dense<[5, -8, 1]> : vector<3xi32>
+
+  // CHECK: spirv.Constant dense<[2147483643, 40, -1]>
+  // CHECK-NEXT: spirv.Constant dense<[2, 0, -1]>
+  // CHECK-NEXT: spirv.CompositeConstruct
+  %0 = spirv.SMulExtended %v0, %v1 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+  return %0 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.UMulExtended
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @umulextended_x_0
+func.func @umulextended_x_0(%arg0 : i32) -> !spirv.struct<(i32, i32)> {
+  %c0 = spirv.Constant 0 : i32
+
+  // CHECK: spirv.CompositeConstruct
+  %0 = spirv.UMulExtended %arg0, %c0 : !spirv.struct<(i32, i32)>
+  return %0 : !spirv.struct<(i32, i32)>
+}
+
+// CHECK-LABEL: @umulextended_x_1
+func.func @umulextended_x_1(%arg0 : i32) -> !spirv.struct<(i32, i32)> {
+  %c0 = spirv.Constant 1 : i32
+
+  // CHECK: spirv.CompositeConstruct
+  %0 = spirv.UMulExtended %arg0, %c0 : !spirv.struct<(i32, i32)>
+  return %0 : !spirv.struct<(i32, i32)>
+}
+
+// CHECK-LABEL: @const_fold_scalar_umulextended
+func.func @const_fold_scalar_umulextended() -> (!spirv.struct<(i32, i32)>, !spirv.struct<(i32, i32)>) {
+  %c5 = spirv.Constant 5 : i32
+  %cn5 = spirv.Constant -5 : i32
+  %cn8 = spirv.Constant -8 : i32
+
+  // CHECK-DAG: spirv.Constant 40
+  // CHECK-DAG: spirv.Constant -13
+  // CHECK-DAG: spirv.CompositeConstruct
+  // CHECK-DAG: spirv.Constant -40
+  // CHECK-DAG: spirv.Constant 4
+  // CHECK-DAG: spirv.CompositeConstruct
+  %0 = spirv.UMulExtended %c5, %cn8 : !spirv.struct<(i32, i32)>
+  %1 = spirv.UMulExtended %cn5, %cn8 : !spirv.struct<(i32, i32)>
+
+  return %0, %1 : !spirv.struct<(i32, i32)>, !spirv.struct<(i32, i32)>
+}
+
+// CHECK-LABEL: @const_fold_vector_umulextended
+func.func @const_fold_vector_umulextended() -> !spirv.struct<(vector<3xi32>, vector<3xi32>)> {
+  %v0 = spirv.Constant dense<[2147483647, -5, -1]> : vector<3xi32>
+  %v1 = spirv.Constant dense<[5, -8, 1]> : vector<3xi32>
+
+  // CHECK: spirv.Constant dense<[2147483643, 40, -1]>
+  // CHECK-NEXT: spirv.Constant dense<[2, -13, 0]>
+  // CHECK-NEXT: spirv.CompositeConstruct
+  %0 = spirv.UMulExtended %v0, %v1 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+  return %0 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+
+}
+
+// -----
+
+
 //===----------------------------------------------------------------------===//
 // spirv.ISub
 //===----------------------------------------------------------------------===//



More information about the Mlir-commits mailing list