[Mlir-commits] [mlir] 9b32886 - [mlir][Arithmetic] Use common constant fold function in RemSI and RemUI to cover splat.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Apr 22 02:20:33 PDT 2022


Author: jacquesguan
Date: 2022-04-22T09:20:18Z
New Revision: 9b32886e7e705bb28aab57682e612375075a0ad7

URL: https://github.com/llvm/llvm-project/commit/9b32886e7e705bb28aab57682e612375075a0ad7
DIFF: https://github.com/llvm/llvm-project/commit/9b32886e7e705bb28aab57682e612375075a0ad7.diff

LOG: [mlir][Arithmetic] Use common constant fold function in RemSI and RemUI to cover splat.

This patch replaces current fold function with the common constant fold funtion in order to cover the situation of constant splat.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D124236

Added: 
    

Modified: 
    mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
    mlir/test/Dialect/Arithmetic/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
index 5a104400c48a8..8f26e66394287 100644
--- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
+++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
@@ -444,23 +444,22 @@ OpFoldResult arith::FloorDivSIOp::fold(ArrayRef<Attribute> operands) {
 //===----------------------------------------------------------------------===//
 
 OpFoldResult arith::RemUIOp::fold(ArrayRef<Attribute> operands) {
-  auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
-  if (!rhs)
-    return {};
-  auto rhsValue = rhs.getValue();
-
-  // x % 1 = 0
-  if (rhsValue.isOneValue())
-    return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0));
+  // remui (x, 1) -> 0.
+  if (matchPattern(getRhs(), m_One()))
+    return Builder(getContext()).getZeroAttr(getType());
 
-  // Don't fold if it requires division by zero.
-  if (rhsValue.isNullValue())
-    return {};
+  // Don't fold if it would require a division by zero.
+  bool div0 = false;
+  auto result =
+      constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
+        if (div0 || b.isNullValue()) {
+          div0 = true;
+          return a;
+        }
+        return a.urem(b);
+      });
 
-  auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
-  if (!lhs)
-    return {};
-  return IntegerAttr::get(lhs.getType(), lhs.getValue().urem(rhsValue));
+  return div0 ? Attribute() : result;
 }
 
 //===----------------------------------------------------------------------===//
@@ -468,23 +467,22 @@ OpFoldResult arith::RemUIOp::fold(ArrayRef<Attribute> operands) {
 //===----------------------------------------------------------------------===//
 
 OpFoldResult arith::RemSIOp::fold(ArrayRef<Attribute> operands) {
-  auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
-  if (!rhs)
-    return {};
-  auto rhsValue = rhs.getValue();
-
-  // x % 1 = 0
-  if (rhsValue.isOneValue())
-    return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0));
+  // remsi (x, 1) -> 0.
+  if (matchPattern(getRhs(), m_One()))
+    return Builder(getContext()).getZeroAttr(getType());
 
-  // Don't fold if it requires division by zero.
-  if (rhsValue.isNullValue())
-    return {};
+  // Don't fold if it would require a division by zero.
+  bool div0 = false;
+  auto result =
+      constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
+        if (div0 || b.isNullValue()) {
+          div0 = true;
+          return a;
+        }
+        return a.srem(b);
+      });
 
-  auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
-  if (!lhs)
-    return {};
-  return IntegerAttr::get(lhs.getType(), lhs.getValue().srem(rhsValue));
+  return div0 ? Attribute() : result;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Arithmetic/canonicalize.mlir b/mlir/test/Dialect/Arithmetic/canonicalize.mlir
index 5ea88b3569e3d..1f6b47330f144 100644
--- a/mlir/test/Dialect/Arithmetic/canonicalize.mlir
+++ b/mlir/test/Dialect/Arithmetic/canonicalize.mlir
@@ -1319,3 +1319,49 @@ func.func @test_negf() -> (f32) {
   %0 = arith.negf %c : f32
   return %0: f32
 }
+
+// -----
+
+// CHECK-LABEL: @test_remui(
+// CHECK: %[[res:.+]] = arith.constant dense<[0, 0, 4, 2]> : vector<4xi32>
+// CHECK: return %[[res]]
+func @test_remui() -> (vector<4xi32>) {
+  %v1 = arith.constant dense<[9, 9, 9, 9]> : vector<4xi32>
+  %v2 = arith.constant dense<[1, 3, 5, 7]> : vector<4xi32>
+  %0 = arith.remui %v1, %v2 : vector<4xi32>
+  return %0 : vector<4xi32>
+}
+
+// // -----
+
+// CHECK-LABEL: @test_remui_1(
+// CHECK: %[[res:.+]] = arith.constant dense<0> : vector<4xi32>
+// CHECK: return %[[res]]
+func @test_remui_1(%arg : vector<4xi32>) -> (vector<4xi32>) {
+  %v = arith.constant dense<[1, 1, 1, 1]> : vector<4xi32>
+  %0 = arith.remui %arg, %v : vector<4xi32>
+  return %0 : vector<4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_remsi(
+// CHECK: %[[res:.+]] = arith.constant dense<[0, 0, 4, 2]> : vector<4xi32>
+// CHECK: return %[[res]]
+func @test_remsi() -> (vector<4xi32>) {
+  %v1 = arith.constant dense<[9, 9, 9, 9]> : vector<4xi32>
+  %v2 = arith.constant dense<[1, 3, 5, 7]> : vector<4xi32>
+  %0 = arith.remsi %v1, %v2 : vector<4xi32>
+  return %0 : vector<4xi32>
+}
+
+// // -----
+
+// CHECK-LABEL: @test_remsi_1(
+// CHECK: %[[res:.+]] = arith.constant dense<0> : vector<4xi32>
+// CHECK: return %[[res]]
+func @test_remsi_1(%arg : vector<4xi32>) -> (vector<4xi32>) {
+  %v = arith.constant dense<[1, 1, 1, 1]> : vector<4xi32>
+  %0 = arith.remsi %arg, %v : vector<4xi32>
+  return %0 : vector<4xi32>
+}


        


More information about the Mlir-commits mailing list