[Mlir-commits] [mlir] f97f946 - Canonicalize max/min operations on integers.

Eugene Zhulenev llvmlistbot at llvm.org
Tue Oct 19 05:26:05 PDT 2021


Author: bakhtiyar
Date: 2021-10-19T05:25:59-07:00
New Revision: f97f946839d18ca88b4e8f32d45e458f124bdf6b

URL: https://github.com/llvm/llvm-project/commit/f97f946839d18ca88b4e8f32d45e458f124bdf6b
DIFF: https://github.com/llvm/llvm-project/commit/f97f946839d18ca88b4e8f32d45e458f124bdf6b.diff

LOG: Canonicalize max/min operations on integers.

Reviewed By: ezhulenev

Differential Revision: https://reviews.llvm.org/D112051

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/test/Dialect/Standard/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 2b08d73a9d72b..4154a6f6614b1 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -751,6 +751,7 @@ def MaxSIOp : IntBinaryOp<"maxsi"> {
     %a = maxsi %b, %c : i64
     ```
   }];
+  let hasFolder = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -775,6 +776,7 @@ def MaxUIOp : IntBinaryOp<"maxui"> {
     %a = maxui %b, %c : i64
     ```
   }];
+  let hasFolder = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -824,6 +826,7 @@ def MinSIOp : IntBinaryOp<"minsi"> {
     %a = minsi %b, %c : i64
     ```
   }];
+  let hasFolder = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -848,6 +851,7 @@ def MinUIOp : IntBinaryOp<"minui"> {
     %a = minui %b, %c : i64
     ```
   }];
+  let hasFolder = 1;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index d4b1b1e98f99e..81678c912e48e 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -939,6 +939,106 @@ bool ConstantOp::isBuildableWith(Attribute value, Type type) {
   return value.isa<UnitAttr>();
 }
 
+//===----------------------------------------------------------------------===//
+// MaxSIOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult MaxSIOp::fold(ArrayRef<Attribute> operands) {
+  assert(operands.size() == 2 && "binary operation takes two operands");
+
+  // maxsi(x,x) -> x
+  if (lhs() == rhs())
+    return rhs();
+
+  APInt intValue;
+  // maxsi(x,MAX_INT) -> MAX_INT
+  if (matchPattern(rhs(), m_ConstantInt(&intValue)) &&
+      intValue.isMaxSignedValue())
+    return rhs();
+
+  // maxsi(x, MIN_INT) -> x
+  if (matchPattern(rhs(), m_ConstantInt(&intValue)) &&
+      intValue.isMinSignedValue())
+    return lhs();
+
+  return constFoldBinaryOp<IntegerAttr>(
+      operands, [](APInt a, APInt b) { return llvm::APIntOps::smax(a, b); });
+}
+
+//===----------------------------------------------------------------------===//
+// MaxUIOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult MaxUIOp::fold(ArrayRef<Attribute> operands) {
+  assert(operands.size() == 2 && "binary operation takes two operands");
+
+  // maxui(x,x) -> x
+  if (lhs() == rhs())
+    return rhs();
+
+  APInt intValue;
+  // maxui(x,MAX_INT) -> MAX_INT
+  if (matchPattern(rhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue())
+    return rhs();
+
+  // maxui(x, MIN_INT) -> x
+  if (matchPattern(rhs(), m_ConstantInt(&intValue)) && intValue.isMinValue())
+    return lhs();
+
+  return constFoldBinaryOp<IntegerAttr>(
+      operands, [](APInt a, APInt b) { return llvm::APIntOps::umax(a, b); });
+}
+
+//===----------------------------------------------------------------------===//
+// MinSIOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult MinSIOp::fold(ArrayRef<Attribute> operands) {
+  assert(operands.size() == 2 && "binary operation takes two operands");
+
+  // minsi(x,x) -> x
+  if (lhs() == rhs())
+    return rhs();
+
+  APInt intValue;
+  // minsi(x,MIN_INT) -> MIN_INT
+  if (matchPattern(rhs(), m_ConstantInt(&intValue)) &&
+      intValue.isMinSignedValue())
+    return rhs();
+
+  // minsi(x, MAX_INT) -> x
+  if (matchPattern(rhs(), m_ConstantInt(&intValue)) &&
+      intValue.isMaxSignedValue())
+    return lhs();
+
+  return constFoldBinaryOp<IntegerAttr>(
+      operands, [](APInt a, APInt b) { return llvm::APIntOps::smin(a, b); });
+}
+
+//===----------------------------------------------------------------------===//
+// MinUIOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult MinUIOp::fold(ArrayRef<Attribute> operands) {
+  assert(operands.size() == 2 && "binary operation takes two operands");
+
+  // minui(x,x) -> x
+  if (lhs() == rhs())
+    return rhs();
+
+  APInt intValue;
+  // minui(x,MIN_INT) -> MIN_INT
+  if (matchPattern(rhs(), m_ConstantInt(&intValue)) && intValue.isMinValue())
+    return rhs();
+
+  // minui(x, MAX_INT) -> x
+  if (matchPattern(rhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue())
+    return lhs();
+
+  return constFoldBinaryOp<IntegerAttr>(
+      operands, [](APInt a, APInt b) { return llvm::APIntOps::umin(a, b); });
+}
+
 //===----------------------------------------------------------------------===//
 // RankOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir
index 875d9f7bc4fa2..2ba3fe1fa6000 100644
--- a/mlir/test/Dialect/Standard/canonicalize.mlir
+++ b/mlir/test/Dialect/Standard/canonicalize.mlir
@@ -60,3 +60,68 @@ func @selToNot(%arg0: i1) -> i1 {
   %res = select %arg0, %false, %true : i1
   return %res : i1
 }
+
+// CHECK-LABEL: test_maxsi
+// CHECK: %[[C0:.+]] = arith.constant 42
+// CHECK: %[[MAX_INT_CST:.+]] = arith.constant 127
+// CHECK: %[[X:.+]] = maxsi %arg0, %[[C0]]
+// CHECK: return %arg0, %[[MAX_INT_CST]], %arg0, %[[X]]
+func @test_maxsi(%arg0 : i8) -> (i8, i8, i8, i8) {
+  %maxIntCst = arith.constant 127 : i8
+  %minIntCst = arith.constant -128 : i8
+  %c0 = arith.constant 42 : i8
+  %0 = maxsi %arg0, %arg0 : i8
+  %1 = maxsi %arg0, %maxIntCst : i8
+  %2 = maxsi %arg0, %minIntCst : i8
+  %3 = maxsi %arg0, %c0 : i8
+  return %0, %1, %2, %3: i8, i8, i8, i8
+}
+
+// CHECK-LABEL: test_maxui
+// CHECK: %[[C0:.+]] = arith.constant 42
+// CHECK: %[[MAX_INT_CST:.+]] = arith.constant -1
+// CHECK: %[[X:.+]] = maxui %arg0, %[[C0]]
+// CHECK: return %arg0, %[[MAX_INT_CST]], %arg0, %[[X]]
+func @test_maxui(%arg0 : i8) -> (i8, i8, i8, i8) {
+  %maxIntCst = arith.constant 255 : i8
+  %minIntCst = arith.constant 0 : i8
+  %c0 = arith.constant 42 : i8
+  %0 = maxui %arg0, %arg0 : i8
+  %1 = maxui %arg0, %maxIntCst : i8
+  %2 = maxui %arg0, %minIntCst : i8
+  %3 = maxui %arg0, %c0 : i8
+  return %0, %1, %2, %3: i8, i8, i8, i8
+}
+
+
+// CHECK-LABEL: test_minsi
+// CHECK: %[[C0:.+]] = arith.constant 42
+// CHECK: %[[MIN_INT_CST:.+]] = arith.constant -128
+// CHECK: %[[X:.+]] = minsi %arg0, %[[C0]]
+// CHECK: return %arg0, %arg0, %[[MIN_INT_CST]], %[[X]]
+func @test_minsi(%arg0 : i8) -> (i8, i8, i8, i8) {
+  %maxIntCst = arith.constant 127 : i8
+  %minIntCst = arith.constant -128 : i8
+  %c0 = arith.constant 42 : i8
+  %0 = minsi %arg0, %arg0 : i8
+  %1 = minsi %arg0, %maxIntCst : i8
+  %2 = minsi %arg0, %minIntCst : i8
+  %3 = minsi %arg0, %c0 : i8
+  return %0, %1, %2, %3: i8, i8, i8, i8
+}
+
+// CHECK-LABEL: test_minui
+// CHECK: %[[C0:.+]] = arith.constant 42
+// CHECK: %[[MIN_INT_CST:.+]] = arith.constant 0
+// CHECK: %[[X:.+]] = minui %arg0, %[[C0]]
+// CHECK: return %arg0, %arg0, %[[MIN_INT_CST]], %[[X]]
+func @test_minui(%arg0 : i8) -> (i8, i8, i8, i8) {
+  %maxIntCst = arith.constant 255 : i8
+  %minIntCst = arith.constant 0 : i8
+  %c0 = arith.constant 42 : i8
+  %0 = minui %arg0, %arg0 : i8
+  %1 = minui %arg0, %maxIntCst : i8
+  %2 = minui %arg0, %minIntCst : i8
+  %3 = minui %arg0, %c0 : i8
+  return %0, %1, %2, %3: i8, i8, i8, i8
+}


        


More information about the Mlir-commits mailing list