[Mlir-commits] [mlir] [mlir][spirv] Add folding for [S|U]Mod, [S|U]Div, SRem (PR #73341)

Finn Plummer llvmlistbot at llvm.org
Fri Nov 24 07:30:47 PST 2023


https://github.com/inbelic created https://github.com/llvm/llvm-project/pull/73341

Add missing constant propogation folder for [S|U]Mod, [S|U]Div, SRem

Implement additional folding when rhs is 1 for all ops.

This helps for readability of lowered code into SPIR-V.

Part of work for #70704

>From 1429c564002e251ef6431bbfe2d8d36d09967d70 Mon Sep 17 00:00:00 2001
From: inbelic <canadienfinn at gmail.com>
Date: Fri, 24 Nov 2023 09:54:24 +0100
Subject: [PATCH] [mlir][spirv] Add folding for [S|U]Mod, [S|U]Div, SRem

Add missing constant propogation folder for [S|U]Mod, [S|U]Div, SRem

Implement additional folding when rhs is 1 for all ops.

This helps for readability of lowered code into SPIR-V.

Part of work for #70704
---
 .../Dialect/SPIRV/IR/SPIRVArithmeticOps.td    |   9 +
 .../SPIRV/IR/SPIRVCanonicalization.cpp        | 160 +++++++++++
 .../SPIRV/Transforms/canonicalize.mlir        | 262 ++++++++++++++++++
 3 files changed, 431 insertions(+)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
index c4d1e01f9feef83..16bf173cb7971e0 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
@@ -534,6 +534,8 @@ def SPIRV_SDivOp : SPIRV_ArithmeticBinaryOp<"SDiv",
 
     ```
   }];
+
+  let hasFolder = 1;
 }
 
 // -----
@@ -573,6 +575,8 @@ def SPIRV_SModOp : SPIRV_ArithmeticBinaryOp<"SMod",
 
     ```
   }];
+
+  let hasFolder = 1;
 }
 
 // -----
@@ -673,6 +677,8 @@ def SPIRV_SRemOp : SPIRV_ArithmeticBinaryOp<"SRem",
 
     ```
   }];
+
+  let hasFolder = 1;
 }
 
 // -----
@@ -707,6 +713,8 @@ def SPIRV_UDivOp : SPIRV_ArithmeticBinaryOp<"UDiv",
 
     ```
   }];
+
+  let hasFolder = 1;
 }
 
 // -----
@@ -811,6 +819,7 @@ def SPIRV_UModOp : SPIRV_ArithmeticBinaryOp<"UMod",
     ```
   }];
 
+  let hasFolder = 1;
   let hasCanonicalizer = 1;
 }
 
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index 9acd982dc95af6d..8144a100dab3495 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -69,6 +69,14 @@ static Attribute extractCompositeElement(Attribute composite,
   return {};
 }
 
+static bool isDivZeroOrOverflow(const APInt &a, const APInt &b) {
+  bool div0 = b.isZero();
+
+  bool overflow = a.isMinSignedValue() && b.isAllOnes();
+
+  return div0 || overflow;
+}
+
 //===----------------------------------------------------------------------===//
 // TableGen'erated canonicalizers
 //===----------------------------------------------------------------------===//
@@ -290,6 +298,158 @@ OpFoldResult spirv::ISubOp::fold(FoldAdaptor adaptor) {
       [](APInt a, const APInt &b) { return std::move(a) - b; });
 }
 
+//===----------------------------------------------------------------------===//
+// spirv.SDiv
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::SDivOp::fold(FoldAdaptor adaptor) {
+  // sdiv (x, 1) = x
+  if (matchPattern(getOperand2(), m_One()))
+    return getOperand1();
+
+  // According to the SPIR-V spec:
+  //
+  // Signed-integer division of Operand 1 divided by Operand 2.
+  // Results are computed per component. Behavior is undefined if Operand 2 is
+  // 0. Behavior is undefined if Operand 2 is -1 and Operand 1 is the minimum
+  // representable value for the operands' type, causing signed overflow.
+  //
+  // So don't fold during undefined behaviour.
+  bool div0OrOverflow = false;
+  auto res = constFoldBinaryOp<IntegerAttr>(
+      adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
+        if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
+          div0OrOverflow = true;
+          return a;
+        }
+        return a.sdiv(b);
+      });
+  return div0OrOverflow ? Attribute() : res;
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.SMod
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::SModOp::fold(FoldAdaptor adaptor) {
+  // smod (x, 1) = 0
+  if (matchPattern(getOperand2(), m_One()))
+    return Builder(getContext()).getZeroAttr(getType());
+
+  // According to SPIR-V spec:
+  //
+  // Signed remainder operation for the remainder whose sign matches the sign
+  // of Operand 2. Behavior is undefined if Operand 2 is 0. Behavior is
+  // undefined if Operand 2 is -1 and Operand 1 is the minimum representable
+  // value for the operands' type, causing signed overflow. Otherwise, the
+  // result is the remainder r of Operand 1 divided by Operand 2 where if
+  // r ≠ 0, the sign of r is the same as the sign of Operand 2.
+  //
+  // So don't fold during undefined behaviour
+  bool div0OrOverflow = false;
+  auto res = constFoldBinaryOp<IntegerAttr>(
+      adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
+        if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
+          div0OrOverflow = true;
+          return a;
+        }
+        APInt c = a.abs().urem(b.abs());
+        if (c.isZero())
+          return c;
+        if (b.isNegative()) {
+          APInt zero = APInt::getZero(c.getBitWidth());
+          return a.isNegative() ? (zero - c) : (b + c);
+        }
+        return a.isNegative() ? (b - c) : c;
+      });
+  return div0OrOverflow ? Attribute() : res;
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.SRem
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::SRemOp::fold(FoldAdaptor adaptor) {
+  // x % 1 = 0
+  if (matchPattern(getOperand2(), m_One()))
+    return Builder(getContext()).getZeroAttr(getType());
+
+  // According to SPIR-V spec:
+  //
+  // Signed remainder operation for the remainder whose sign matches the sign
+  // of Operand 1. Behavior is undefined if Operand 2 is 0. Behavior is
+  // undefined if Operand 2 is -1 and Operand 1 is the minimum representable
+  // value for the operands' type, causing signed overflow. Otherwise, the
+  // result is the remainder r of Operand 1 divided by Operand 2 where if
+  // r ≠ 0, the sign of r is the same as the sign of Operand 1.
+
+  // Don't fold if it would do undefined behaviour.
+  bool div0OrOverflow = false;
+  auto res = constFoldBinaryOp<IntegerAttr>(
+      adaptor.getOperands(), [&](APInt a, const APInt &b) {
+        if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
+          div0OrOverflow = true;
+          return a;
+        }
+        return a.srem(b);
+      });
+  return div0OrOverflow ? Attribute() : res;
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.UDiv
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::UDivOp::fold(FoldAdaptor adaptor) {
+  // udiv (x, 1) = x
+  if (matchPattern(getOperand2(), m_One()))
+    return getOperand1();
+
+  // According to the SPIR-V spec:
+  //
+  // Unsigned-integer division of Operand 1 divided by Operand 2. Behavior is
+  // undefined if Operand 2 is 0.
+  //
+  // So don't fold during undefined behaviour.
+  bool div0 = false;
+  auto res = constFoldBinaryOp<IntegerAttr>(
+      adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
+        if (div0 || b.isZero()) {
+          div0 = true;
+          return a;
+        }
+        return a.udiv(b);
+      });
+  return div0 ? Attribute() : res;
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.UMod
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::UModOp::fold(FoldAdaptor adaptor) {
+  // umod (x, 1) = 0
+  if (matchPattern(getOperand2(), m_One()))
+    return Builder(getContext()).getZeroAttr(getType());
+
+  // According to the SPIR-V spec:
+  //
+  // Unsigned modulo operation of Operand 1 modulo Operand 2. Behavior is
+  // undefined if Operand 2 is 0.
+  //
+  // So don't fold during undefined behaviour.
+  bool div0 = false;
+  auto res = constFoldBinaryOp<IntegerAttr>(
+      adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
+        if (div0 || b.isZero()) {
+          div0 = true;
+          return a;
+        }
+        return a.urem(b);
+      });
+  return div0 ? Attribute() : res;
+}
+
 //===----------------------------------------------------------------------===//
 // spirv.LogicalAnd
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
index 0200805a444397a..7b1163601e1b427 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
@@ -462,10 +462,272 @@ func.func @const_fold_vector_isub() -> vector<3xi32> {
 
 // -----
 
+//===----------------------------------------------------------------------===//
+// spirv.SDiv
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @sdiv_x_1
+func.func @sdiv_x_1(%arg0 : i32) -> i32 {
+  // CHECK-NEXT: return %arg0 : i32
+  %c1 = spirv.Constant 1  : i32
+  %2 = spirv.SDiv %arg0, %c1: i32
+  return %2 : i32
+}
+
+// CHECK-LABEL: @sdiv_div_0_or_overflow
+func.func @sdiv_div_0_or_overflow() -> (i32, i32) {
+  // CHECK: spirv.SDiv
+  // CHECK: spirv.SDiv
+  %c0 = spirv.Constant 0 : i32
+  %cn1 = spirv.Constant -1 : i32
+  %min_i32 = spirv.Constant -2147483648 : i32
+
+  %0 = spirv.SDiv %cn1, %c0 : i32
+  %1 = spirv.SDiv %min_i32, %cn1 : i32
+  return %0, %1 : i32, i32
+}
+
+// CHECK-LABEL: @const_fold_scalar_sdiv
+func.func @const_fold_scalar_sdiv() -> (i32, i32, i32, i32) {
+  %c56 = spirv.Constant 56 : i32
+  %c7 = spirv.Constant 7 : i32
+  %cn8 = spirv.Constant -8 : i32
+  %c3 = spirv.Constant 3 : i32
+  %cn3 = spirv.Constant -3 : i32
+
+  // CHECK-DAG: spirv.Constant -18
+  // CHECK-DAG: spirv.Constant -2
+  // CHECK-DAG: spirv.Constant -7
+  // CHECK-DAG: spirv.Constant 8
+  %0 = spirv.SDiv %c56, %c7 : i32
+  %1 = spirv.SDiv %c56, %cn8 : i32
+  %2 = spirv.SDiv %cn8, %c3 : i32
+  %3 = spirv.SDiv %c56, %cn3 : i32
+  return %0, %1, %2, %3: i32, i32, i32, i32
+}
+
+// CHECK-LABEL: @const_fold_vector_sdiv
+func.func @const_fold_vector_sdiv() -> vector<3xi32> {
+  // CHECK: spirv.Constant dense<[0, -1, -3]>
+
+  %cv_num = spirv.Constant dense<[42, 24, -16]> : vector<3xi32>
+  %cv_denom = spirv.Constant dense<[76, -24, 5]> : vector<3xi32>
+  %0 = spirv.SDiv %cv_num, %cv_denom : vector<3xi32>
+  return %0 : vector<3xi32>
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.SMod
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @smod_x_1
+func.func @smod_x_1(%arg0: i32, %arg1: vector<3xi32>) -> (i32, vector<3xi32>) {
+  // CHECK-DAG: spirv.Constant 0
+  // CHECK-DAG: spirv.Constant dense<0>
+  %c1 = spirv.Constant 1 : i32
+  %cv1 = spirv.Constant dense<1> : vector<3xi32>
+  %0 = spirv.SMod %arg0, %c1: i32
+  %1 = spirv.SMod %arg1, %cv1: vector<3xi32>
+  return %0, %1 : i32, vector<3xi32>
+}
+
+// CHECK-LABEL: @smod_div_0_or_overflow
+func.func @smod_div_0_or_overflow() -> (i32, i32) {
+  // CHECK: spirv.SMod
+  // CHECK: spirv.SMod
+  %c0 = spirv.Constant 0 : i32
+  %cn1 = spirv.Constant -1 : i32
+  %min_i32 = spirv.Constant -2147483648 : i32
+
+  %0 = spirv.SMod %cn1, %c0 : i32
+  %1 = spirv.SMod %min_i32, %cn1 : i32
+  return %0, %1 : i32, i32
+}
+
+// CHECK-LABEL: @const_fold_scalar_smod
+func.func @const_fold_scalar_smod() -> (i32, i32, i32, i32, i32, i32, i32, i32) {
+  %c56 = spirv.Constant 56 : i32
+  %cn56 = spirv.Constant -56 : i32
+  %c59 = spirv.Constant 59 : i32
+  %cn59 = spirv.Constant -59 : i32
+  %c7 = spirv.Constant 7 : i32
+  %cn8 = spirv.Constant -8 : i32
+  %c3 = spirv.Constant 3 : i32
+  %cn3 = spirv.Constant -3 : i32
+
+  // CHECK-DAG: %[[ZERO:.*]] = spirv.Constant 0 : i32
+  // CHECK-DAG: %[[TWO:.*]] = spirv.Constant 2 : i32
+  // CHECK-DAG: %[[FIFTYTHREE:.*]] = spirv.Constant 53 : i32
+  // CHECK-DAG: %[[NFIFTYTHREE:.*]] = spirv.Constant -53 : i32
+  // CHECK-DAG: %[[THREE:.*]] = spirv.Constant 3 : i32
+  // CHECK-DAG: %[[NTHREE:.*]] = spirv.Constant -3 : i32
+  %0 = spirv.SMod %c56, %c7 : i32
+  %1 = spirv.SMod %c56, %cn8 : i32
+  %2 = spirv.SMod %c56, %c3 : i32
+  %3 = spirv.SMod %cn3, %c56 : i32
+  %4 = spirv.SMod %cn3, %cn56 : i32
+  %5 = spirv.SMod %c59, %c56 : i32
+  %6 = spirv.SMod %c59, %cn56 : i32
+  %7 = spirv.SMod %cn59, %cn56 : i32
+
+  // CHECK: return %[[ZERO]], %[[ZERO]], %[[TWO]], %[[FIFTYTHREE]], %[[NTHREE]], %[[THREE]], %[[NFIFTYTHREE]], %[[NTHREE]]
+  return %0, %1, %2, %3, %4, %5, %6, %7 : i32, i32, i32, i32, i32, i32, i32, i32
+}
+
+// CHECK-LABEL: @const_fold_vector_smod
+func.func @const_fold_vector_smod() -> vector<3xi32> {
+  // CHECK: spirv.Constant dense<[42, -4, 4]>
+
+  %cv = spirv.Constant dense<[42, 24, -16]> : vector<3xi32>
+  %cv_mod = spirv.Constant dense<[76, -7, 5]> : vector<3xi32>
+  %0 = spirv.SMod %cv, %cv_mod : vector<3xi32>
+  return %0 : vector<3xi32>
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.SRem
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @srem_x_1
+func.func @srem_x_1(%arg0: i32, %arg1: vector<3xi32>) -> (i32, vector<3xi32>) {
+  // CHECK-DAG: spirv.Constant 0
+  // CHECK-DAG: spirv.Constant dense<0>
+  %c1 = spirv.Constant 1 : i32
+  %cv1 = spirv.Constant dense<1> : vector<3xi32>
+  %0 = spirv.SRem %arg0, %c1: i32
+  %1 = spirv.SRem %arg1, %cv1: vector<3xi32>
+  return %0, %1 : i32, vector<3xi32>
+}
+
+// CHECK-LABEL: @srem_div_0_or_overflow
+func.func @srem_div_0_or_overflow() -> (i32, i32) {
+  // CHECK: spirv.SRem
+  // CHECK: spirv.SRem
+  %c0 = spirv.Constant 0 : i32
+  %cn1 = spirv.Constant -1 : i32
+  %min_i32 = spirv.Constant -2147483648 : i32
+
+  %0 = spirv.SRem %cn1, %c0 : i32
+  %1 = spirv.SRem %min_i32, %cn1 : i32
+  return %0, %1 : i32, i32
+}
+
+// CHECK-LABEL: @const_fold_scalar_srem
+func.func @const_fold_scalar_srem() -> (i32, i32, i32, i32, i32) {
+  %c56 = spirv.Constant 56 : i32
+  %c7 = spirv.Constant 7 : i32
+  %cn8 = spirv.Constant -8 : i32
+  %c3 = spirv.Constant 3 : i32
+  %cn3 = spirv.Constant -3 : i32
+
+  // CHECK-DAG: %[[ONE:.*]] = spirv.Constant 1 : i32
+  // CHECK-DAG: %[[NTHREE:.*]] = spirv.Constant -3 : i32
+  // CHECK-DAG: %[[TWO:.*]] = spirv.Constant 2 : i32
+  // CHECK-DAG: %[[ZERO:.*]] = spirv.Constant 0 : i32
+  %0 = spirv.SRem %c56, %c7 : i32
+  %1 = spirv.SRem %c56, %cn8 : i32
+  %2 = spirv.SRem %c56, %c3 : i32
+  %3 = spirv.SRem %cn3, %c56 : i32
+  %4 = spirv.SRem %c7, %cn3 : i32
+  // CHECK: return %[[ZERO]], %[[ZERO]], %[[TWO]], %[[NTHREE]], %[[ONE]]
+  return %0, %1, %2, %3, %4 : i32, i32, i32, i32, i32
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.UDiv
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @udiv_x_1
+func.func @udiv_x_1(%arg0 : i32) -> i32 {
+  // CHECK-NEXT: return %arg0 : i32
+  %c1 = spirv.Constant 1  : i32
+  %2 = spirv.UDiv %arg0, %c1: i32
+  return %2 : i32
+}
+
+// CHECK-LABEL: @udiv_div_0
+func.func @udiv_div_0() -> i32 {
+  // CHECK: spirv.UDiv
+  %c0 = spirv.Constant 0 : i32
+  %cn1 = spirv.Constant -1 : i32
+  %0 = spirv.UDiv %cn1, %c0 : i32
+  return %0 : i32
+}
+
+// CHECK-LABEL: @const_fold_scalar_udiv
+func.func @const_fold_scalar_udiv() -> (i32, i32, i32) {
+  %c56 = spirv.Constant 56 : i32
+  %c7 = spirv.Constant 7 : i32
+  %cn8 = spirv.Constant -8 : i32
+  %c3 = spirv.Constant 3 : i32
+
+  // CHECK-DAG: spirv.Constant 0
+  // CHECK-DAG: spirv.Constant 1431655762
+  // CHECK-DAG: spirv.Constant 8
+  %0 = spirv.UDiv %c56, %c7 : i32
+  %1 = spirv.UDiv %cn8, %c3 : i32
+  %2 = spirv.UDiv %c56, %cn8 : i32
+  return %0, %1, %2 : i32, i32, i32
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.UMod
 //===----------------------------------------------------------------------===//
 
+// CHECK-LABEL: @umod_x_1
+func.func @umod_x_1(%arg0: i32, %arg1: vector<3xi32>) -> (i32, vector<3xi32>) {
+  // CHECK-DAG: spirv.Constant 0
+  // CHECK-DAG: spirv.Constant dense<0>
+  %c1 = spirv.Constant 1 : i32
+  %cv1 = spirv.Constant dense<1> : vector<3xi32>
+  %0 = spirv.UMod %arg0, %c1: i32
+  %1 = spirv.UMod %arg1, %cv1: vector<3xi32>
+  return %0, %1 : i32, vector<3xi32>
+}
+
+// CHECK-LABEL: @umod_div_0
+func.func @umod_div_0() -> i32 {
+  // CHECK: spirv.UMod
+  %c0 = spirv.Constant 0 : i32
+  %cn1 = spirv.Constant -1 : i32
+  %0 = spirv.UMod %cn1, %c0 : i32
+  return %0 : i32
+}
+
+// CHECK-LABEL: @const_fold_scalar_umod
+func.func @const_fold_scalar_umod() -> (i32, i32, i32) {
+  %c56 = spirv.Constant 56 : i32
+  %c7 = spirv.Constant 7 : i32
+  %cn8 = spirv.Constant -8 : i32
+  %c3 = spirv.Constant 3 : i32
+
+  // CHECK-DAG: spirv.Constant 0
+  // CHECK-DAG: spirv.Constant 2
+  // CHECK-DAG: spirv.Constant 56
+  %0 = spirv.UMod %c56, %c7 : i32
+  %1 = spirv.UMod %cn8, %c3 : i32
+  %2 = spirv.UMod %c56, %cn8 : i32
+  return %0, %1, %2 : i32, i32, i32
+}
+
+// CHECK-LABEL: @const_fold_vector_umod
+func.func @const_fold_vector_umod() -> vector<3xi32> {
+  // CHECK: spirv.Constant dense<[42, 24, 0]>
+
+  %cv = spirv.Constant dense<[42, 24, -16]> : vector<3xi32>
+  %cv_mod = spirv.Constant dense<[76, -7, 5]> : vector<3xi32>
+  %0 = spirv.UMod %cv, %cv_mod : vector<3xi32>
+  return %0 : vector<3xi32>
+}
+
 // CHECK-LABEL: @umod_fold
 // CHECK-SAME: (%[[ARG:.*]]: i32)
 func.func @umod_fold(%arg0: i32) -> (i32, i32) {



More information about the Mlir-commits mailing list