[Mlir-commits] [mlir] [mlir] Use arith max or min ops instead of cmp + select (PR #82178)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Feb 21 12:04:09 PST 2024
https://github.com/mlevesquedion updated https://github.com/llvm/llvm-project/pull/82178
>From d345304b27f683af4aa66e49b484e63a66087a7f Mon Sep 17 00:00:00 2001
From: Michael Levesque-Dion <mlevesquedion at google.com>
Date: Sat, 17 Feb 2024 21:21:50 -0800
Subject: [PATCH 1/2] [mlir] Use arith max or min ops instead of cmp + select
I believe the semantics should be the same, but this saves 1 op and
simplifies the code.
---
.../AffineToStandard/AffineToStandard.cpp | 17 ++--
.../ShapeToStandard/ShapeToStandard.cpp | 8 +-
.../Conversion/TosaToLinalg/TosaToLinalg.cpp | 20 ++---
.../TosaToLinalg/TosaToLinalgNamed.cpp | 9 +-
mlir/lib/Dialect/SCF/Utils/Utils.cpp | 6 +-
.../Dialect/Tosa/Utils/ConversionUtils.cpp | 9 +-
.../AffineToStandard/lower-affine.mlir | 42 ++++------
.../expand-then-convert-to-llvm.mlir | 3 +-
.../ShapeToStandard/shape-to-standard.mlir | 18 ++--
.../TosaToLinalg/tosa-to-linalg-named.mlir | 44 ++++------
.../TosaToLinalg/tosa-to-linalg-resize.mlir | 83 +++++++------------
.../TosaToLinalg/tosa-to-linalg.mlir | 61 +++++---------
mlir/test/Transforms/parametric-tiling.mlir | 12 +--
13 files changed, 114 insertions(+), 218 deletions(-)
diff --git a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
index 15ad6d8cdf629d..98cdfc63252711 100644
--- a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
+++ b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
@@ -34,12 +34,7 @@ using namespace mlir::affine;
using namespace mlir::vector;
/// Given a range of values, emit the code that reduces them with "min" or "max"
-/// depending on the provided comparison predicate. The predicate defines which
-/// comparison to perform, "lt" for "min", "gt" for "max" and is used for the
-/// `cmpi` operation followed by the `select` operation:
-///
-/// %cond = arith.cmpi "predicate" %v0, %v1
-/// %result = select %cond, %v0, %v1
+/// depending on the provided comparison predicate, sgt for max and slt for min.
///
/// Multiple values are scanned in a linear sequence. This creates a data
/// dependences that wouldn't exist in a tree reduction, but is easier to
@@ -48,13 +43,17 @@ static Value buildMinMaxReductionSeq(Location loc,
arith::CmpIPredicate predicate,
ValueRange values, OpBuilder &builder) {
assert(!values.empty() && "empty min/max chain");
+ assert(predicate == arith::CmpIPredicate::sgt ||
+ predicate == arith::CmpIPredicate::slt);
auto valueIt = values.begin();
Value value = *valueIt++;
for (; valueIt != values.end(); ++valueIt) {
- auto cmpOp = builder.create<arith::CmpIOp>(loc, predicate, value, *valueIt);
- value = builder.create<arith::SelectOp>(loc, cmpOp.getResult(), value,
- *valueIt);
+ if (predicate == arith::CmpIPredicate::sgt) {
+ value = builder.create<arith::MaxSIOp>(loc, value, *valueIt);
+ } else {
+ value = builder.create<arith::MinSIOp>(loc, value, *valueIt);
+ }
}
return value;
diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
index a3e51aeed0735a..de649f730ee9d7 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
+++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
@@ -147,9 +147,7 @@ LogicalResult BroadcastOpConverter::matchAndRewrite(
// Find the maximum rank
Value maxRank = ranks.front();
for (Value v : llvm::drop_begin(ranks, 1)) {
- Value rankIsGreater =
- lb.create<arith::CmpIOp>(arith::CmpIPredicate::ugt, v, maxRank);
- maxRank = lb.create<arith::SelectOp>(rankIsGreater, v, maxRank);
+ maxRank = lb.create<arith::MaxUIOp>(v, maxRank);
}
// Calculate the difference of ranks and the maximum rank for later offsets.
@@ -262,9 +260,7 @@ LogicalResult IsBroadcastableOpConverter::matchAndRewrite(
// Find the maximum rank
Value maxRank = ranks.front();
for (Value v : llvm::drop_begin(ranks, 1)) {
- Value rankIsGreater =
- lb.create<arith::CmpIOp>(arith::CmpIPredicate::ugt, v, maxRank);
- maxRank = lb.create<arith::SelectOp>(rankIsGreater, v, maxRank);
+ maxRank = lb.create<arith::MaxUIOp>(v, maxRank);
}
// Calculate the difference of ranks and the maximum rank for later offsets.
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index f4f6dadfb37166..7eb32ebe3228fb 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -61,10 +61,8 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
if (isa<tosa::AbsOp>(op) && isa<IntegerType>(elementTy)) {
auto zero = rewriter.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(elementTy));
- auto cmp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt,
- args[0], zero);
auto neg = rewriter.create<arith::SubIOp>(loc, zero, args[0]);
- return rewriter.create<arith::SelectOp>(loc, cmp, args[0], neg);
+ return rewriter.create<arith::MaxSIOp>(loc, args[0], neg);
}
// tosa::AddOp
@@ -348,9 +346,7 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
}
if (isa<tosa::MaximumOp>(op) && elementTy.isSignlessInteger()) {
- auto predicate = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::sgt, args[0], args[1]);
- return rewriter.create<arith::SelectOp>(loc, predicate, args[0], args[1]);
+ return rewriter.create<arith::MaxSIOp>(loc, args[0], args[1]);
}
// tosa::MinimumOp
@@ -359,9 +355,7 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
}
if (isa<tosa::MinimumOp>(op) && elementTy.isSignlessInteger()) {
- auto predicate = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::slt, args[0], args[1]);
- return rewriter.create<arith::SelectOp>(loc, predicate, args[0], args[1]);
+ return rewriter.create<arith::MinSIOp>(loc, args[0], args[1]);
}
// tosa::CeilOp
@@ -1000,9 +994,7 @@ static Value createLinalgBodyCalculationForReduceOp(Operation *op,
}
if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy)) {
- auto predicate = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::slt, args[0], args[1]);
- return rewriter.create<arith::SelectOp>(loc, predicate, args[0], args[1]);
+ return rewriter.create<arith::MinSIOp>(loc, args[0], args[1]);
}
if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy)) {
@@ -1010,9 +1002,7 @@ static Value createLinalgBodyCalculationForReduceOp(Operation *op,
}
if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy)) {
- auto predicate = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::sgt, args[0], args[1]);
- return rewriter.create<arith::SelectOp>(loc, predicate, args[0], args[1]);
+ return rewriter.create<arith::MaxSIOp>(loc, args[0], args[1]);
}
if (isa<tosa::ReduceAllOp>(op) && elementTy.isInteger(1))
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 607a603cca810f..3f39cbf03a9a80 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -845,10 +845,7 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
auto padVal = rewriter.create<arith::ConstantIndexOp>(loc, pad);
Value dpos = rewriter.create<arith::SubIOp>(loc, pos, padVal);
- Value cmp = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::slt, dpos, zero);
- Value offset =
- rewriter.create<arith::SelectOp>(loc, cmp, dpos, zero);
+ Value offset = rewriter.create<arith::MinSIOp>(loc, dpos, zero);
return rewriter.create<arith::AddIOp>(loc, valid, offset)
->getResult(0);
};
@@ -868,9 +865,7 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
// Determine how much padding was included.
val = padFn(val, left, pad[i * 2]);
val = padFn(val, right, pad[i * 2 + 1]);
- Value cmp = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::slt, val, one);
- return rewriter.create<arith::SelectOp>(loc, cmp, one, val);
+ return rewriter.create<arith::MaxSIOp>(loc, one, val);
};
// Compute the indices from either end.
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 536c02feca1bd5..502d7e197a6f6b 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -791,10 +791,8 @@ static Loops stripmineSink(scf::ForOp forOp, Value factor,
// Insert newForOp before the terminator of `t`.
auto b = OpBuilder::atBlockTerminator((t.getBody()));
Value stepped = b.create<arith::AddIOp>(t.getLoc(), iv, forOp.getStep());
- Value less = b.create<arith::CmpIOp>(t.getLoc(), arith::CmpIPredicate::slt,
- forOp.getUpperBound(), stepped);
- Value ub = b.create<arith::SelectOp>(t.getLoc(), less,
- forOp.getUpperBound(), stepped);
+ Value ub =
+ b.create<arith::MinSIOp>(t.getLoc(), forOp.getUpperBound(), stepped);
// Splice [begin, begin + nOps - 1) into `newForOp` and replace uses.
auto newForOp = b.create<scf::ForOp>(t.getLoc(), iv, ub, originalStep);
diff --git a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
index ee428b201d0073..4fc97115064f33 100644
--- a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
+++ b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
@@ -39,13 +39,8 @@ Value mlir::tosa::clampFloatHelper(Location loc, Value arg, Value min,
Value mlir::tosa::clampIntHelper(Location loc, Value arg, Value min, Value max,
OpBuilder &rewriter) {
- auto smallerThanMin =
- rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, arg, min);
- auto minOrArg =
- rewriter.create<arith::SelectOp>(loc, smallerThanMin, min, arg);
- auto largerThanMax =
- rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, max, arg);
- return rewriter.create<arith::SelectOp>(loc, largerThanMax, max, minOrArg);
+ auto minOrArg = rewriter.create<arith::MaxSIOp>(loc, min, arg);
+ return rewriter.create<arith::MinSIOp>(loc, max, minOrArg);
}
bool mlir::tosa::validIntegerRange(IntegerType ty, int64_t value) {
diff --git a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
index 92608135d24b08..00d7b6b8d65f67 100644
--- a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
+++ b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
@@ -371,16 +371,14 @@ func.func @if_for() {
// CHECK-NEXT: %[[c1:.*]] = arith.constant 1 : index
// CHECK-NEXT: for %{{.*}} = %[[c0]] to %[[c42]] step %[[c1]] {
// CHECK-NEXT: %[[cm1:.*]] = arith.constant -1 : index
-// CHECK-NEXT: %[[a:.*]] = arith.muli %{{.*}}, %[[cm1]] : index
-// CHECK-NEXT: %[[b:.*]] = arith.addi %[[a]], %{{.*}} : index
-// CHECK-NEXT: %[[c:.*]] = arith.cmpi sgt, %{{.*}}, %[[b]] : index
-// CHECK-NEXT: %[[d:.*]] = arith.select %[[c]], %{{.*}}, %[[b]] : index
+// CHECK-NEXT: %[[mul0:.*]] = arith.muli %{{.*}}, %[[cm1]] : index
+// CHECK-NEXT: %[[add0:.*]] = arith.addi %[[mul0]], %{{.*}} : index
+// CHECK-NEXT: %[[max:.*]] = arith.maxsi %{{.*}}, %[[add0]] : index
// CHECK-NEXT: %[[c10:.*]] = arith.constant 10 : index
-// CHECK-NEXT: %[[e:.*]] = arith.addi %{{.*}}, %[[c10]] : index
-// CHECK-NEXT: %[[f:.*]] = arith.cmpi slt, %{{.*}}, %[[e]] : index
-// CHECK-NEXT: %[[g:.*]] = arith.select %[[f]], %{{.*}}, %[[e]] : index
+// CHECK-NEXT: %[[add1:.*]] = arith.addi %{{.*}}, %[[c10]] : index
+// CHECK-NEXT: %[[min:.*]] = arith.minsi %{{.*}}, %[[add1]] : index
// CHECK-NEXT: %[[c1_0:.*]] = arith.constant 1 : index
-// CHECK-NEXT: for %{{.*}} = %[[d]] to %[[g]] step %[[c1_0]] {
+// CHECK-NEXT: for %{{.*}} = %[[max]] to %[[min]] step %[[c1_0]] {
// CHECK-NEXT: call @body2(%{{.*}}, %{{.*}}) : (index, index) -> ()
// CHECK-NEXT: }
// CHECK-NEXT: }
@@ -397,25 +395,19 @@ func.func @loop_min_max(%N : index) {
#map_7_values = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4, d5, d6)>
-// Check that the "min" (cmpi slt + select) reduction sequence is emitted
+// Check that the "min" reduction sequence is emitted
// correctly for an affine map with 7 results.
// CHECK-LABEL: func @min_reduction_tree
// CHECK-NEXT: %[[c0:.*]] = arith.constant 0 : index
-// CHECK-NEXT: %[[c01:.+]] = arith.cmpi slt, %{{.*}}, %{{.*}} : index
-// CHECK-NEXT: %[[r01:.+]] = arith.select %[[c01]], %{{.*}}, %{{.*}} : index
-// CHECK-NEXT: %[[c012:.+]] = arith.cmpi slt, %[[r01]], %{{.*}} : index
-// CHECK-NEXT: %[[r012:.+]] = arith.select %[[c012]], %[[r01]], %{{.*}} : index
-// CHECK-NEXT: %[[c0123:.+]] = arith.cmpi slt, %[[r012]], %{{.*}} : index
-// CHECK-NEXT: %[[r0123:.+]] = arith.select %[[c0123]], %[[r012]], %{{.*}} : index
-// CHECK-NEXT: %[[c01234:.+]] = arith.cmpi slt, %[[r0123]], %{{.*}} : index
-// CHECK-NEXT: %[[r01234:.+]] = arith.select %[[c01234]], %[[r0123]], %{{.*}} : index
-// CHECK-NEXT: %[[c012345:.+]] = arith.cmpi slt, %[[r01234]], %{{.*}} : index
-// CHECK-NEXT: %[[r012345:.+]] = arith.select %[[c012345]], %[[r01234]], %{{.*}} : index
-// CHECK-NEXT: %[[c0123456:.+]] = arith.cmpi slt, %[[r012345]], %{{.*}} : index
-// CHECK-NEXT: %[[r0123456:.+]] = arith.select %[[c0123456]], %[[r012345]], %{{.*}} : index
+// CHECK-NEXT: %[[min:.+]] = arith.minsi %{{.*}}, %{{.*}} : index
+// CHECK-NEXT: %[[min_0:.+]] = arith.minsi %[[min]], %{{.*}} : index
+// CHECK-NEXT: %[[min_1:.+]] = arith.minsi %[[min_0]], %{{.*}} : index
+// CHECK-NEXT: %[[min_2:.+]] = arith.minsi %[[min_1]], %{{.*}} : index
+// CHECK-NEXT: %[[min_3:.+]] = arith.minsi %[[min_2]], %{{.*}} : index
+// CHECK-NEXT: %[[min_4:.+]] = arith.minsi %[[min_3]], %{{.*}} : index
// CHECK-NEXT: %[[c1:.*]] = arith.constant 1 : index
-// CHECK-NEXT: for %{{.*}} = %[[c0]] to %[[r0123456]] step %[[c1]] {
+// CHECK-NEXT: for %{{.*}} = %[[c0]] to %[[min_4]] step %[[c1]] {
// CHECK-NEXT: call @body(%{{.*}}) : (index) -> ()
// CHECK-NEXT: }
// CHECK-NEXT: return
@@ -690,8 +682,7 @@ func.func @affine_min(%arg0: index, %arg1: index) -> index{
// CHECK: %[[Cm2:.*]] = arith.constant -1
// CHECK: %[[neg2:.*]] = arith.muli %[[ARG0]], %[[Cm2:.*]]
// CHECK: %[[second:.*]] = arith.addi %[[ARG1]], %[[neg2]]
- // CHECK: %[[cmp:.*]] = arith.cmpi slt, %[[first]], %[[second]]
- // CHECK: arith.select %[[cmp]], %[[first]], %[[second]]
+ // CHECK: arith.minsi %[[first]], %[[second]]
%0 = affine.min affine_map<(d0,d1) -> (d0 - d1, d1 - d0)>(%arg0, %arg1)
return %0 : index
}
@@ -705,8 +696,7 @@ func.func @affine_max(%arg0: index, %arg1: index) -> index{
// CHECK: %[[Cm2:.*]] = arith.constant -1
// CHECK: %[[neg2:.*]] = arith.muli %[[ARG0]], %[[Cm2:.*]]
// CHECK: %[[second:.*]] = arith.addi %[[ARG1]], %[[neg2]]
- // CHECK: %[[cmp:.*]] = arith.cmpi sgt, %[[first]], %[[second]]
- // CHECK: arith.select %[[cmp]], %[[first]], %[[second]]
+ // CHECK: arith.maxsi %[[first]], %[[second]]
%0 = affine.max affine_map<(d0,d1) -> (d0 - d1, d1 - d0)>(%arg0, %arg1)
return %0 : index
}
diff --git a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
index eb45112b117c0d..87d613986c7c3f 100644
--- a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
@@ -554,8 +554,7 @@ func.func @collapse_shape_dynamic(%arg0 : memref<1x2x?xf32>) -> memref<1x?xf32>
// CHECK: %[[SIZE1_TO_IDX:.*]] = builtin.unrealized_conversion_cast %[[FINAL_SIZE1]] : i64 to index
// CHECK: %[[FINAL_SIZE1:.*]] = builtin.unrealized_conversion_cast %[[SIZE1_TO_IDX]] : index to i64
// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : index) : i64
-// CHECK: %[[IS_MIN_STRIDE1:.*]] = llvm.icmp "slt" %[[STRIDE1]], %[[C1]] : i64
-// CHECK: %[[MIN_STRIDE1:.*]] = llvm.select %[[IS_MIN_STRIDE1]], %[[STRIDE1]], %[[C1]] : i1, i64
+// CHECK: %[[MIN_STRIDE1:.*]] = llvm.intr.smin(%[[STRIDE1]], %[[C1]]) : (i64, i64) -> i64
// CHECK: %[[MIN_STRIDE1_TO_IDX:.*]] = builtin.unrealized_conversion_cast %[[MIN_STRIDE1]] : i64 to index
// CHECK: %[[MIN_STRIDE1:.*]] = builtin.unrealized_conversion_cast %[[MIN_STRIDE1_TO_IDX]] : index to i64
// CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
index cb3af973daee20..3b73c513b7955f 100644
--- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
+++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
@@ -377,10 +377,8 @@ func.func @try_is_broadcastable (%a : tensor<2xindex>, %b : tensor<3xindex>, %c
// CHECK: %[[RANK0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<2xindex>
// CHECK: %[[RANK1:.*]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<3xindex>
// CHECK: %[[RANK2:.*]] = tensor.dim %[[ARG2]], %[[C0]] : tensor<2xindex>
-// CHECK: %[[CMP0:.*]] = arith.cmpi ugt, %[[RANK1]], %[[RANK0]] : index
-// CHECK: %[[LARGER_DIM:.*]] = arith.select %[[CMP0]], %[[RANK1]], %[[RANK0]] : index
-// CHECK: %[[CMP1:.*]] = arith.cmpi ugt, %[[RANK2]], %[[LARGER_DIM]] : index
-// CHECK: %[[MAX_RANK:.*]] = arith.select %[[CMP1]], %[[RANK2]], %[[LARGER_DIM]] : index
+// CHECK: %[[MAX0:.*]] = arith.maxui %[[RANK1]], %[[RANK0]] : index
+// CHECK: %[[MAX_RANK:.*]] = arith.maxui %[[RANK2]], %[[MAX0]] : index
// CHECK: %[[DIM_DIFF0:.*]] = arith.subi %[[MAX_RANK]], %[[RANK0]] : index
// CHECK: %[[DIM_DIFF1:.*]] = arith.subi %[[MAX_RANK]], %[[RANK1]] : index
// CHECK: %[[DIM_DIFF2:.*]] = arith.subi %[[MAX_RANK]], %[[RANK2]] : index
@@ -467,10 +465,8 @@ func.func @broadcast(%a : tensor<2xindex>, %b : tensor<3xindex>, %c : tensor<2xi
// CHECK: %[[RANK0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<2xindex>
// CHECK: %[[RANK1:.*]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<3xindex>
// CHECK: %[[RANK2:.*]] = tensor.dim %[[ARG2]], %[[C0]] : tensor<2xindex>
-// CHECK: %[[CMP0:.*]] = arith.cmpi ugt, %[[RANK1]], %[[RANK0]] : index
-// CHECK: %[[LARGER_DIM:.*]] = arith.select %[[CMP0]], %[[RANK1]], %[[RANK0]] : index
-// CHECK: %[[CMP1:.*]] = arith.cmpi ugt, %[[RANK2]], %[[LARGER_DIM]] : index
-// CHECK: %[[MAX_RANK:.*]] = arith.select %[[CMP1]], %[[RANK2]], %[[LARGER_DIM]] : index
+// CHECK: %[[MAX0:.*]] = arith.maxui %[[RANK1]], %[[RANK0]] : index
+// CHECK: %[[MAX_RANK:.*]] = arith.maxui %[[RANK2]], %[[MAX0]] : index
// CHECK: %[[DIM_DIFF0:.*]] = arith.subi %[[MAX_RANK]], %[[RANK0]] : index
// CHECK: %[[DIM_DIFF1:.*]] = arith.subi %[[MAX_RANK]], %[[RANK1]] : index
// CHECK: %[[DIM_DIFF2:.*]] = arith.subi %[[MAX_RANK]], %[[RANK2]] : index
@@ -559,10 +555,8 @@ func.func @broadcast_3_shapes_different_extents(%a : tensor<2xindex>,
// CHECK: %[[RANK0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<2xindex>
// CHECK: %[[RANK1:.*]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<3xindex>
// CHECK: %[[RANK2:.*]] = tensor.dim %[[ARG2]], %[[C0]] : tensor<2xindex>
-// CHECK: %[[CMP0:.*]] = arith.cmpi ugt, %[[RANK1]], %[[RANK0]] : index
-// CHECK: %[[LARGER_DIM:.*]] = arith.select %[[CMP0]], %[[RANK1]], %[[RANK0]] : index
-// CHECK: %[[CMP1:.*]] = arith.cmpi ugt, %[[RANK2]], %[[LARGER_DIM]] : index
-// CHECK: %[[MAX_RANK:.*]] = arith.select %[[CMP1]], %[[RANK2]], %[[LARGER_DIM]] : index
+// CHECK: %[[MAX0:.*]] = arith.maxui %[[RANK1]], %[[RANK0]] : index
+// CHECK: %[[MAX_RANK:.*]] = arith.maxui %[[RANK2]], %[[MAX0]] : index
// CHECK: %[[DIM_DIFF0:.*]] = arith.subi %[[MAX_RANK]], %[[RANK0]] : index
// CHECK: %[[DIM_DIFF1:.*]] = arith.subi %[[MAX_RANK]], %[[RANK1]] : index
// CHECK: %[[DIM_DIFF2:.*]] = arith.subi %[[MAX_RANK]], %[[RANK2]] : index
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index 51ebcad0797807..e64903671e599f 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -263,16 +263,13 @@ func.func @avg_pool_f32(%arg0: tensor<1x6x34x62xf32>) -> (tensor<1x5x33x62xf32>)
// CHECK: %[[SRC_END:.+]] = arith.muli %[[END]], %[[STRIDE]]
// CHECK: %[[PAD_START:.+]] = arith.constant 1
// CHECK: %[[START_SUB:.+]] = arith.subi %[[SRC_START]], %[[PAD_START]]
- // CHECK: %[[CMP:.+]] = arith.cmpi slt, %[[START_SUB]], %[[ZERO]]
- // CHECK: %[[OFFSET:.+]] = arith.select %[[CMP]], %[[START_SUB]], %[[ZERO]]
+ // CHECK: %[[OFFSET:.+]] = arith.minsi %[[START_SUB]], %[[ZERO]]
// CHECK: %[[START_OFFSET:.+]] = arith.addi %[[KSIZE]], %[[OFFSET]]
// CHECK: %[[PAD_END:.+]] = arith.constant 1
// CHECK: %[[END_SUB:.+]] = arith.subi %[[SRC_END]], %[[PAD_END]]
- // CHECK: %[[CMP:.+]] = arith.cmpi slt, %[[END_SUB]], %[[ZERO]]
- // CHECK: %[[OFFSET:.+]] = arith.select %[[CMP]], %[[END_SUB]], %[[ZERO]]
+ // CHECK: %[[OFFSET:.+]] = arith.minsi %[[END_SUB]], %[[ZERO]]
// CHECK: %[[END_OFFSET:.+]] = arith.addi %[[START_OFFSET]], %[[OFFSET]]
- // CHECK: %[[CMP:.+]] = arith.cmpi slt, %[[END_OFFSET]], %[[ONE]]
- // CHECK: %[[KHEIGHT:.+]] = arith.select %[[CMP]], %[[ONE]], %[[END_OFFSET]]
+ // CHECK: %[[KHEIGHT:.+]] = arith.maxsi %[[ONE]], %[[END_OFFSET]]
// Compute how much of the width does not include padding:
// CHECK: %[[STRIDE:.+]] = arith.constant 1
@@ -283,16 +280,13 @@ func.func @avg_pool_f32(%arg0: tensor<1x6x34x62xf32>) -> (tensor<1x5x33x62xf32>)
// CHECK: %[[SRC_END:.+]] = arith.muli %[[END]], %[[STRIDE]]
// CHECK: %[[PAD_START:.+]] = arith.constant 1
// CHECK: %[[START_SUB:.+]] = arith.subi %[[SRC_START]], %[[PAD_START]]
- // CHECK: %[[CMP:.+]] = arith.cmpi slt, %[[START_SUB]], %[[ZERO]]
- // CHECK: %[[OFFSET:.+]] = arith.select %[[CMP]], %[[START_SUB]], %[[ZERO]]
+ // CHECK: %[[OFFSET:.+]] = arith.minsi %[[START_SUB]], %[[ZERO]]
// CHECK: %[[START_OFFSET:.+]] = arith.addi %[[KSIZE]], %[[OFFSET]]
// CHECK: %[[PAD_END:.+]] = arith.constant 1
// CHECK: %[[END_SUB:.+]] = arith.subi %[[SRC_END]], %[[PAD_END]]
- // CHECK: %[[CMP:.+]] = arith.cmpi slt, %[[END_SUB]], %[[ZERO]]
- // CHECK: %[[OFFSET:.+]] = arith.select %[[CMP]], %[[END_SUB]], %[[ZERO]]
+ // CHECK: %[[OFFSET:.+]] = arith.minsi %[[END_SUB]], %[[ZERO]]
// CHECK: %[[END_OFFSET:.+]] = arith.addi %[[START_OFFSET]], %[[OFFSET]]
- // CHECK: %[[CMP:.+]] = arith.cmpi slt, %[[END_OFFSET]], %[[ONE]]
- // CHECK: %[[KWIDTH:.+]] = arith.select %[[CMP]], %[[ONE]], %[[END_OFFSET]]
+ // CHECK: %[[KWIDTH:.+]] = arith.maxsi %[[ONE]], %[[END_OFFSET]]
// Divide the summed value by the number of values summed.
// CHECK: %[[COUNT:.+]] = arith.muli %[[KHEIGHT]], %[[KWIDTH]]
@@ -353,16 +347,13 @@ func.func @avg_pool_f16_f32acc(%arg0: tensor<1x6x34x62xf16>) -> (tensor<1x5x33x6
// CHECK: %[[SRC_END:.+]] = arith.muli %[[END]], %[[STRIDE]]
// CHECK: %[[PAD_START:.+]] = arith.constant 1
// CHECK: %[[START_SUB:.+]] = arith.subi %[[SRC_START]], %[[PAD_START]]
- // CHECK: %[[CMP:.+]] = arith.cmpi slt, %[[START_SUB]], %[[ZERO]]
- // CHECK: %[[OFFSET:.+]] = arith.select %[[CMP]], %[[START_SUB]], %[[ZERO]]
+ // CHECK: %[[OFFSET:.+]] = arith.minsi %[[START_SUB]], %[[ZERO]]
// CHECK: %[[START_OFFSET:.+]] = arith.addi %[[KSIZE]], %[[OFFSET]]
// CHECK: %[[PAD_END:.+]] = arith.constant 1
// CHECK: %[[END_SUB:.+]] = arith.subi %[[SRC_END]], %[[PAD_END]]
- // CHECK: %[[CMP:.+]] = arith.cmpi slt, %[[END_SUB]], %[[ZERO]]
- // CHECK: %[[OFFSET:.+]] = arith.select %[[CMP]], %[[END_SUB]], %[[ZERO]]
+ // CHECK: %[[OFFSET:.+]] = arith.minsi %[[END_SUB]], %[[ZERO]]
// CHECK: %[[END_OFFSET:.+]] = arith.addi %[[START_OFFSET]], %[[OFFSET]]
- // CHECK: %[[CMP:.+]] = arith.cmpi slt, %[[END_OFFSET]], %[[ONE]]
- // CHECK: %[[KHEIGHT:.+]] = arith.select %[[CMP]], %[[ONE]], %[[END_OFFSET]]
+ // CHECK: %[[KHEIGHT:.+]] = arith.maxsi %[[ONE]], %[[END_OFFSET]]
// Compute how much of the width does not include padding:
// CHECK: %[[STRIDE:.+]] = arith.constant 1
@@ -373,16 +364,13 @@ func.func @avg_pool_f16_f32acc(%arg0: tensor<1x6x34x62xf16>) -> (tensor<1x5x33x6
// CHECK: %[[SRC_END:.+]] = arith.muli %[[END]], %[[STRIDE]]
// CHECK: %[[PAD_START:.+]] = arith.constant 1
// CHECK: %[[START_SUB:.+]] = arith.subi %[[SRC_START]], %[[PAD_START]]
- // CHECK: %[[CMP:.+]] = arith.cmpi slt, %[[START_SUB]], %[[ZERO]]
- // CHECK: %[[OFFSET:.+]] = arith.select %[[CMP]], %[[START_SUB]], %[[ZERO]]
+ // CHECK: %[[OFFSET:.+]] = arith.minsi %[[START_SUB]], %[[ZERO]]
// CHECK: %[[START_OFFSET:.+]] = arith.addi %[[KSIZE]], %[[OFFSET]]
// CHECK: %[[PAD_END:.+]] = arith.constant 1
// CHECK: %[[END_SUB:.+]] = arith.subi %[[SRC_END]], %[[PAD_END]]
- // CHECK: %[[CMP:.+]] = arith.cmpi slt, %[[END_SUB]], %[[ZERO]]
- // CHECK: %[[OFFSET:.+]] = arith.select %[[CMP]], %[[END_SUB]], %[[ZERO]]
+ // CHECK: %[[OFFSET:.+]] = arith.minsi %[[END_SUB]], %[[ZERO]]
// CHECK: %[[END_OFFSET:.+]] = arith.addi %[[START_OFFSET]], %[[OFFSET]]
- // CHECK: %[[CMP:.+]] = arith.cmpi slt, %[[END_OFFSET]], %[[ONE]]
- // CHECK: %[[KWIDTH:.+]] = arith.select %[[CMP]], %[[ONE]], %[[END_OFFSET]]
+ // CHECK: %[[KWIDTH:.+]] = arith.maxsi %[[ONE]], %[[END_OFFSET]]
// Divide the summed value by the number of values summed.
// CHECK: %[[COUNT:.+]] = arith.muli %[[KHEIGHT]], %[[KWIDTH]]
@@ -407,7 +395,7 @@ func.func @avg_pool_i8(%arg0: tensor<1x6x34x62xi8>) -> (tensor<1x5x33x62xi8>) {
// Only different behavior is how the division is performed.
// First we compute the mul and shift values for average pool:
- // CHECK: %[[COUNT:.+]] = arith.muli %21, %35
+ // CHECK: %[[COUNT:.+]] = arith.muli %{{[0-9]+}}, %{{[0-9]+}}
// CHECK: %[[ICAST:.+]] = arith.index_cast %[[COUNT]]
// CHECK: %[[C1:.+]] = arith.constant 1
// CHECK: %[[C32:.+]] = arith.constant 32
@@ -428,10 +416,8 @@ func.func @avg_pool_i8(%arg0: tensor<1x6x34x62xi8>) -> (tensor<1x5x33x62xi8>) {
// Perform the normalization.
// CHECK: %[[CMIN:.+]] = arith.constant -128
// CHECK: %[[CMAX:.+]] = arith.constant 127
- // CHECK: %[[CMP:.+]] = arith.cmpi slt, %[[SCALED]], %[[CMIN]]
- // CHECK: %[[SEL:.+]] = arith.select %[[CMP]], %[[CMIN]], %[[SCALED]]
- // CHECK: %[[CMP:.+]] = arith.cmpi slt, %[[CMAX]], %[[SCALED]]
- // CHECK: %[[CLAMP:.+]] = arith.select %[[CMP]], %[[CMAX]], %[[SEL]]
+ // CHECK: %[[LOW:.+]] = arith.maxsi %[[CMIN]], %[[SCALED]]
+ // CHECK: %[[CLAMP:.+]] = arith.minsi %[[CMAX]], %[[LOW]]
// CHECK: %[[TRUNC:.+]] = arith.trunci %[[CLAMP]]
// CHECK: linalg.yield %[[TRUNC]]
%0 = tosa.avg_pool2d %arg0 {acc_type = i32, pad = array<i64: 1, 1, 1, 1>, kernel = array<i64: 4, 4>, stride = array<i64: 1, 1>} : (tensor<1x6x34x62xi8>) -> tensor<1x5x33x62xi8>
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-resize.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-resize.mlir
index aedc6b7fae4a45..468e92e2a2661f 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-resize.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-resize.mlir
@@ -167,22 +167,18 @@ func.func @resize_nearest_int(%arg0: tensor<1x15x13x1xi8>) -> () {
// CHECK: %[[PRED_Y:.*]] = arith.cmpi sge, %[[D_Y_DOUBLE]], %[[SCALE_Y_N]]
// CHECK: %[[VAL_37:.*]] = arith.select %[[PRED_Y]], %[[ONE]], %[[ZERO]]
// CHECK: %[[VAL_39:.*]] = arith.addi %[[I_Y]], %[[VAL_37]]
- // CHECK: %[[VAL_41:.*]] = arith.cmpi slt, %[[VAL_39]], %[[ZERO]]
- // CHECK: %[[VAL_42:.*]] = arith.select %[[VAL_41]], %[[ZERO]], %[[VAL_39]]
- // CHECK: %[[VAL_43:.*]] = arith.cmpi slt, %[[Y_MAX]], %[[VAL_39]]
- // CHECK: %[[VAL_44:.*]] = arith.select %[[VAL_43]], %[[Y_MAX]], %[[VAL_42]]
- // CHECK: %[[IDY:.+]] = arith.index_cast %[[VAL_44]]
+ // CHECK: %[[LOWER:.*]] = arith.maxsi %[[ZERO]], %[[VAL_39]]
+ // CHECK: %[[CLAMPED:.*]] = arith.minsi %[[Y_MAX]], %[[LOWER]]
+ // CHECK: %[[IDY:.+]] = arith.index_cast %[[CLAMPED]]
// Compute the offset and bound for the X position.
// CHECK: %[[D_X_DOUBLE:.*]] = arith.shli %[[D_X]], %[[ONE]]
// CHECK: %[[PRED_X:.*]] = arith.cmpi sge, %[[D_X_DOUBLE]], %[[SCALE_X_N]]
// CHECK: %[[VAL_38:.*]] = arith.select %[[PRED_X]], %[[ONE]], %[[ZERO]]
// CHECK: %[[VAL_40:.*]] = arith.addi %[[I_X]], %[[VAL_38]]
- // CHECK: %[[VAL_45:.*]] = arith.cmpi slt, %[[VAL_40]], %[[ZERO]]
- // CHECK: %[[VAL_46:.*]] = arith.select %[[VAL_45]], %[[ZERO]], %[[VAL_40]]
- // CHECK: %[[VAL_47:.*]] = arith.cmpi slt, %[[X_MAX]], %[[VAL_40]]
- // CHECK: %[[VAL_48:.*]] = arith.select %[[VAL_47]], %[[X_MAX]], %[[VAL_46]]
- // CHECK: %[[IDX:.+]] = arith.index_cast %[[VAL_48]]
+ // CHECK: %[[LOWER:.*]] = arith.maxsi %[[ZERO]], %[[VAL_40]]
+ // CHECK: %[[CLAMPED:.*]] = arith.minsi %[[X_MAX]], %[[LOWER]]
+ // CHECK: %[[IDX:.+]] = arith.index_cast %[[CLAMPED]]
// CHECK: %[[EXTRACT:.+]] = tensor.extract %arg0[%[[IDX_0]], %[[IDY]], %[[IDX]], %[[IDX_3]]]
// CHECK: linalg.yield %[[EXTRACT]]
@@ -236,29 +232,21 @@ func.func @resize_bilinear_int(%arg0: tensor<1x19x20x1xi8>) {
// Bound check each dimension.
- // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[I_Y]], %[[ZERO]]
- // CHECK: %[[BOUND:.*]] = arith.select %[[PRED]], %[[ZERO]], %[[I_Y]]
- // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[Y_MAX]], %[[I_Y]]
- // CHECK: %[[YLO:.*]] = arith.select %[[PRED]], %[[Y_MAX]], %[[BOUND]]
+ // CHECK: %[[BOUND:.*]] = arith.maxsi %[[ZERO]], %[[I_Y]]
+ // CHECK: %[[YLO:.*]] = arith.minsi %[[Y_MAX]], %[[BOUND]]
- // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[Y1]], %[[ZERO]]
- // CHECK: %[[BOUND:.*]] = arith.select %[[PRED]], %[[ZERO]], %[[Y1]]
- // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[Y_MAX]], %[[Y1]]
- // CHECK: %[[YHI:.*]] = arith.select %[[PRED]], %[[Y_MAX]], %[[BOUND]]
+ // CHECK: %[[BOUND:.*]] = arith.maxsi %[[ZERO]], %[[Y1]]
+ // CHECK: %[[YHI:.*]] = arith.minsi %[[Y_MAX]], %[[BOUND]]
// CHECK: %[[YLOI:.+]] = arith.index_cast %[[YLO]]
// CHECK: %[[YHII:.+]] = arith.index_cast %[[YHI]]
// CHECK: %[[X1:.*]] = arith.addi %[[I_X]], %[[ONE]]
- // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[I_X]], %[[ZERO]]
- // CHECK: %[[BOUND:.*]] = arith.select %[[PRED]], %[[ZERO]], %[[I_X]]
- // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[X_MAX]], %[[I_X]]
- // CHECK: %[[XLO:.*]] = arith.select %[[PRED]], %[[X_MAX]], %[[BOUND]]
+ // CHECK: %[[BOUND:.*]] = arith.maxsi %[[ZERO]], %[[I_X]]
+ // CHECK: %[[XLO:.*]] = arith.minsi %[[X_MAX]], %[[BOUND]]
- // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[X1]], %[[ZERO]]
- // CHECK: %[[BOUND:.*]] = arith.select %[[PRED]], %[[ZERO]], %[[X1]]
- // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[X_MAX]], %[[X1]]
- // CHECK: %[[XHI:.*]] = arith.select %[[PRED]], %[[X_MAX]], %[[BOUND]]
+ // CHECK: %[[BOUND:.*]] = arith.maxsi %[[ZERO]], %[[X1]]
+ // CHECK: %[[XHI:.*]] = arith.minsi %[[X_MAX]], %[[BOUND]]
// CHECK: %[[XLOI:.+]] = arith.index_cast %[[XLO]]
// CHECK: %[[XHII:.+]] = arith.index_cast %[[XHI]]
@@ -352,21 +340,17 @@ func.func @resize_nearest_fp32(%input: tensor<1x50x48x1xf32>) -> () {
// CHECK: %[[PRED_Y:.*]] = arith.cmpf oge, %[[D_Y]], %[[HALF]]
// CHECK: %[[ROUND_Y:.*]] = arith.select %[[PRED_Y]], %[[ONE]], %[[ZERO]]
// CHECK: %[[VAL_48:.*]] = arith.addi %[[VAL_39]], %[[ROUND_Y]]
- // CHECK: %[[VAL_50:.*]] = arith.cmpi slt, %[[VAL_48]], %[[ZERO]]
- // CHECK: %[[VAL_51:.*]] = arith.select %[[VAL_50]], %[[ZERO]], %[[VAL_48]]
- // CHECK: %[[VAL_52:.*]] = arith.cmpi slt, %[[YMAX]], %[[VAL_48]]
- // CHECK: %[[VAL_53:.*]] = arith.select %[[VAL_52]], %[[YMAX]], %[[VAL_51]]
- // CHECK: %[[IDY:.*]] = arith.index_cast %[[VAL_53]]
+ // CHECK: %[[LOWER:.*]] = arith.maxsi %[[ZERO]], %[[VAL_48]]
+ // CHECK: %[[CLAMPED:.*]] = arith.minsi %[[YMAX]], %[[LOWER]]
+ // CHECK: %[[IDY:.*]] = arith.index_cast %[[CLAMPED]]
// CHECK-DAG: %[[HALF:.*]] = arith.constant 5.000000e-01
// CHECK: %[[PRED_X:.*]] = arith.cmpf oge, %[[D_X]], %[[HALF]]
// CHECK: %[[ROUND_X:.*]] = arith.select %[[PRED_X]], %[[ONE]], %[[ZERO]]
// CHECK: %[[VAL_49:.*]] = arith.addi %[[VAL_40]], %[[ROUND_X]]
- // CHECK: %[[VAL_54:.*]] = arith.cmpi slt, %[[VAL_49]], %[[ZERO]]
- // CHECK: %[[VAL_55:.*]] = arith.select %[[VAL_54]], %[[ZERO]], %[[VAL_49]]
- // CHECK: %[[VAL_56:.*]] = arith.cmpi slt, %[[XMAX]], %[[VAL_49]]
- // CHECK: %[[VAL_57:.*]] = arith.select %[[VAL_56]], %[[XMAX]], %[[VAL_55]]
- // CHECK: %[[IDX:.*]] = arith.index_cast %[[VAL_57]]
+ // CHECK: %[[LOWER:.*]] = arith.maxsi %[[ZERO]], %[[VAL_49]]
+ // CHECK: %[[CLAMPED:.*]] = arith.minsi %[[XMAX]], %[[LOWER]]
+ // CHECK: %[[IDX:.*]] = arith.index_cast %[[CLAMPED]]
// CHECK: %[[EXTRACT:.+]] = tensor.extract %arg0[%[[IDX0]], %[[IDY]], %[[IDX]], %[[IDX3]]]
// CHECK: linalg.yield %[[EXTRACT]]
@@ -429,28 +413,21 @@ func.func @resize_bilinear_fp(%input: tensor<1x23x24x1xf32>) -> () {
// CHECK: %[[Y1:.*]] = arith.addi %[[I_Y]], %[[ONE]]
- // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[I_Y]], %[[ZERO]]
- // CHECK: %[[BOUND:.*]] = arith.select %[[PRED]], %[[ZERO]], %[[I_Y]]
- // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[Y_MAX]], %[[I_Y]]
- // CHECK: %[[YLO:.*]] = arith.select %[[PRED]], %[[Y_MAX]], %[[BOUND]]
+ // CHECK: %[[BOUND:.*]] = arith.maxsi %[[ZERO]], %[[I_Y]]
+ // CHECK: %[[YLO:.*]] = arith.minsi %[[Y_MAX]], %[[BOUND]]
+
+ // CHECK: %[[BOUND:.*]] = arith.maxsi %[[ZERO]], %[[Y1]]
+ // CHECK: %[[YHI:.*]] = arith.minsi %[[Y_MAX]], %[[BOUND]]
- // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[Y1]], %[[ZERO]]
- // CHECK: %[[BOUND:.*]] = arith.select %[[PRED]], %[[ZERO]], %[[Y1]]
- // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[Y_MAX]], %[[Y1]]
- // CHECK: %[[YHI:.*]] = arith.select %[[PRED]], %[[Y_MAX]], %[[BOUND]]
// CHECK: %[[YLOI:.+]] = arith.index_cast %[[YLO]]
// CHECK: %[[YHII:.+]] = arith.index_cast %[[YHI]]
// CHECK: %[[X1:.*]] = arith.addi %[[I_X]], %[[ONE]]
- // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[I_X]], %[[ZERO]]
- // CHECK: %[[BOUND:.*]] = arith.select %[[PRED]], %[[ZERO]], %[[I_X]]
- // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[X_MAX]], %[[I_X]]
- // CHECK: %[[XLO:.*]] = arith.select %[[PRED]], %[[X_MAX]], %[[BOUND]]
-
- // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[X1]], %[[ZERO]]
- // CHECK: %[[BOUND:.*]] = arith.select %[[PRED]], %[[ZERO]], %[[X1]]
- // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[X_MAX]], %[[X1]]
- // CHECK: %[[XHI:.*]] = arith.select %[[PRED]], %[[X_MAX]], %[[BOUND]]
+ // CHECK: %[[BOUND:.*]] = arith.maxsi %[[ZERO]], %[[I_X]]
+ // CHECK: %[[XLO:.*]] = arith.minsi %[[X_MAX]], %[[BOUND]]
+
+ // CHECK: %[[BOUND:.*]] = arith.maxsi %[[ZERO]], %[[X1]]
+ // CHECK: %[[XHI:.*]] = arith.minsi %[[X_MAX]], %[[BOUND]]
// CHECK: %[[XLOI:.+]] = arith.index_cast %[[XLO]]
// CHECK: %[[XHII:.+]] = arith.index_cast %[[XHI]]
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index fc22a436526a6f..febe74e8767465 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -684,18 +684,16 @@ func.func @test_simple_i32(%arg0: tensor<1xi32>) -> () {
%16 = tosa.select %14, %0, %1 : (tensor<1xi1>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
// CHECK: linalg.generic
- // CHECK: arith.cmpi
- // CHECK: select
+ // CHECK: arith.maxsi
%17 = tosa.maximum %0, %1 : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
// CHECK: linalg.generic
- // CHECK: arith.cmpi
- // CHECK: select
+ // CHECK: arith.minsi
%18 = tosa.minimum %0, %1 : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
// CHECK: linalg.generic
- // CHECK: arith.cmpi
- // CHECK: select
+ // CHECK-DAG: arith.maxsi
+ // CHECK-DAG: arith.minsi
%19 = tosa.clamp %0 {min_int = 1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32>
// CHECK: linalg.generic
@@ -717,9 +715,8 @@ func.func @test_simple_i32(%arg0: tensor<1xi32>) -> () {
// CHECK: linalg.generic
// CHECK: arith.constant 0
- // CHECK: arith.cmpi sgt
// CHECK: arith.subi
- // CHECK: select
+ // CHECK: arith.maxsi
%24 = tosa.abs %arg0 : (tensor<1xi32>) -> tensor<1xi32>
return
@@ -745,20 +742,16 @@ func.func @test_i8(%arg0: tensor<1xi8>) -> () {
// CHECK: ^bb0(%[[ARG1:.+]]: i8,
// CHECK-DAG: %[[C127:.+]] = arith.constant -127
// CHECK-DAG: %[[C126:.+]] = arith.constant 126
- // CHECK-DAG: %[[CMP1:.+]] = arith.cmpi slt, %[[ARG1]], %[[C127]]
- // CHECK-DAG: %[[SEL1:.+]] = arith.select %[[CMP1]], %[[C127]]
- // CHECK-DAG: %[[CMP2:.+]] = arith.cmpi slt, %[[C126]], %[[ARG1]]
- // CHECK: %[[SEL2:.+]] = arith.select %[[CMP2]], %[[C126]], %[[SEL1]]
+ // CHECK-DAG: %[[LOWER:.+]] = arith.maxsi %[[C127]], %[[ARG1]]
+ // CHECK-DAG: %[[CLAMPED:.+]] = arith.minsi %[[C126]], %[[LOWER]]
%0 = tosa.clamp %arg0 {min_int = -127 : i64, max_int = 126 : i64, min_fp = 0.0 : f32, max_fp = 0.0 : f32} : (tensor<1xi8>) -> tensor<1xi8>
// CHECK: linalg.generic
// CHECK: ^bb0(%[[ARG1:.+]]: i8,
// CHECK-DAG: %[[C128:.+]] = arith.constant -128
// CHECK-DAG: %[[C127:.+]] = arith.constant 127
- // CHECK-DAG: %[[CMP1:.+]] = arith.cmpi slt, %[[ARG1]], %[[C128]]
- // CHECK-DAG: %[[SEL1:.+]] = arith.select %[[CMP1]], %[[C128]]
- // CHECK-DAG: %[[CMP2:.+]] = arith.cmpi slt, %[[C127]], %[[ARG1]]
- // CHECK: %[[SEL2:.+]] = arith.select %[[CMP2]], %[[C127]], %[[SEL1]]
+ // CHECK-DAG: %[[LOWER:.+]] = arith.maxsi %[[C128]], %[[ARG1]]
+ // CHECK-DAG: %[[CLAMPED:.+]] = arith.minsi %[[C127]], %[[LOWER]]
%1 = tosa.clamp %arg0 {min_int = -130 : i64, max_int = 130 : i64, min_fp = 0.0 : f32, max_fp = 0.0 : f32} : (tensor<1xi8>) -> tensor<1xi8>
return
@@ -814,10 +807,8 @@ func.func @test_negate_quantized(%arg0: tensor<1xi8>) -> () {
// CHECK: [[SUB:%.+]] = arith.subi [[ZERO]], [[EXT]]
// CHECK: [[MIN:%.+]] = arith.constant -128
// CHECK: [[MAX:%.+]] = arith.constant 127
- // CHECK: [[PRED1:%.+]] = arith.cmpi slt, [[SUB]], [[MIN]]
- // CHECK: [[LBOUND:%.+]] = arith.select [[PRED1]], [[MIN]], [[SUB]]
- // CHECK: [[PRED2:%.+]] = arith.cmpi slt, [[MAX]], [[SUB]]
- // CHECK: [[UBOUND:%.+]] = arith.select [[PRED2]], [[MAX]], [[LBOUND]]
+ // CHECK: [[LBOUND:%.+]] = arith.maxsi [[MIN]], [[SUB]]
+ // CHECK: [[UBOUND:%.+]] = arith.minsi [[MAX]], [[LBOUND]]
// CHECK: [[TRUNC:%.+]] = arith.trunci [[UBOUND]]
// CHECK: linalg.yield [[TRUNC]]
%0 = tosa.negate %arg0 {quantization_info = #tosa.unary_quant<input_zp = 0, output_zp = 0>} : (tensor<1xi8>) -> tensor<1xi8>
@@ -1009,15 +1000,13 @@ func.func @reduce_int(%arg0: tensor<5x4xi32>) -> () {
// CHECK: arith.constant 2147483647 : i32
// CHECK: linalg.fill
// CHECK: linalg.reduce
- // CHECK: arith.cmpi slt
- // CHECK: select
+ // CHECK: arith.minsi
%3 = tosa.reduce_min %arg0 {axis = 0 : i32} : (tensor<5x4xi32>) -> tensor<1x4xi32>
// CHECK: arith.constant -2147483648 : i32
// CHECK: linalg.fill
// CHECK: linalg.reduce
- // CHECK: arith.cmpi sgt
- // CHECK: select
+ // CHECK: arith.maxsi
%4 = tosa.reduce_max %arg0 {axis = 0 : i32} : (tensor<5x4xi32>) -> tensor<1x4xi32>
return
}
@@ -1066,10 +1055,8 @@ func.func @rescale_i8(%arg0 : tensor<2xi8>) -> () {
// CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C22]]
// CHECK-DAG: [[CMIN:%.+]] = arith.constant -128
// CHECK-DAG: [[CMAX:%.+]] = arith.constant 127
- // CHECK-DAG: [[MINLT:%.+]] = arith.cmpi slt, [[SCALED_ZEROED]], [[CMIN]]
- // CHECK-DAG: [[MAXLT:%.+]] = arith.cmpi slt, [[CMAX]], [[SCALED_ZEROED]]
- // CHECK-DAG: [[LOWER:%.+]] = arith.select [[MINLT]], [[CMIN]], [[SCALED_ZEROED]]
- // CHECK-DAG: [[BOUNDED:%.+]] = arith.select [[MAXLT]], [[CMAX]], [[LOWER]]
+ // CHECK-DAG: [[LOWER:%.+]] = arith.maxsi [[CMIN]], [[SCALED_ZEROED]]
+ // CHECK-DAG: [[BOUNDED:%.+]] = arith.minsi [[CMAX]], [[LOWER]]
// CHECK-DAG: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]]
// CHECK-DAG: linalg.yield [[TRUNC]]
%0 = tosa.rescale %arg0 {input_zp = 17 : i32, output_zp = 22 : i32, multiplier = array<i32: 19689>, shift = array<i8: 15>, scale32 = false, double_round = false, per_channel = false} : (tensor<2xi8>) -> tensor<2xi8>
@@ -1087,10 +1074,8 @@ func.func @rescale_i8(%arg0 : tensor<2xi8>) -> () {
// CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C22]]
// CHECK-DAG: [[CMIN:%.+]] = arith.constant 0
// CHECK-DAG: [[CMAX:%.+]] = arith.constant 255
- // CHECK-DAG: [[MINLT:%.+]] = arith.cmpi slt, [[SCALED_ZEROED]], [[CMIN]]
- // CHECK-DAG: [[LOWER:%.+]] = arith.select [[MINLT]], [[CMIN]], [[SCALED_ZEROED]]
- // CHECK-DAG: [[MAXLT:%.+]] = arith.cmpi slt, [[CMAX]], [[SCALED_ZEROED]]
- // CHECK-DAG: [[BOUNDED:%.+]] = arith.select [[MAXLT]], [[CMAX]], [[LOWER]]
+ // CHECK-DAG: [[LOWER:%.+]] = arith.maxsi [[CMIN]], [[SCALED_ZEROED]]
+ // CHECK-DAG: [[BOUNDED:%.+]] = arith.minsi [[CMAX]], [[LOWER]]
// CHECK-DAG: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]]
// CHECK-DAG: [[CAST:%.+]] = builtin.unrealized_conversion_cast [[TRUNC]] : i8 to ui8
// CHECK: linalg.yield [[CAST]]
@@ -1160,10 +1145,8 @@ func.func @rescale_ui8(%arg0 : tensor<2xui8>) -> () {
// CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C22]]
// CHECK-DAG: [[CMIN:%.+]] = arith.constant -128
// CHECK-DAG: [[CMAX:%.+]] = arith.constant 127
- // CHECK-DAG: [[MINLT:%.+]] = arith.cmpi slt, [[SCALED_ZEROED]], [[CMIN]]
- // CHECK-DAG: [[LOWER:%.+]] = arith.select [[MINLT]], [[CMIN]], [[SCALED_ZEROED]]
- // CHECK-DAG: [[MAXLT:%.+]] = arith.cmpi slt, [[CMAX]], [[SCALED_ZEROED]]
- // CHECK-DAG: [[BOUNDED:%.+]] = arith.select [[MAXLT]], [[CMAX]], [[LOWER]]
+ // CHECK-DAG: [[LOWER:%.+]] = arith.maxsi [[CMIN]], [[SCALED_ZEROED]]
+ // CHECK-DAG: [[BOUNDED:%.+]] = arith.minsi [[CMAX]], [[LOWER]]
// CHECK-DAG: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]]
// CHECK: linalg.yield [[TRUNC]]
%0 = tosa.rescale %arg0 {input_zp = 17 : i32, output_zp = 22 : i32, multiplier = array<i32: 19689>, shift = array<i8: 15>, scale32 = false, double_round = false, per_channel = false} : (tensor<2xui8>) -> tensor<2xi8>
@@ -1192,10 +1175,8 @@ func.func @rescale_per_channel(%arg0 : tensor<3xi8>) -> (tensor<3xi8>) {
// CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C252]]
// CHECK-DAG: [[CMIN:%.+]] = arith.constant -128
// CHECK-DAG: [[CMAX:%.+]] = arith.constant 127
- // CHECK-DAG: [[MINLT:%.+]] = arith.cmpi slt, [[SCALED_ZEROED]], [[CMIN]]
- // CHECK-DAG: [[MAXLT:%.+]] = arith.cmpi slt, [[CMAX]], [[SCALED_ZEROED]]
- // CHECK-DAG: [[LOWER:%.+]] = arith.select [[MINLT]], [[CMIN]], [[SCALED_ZEROED]]
- // CHECK-DAG: [[BOUNDED:%.+]] = arith.select [[MAXLT]], [[CMAX]], [[LOWER]]
+ // CHECK-DAG: [[LOWER:%.+]] = arith.maxsi [[CMIN]], [[SCALED_ZEROED]]
+ // CHECK-DAG: [[BOUNDED:%.+]] = arith.minsi [[CMAX]], [[LOWER]]
// CHECK-DAG: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]]
// CHECK-DAG: linalg.yield [[TRUNC]]
%0 = tosa.rescale %arg0 {input_zp = 243 : i32, output_zp = 252 : i32, multiplier = array<i32: 42, 43, 44>, shift = array<i8: 14, 15, 64>, scale32 = false, double_round = false, per_channel = false} : (tensor<3xi8>) -> tensor<3xi8>
diff --git a/mlir/test/Transforms/parametric-tiling.mlir b/mlir/test/Transforms/parametric-tiling.mlir
index e3be41e702ec45..f6cace5397def0 100644
--- a/mlir/test/Transforms/parametric-tiling.mlir
+++ b/mlir/test/Transforms/parametric-tiling.mlir
@@ -40,12 +40,10 @@ func.func @rectangular(%arg0: memref<?x?xf32>) {
scf.for %i = %c2 to %c44 step %c1 {
// Upper bound for the inner loop min(%i + %step, %c44).
// COMMON: %[[stepped:.*]] = arith.addi %[[i]], %[[step]]
- // COMMON-NEXT: arith.cmpi slt, %c44, %[[stepped]]
- // COMMON-NEXT: %[[ub:.*]] = arith.select {{.*}}, %c44, %[[stepped]]
+ // COMMON-NEXT: %[[ub:.*]] = arith.minsi %c44, %[[stepped]]
//
// TILE_74: %[[stepped2:.*]] = arith.addi %[[j]], %[[step2]]
- // TILE_74-NEXT: arith.cmpi slt, %c44, %[[stepped2]]
- // TILE_74-NEXT: %[[ub2:.*]] = arith.select {{.*}}, %c44, %[[stepped2]]
+ // TILE_74-NEXT: %[[ub2:.*]] = arith.minsi %c44, %[[stepped2]]
// Created inner scf.
// COMMON:scf.for %[[ii:.*]] = %[[i]] to %[[ub:.*]] step %c1
@@ -108,11 +106,9 @@ func.func @triangular(%arg0: memref<?x?xf32>) {
scf.for %i = %c2 to %c44 step %c1 {
// Upper bound for the inner loop min(%i + %step, %c44).
// COMMON: %[[stepped:.*]] = arith.addi %[[i]], %[[step]]
- // COMMON-NEXT: arith.cmpi slt, %c44, %[[stepped]]
- // COMMON-NEXT: %[[ub:.*]] = arith.select {{.*}}, %c44, %[[stepped]]
+ // COMMON-NEXT: %[[ub:.*]] = arith.minsi %c44, %[[stepped]]
// TILE_74: %[[stepped2:.*]] = arith.addi %[[j]], %[[step2]]
- // TILE_74-NEXT: arith.cmpi slt, %[[i]], %[[stepped2]]
- // TILE_74-NEXT: %[[ub2:.*]] = arith.select {{.*}}, %[[i]], %[[stepped2]]
+ // TILE_74-NEXT: %[[ub2:.*]] = arith.minsi %[[i]], %[[stepped2]]
//
// Created inner scf.
// COMMON:scf.for %[[ii:.*]] = %[[i]] to %[[ub:.*]] step %c1
>From 6012fd8190ae8663269bc05393e0ddc9bd171fb8 Mon Sep 17 00:00:00 2001
From: Michael Levesque-Dion <mlevesquedion at google.com>
Date: Wed, 21 Feb 2024 12:03:46 -0800
Subject: [PATCH 2/2] Drop trivial braces around if/else
---
mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp | 5 ++---
1 file changed, 2 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
index 98cdfc63252711..e69f9c837ca1d6 100644
--- a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
+++ b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
@@ -49,11 +49,10 @@ static Value buildMinMaxReductionSeq(Location loc,
auto valueIt = values.begin();
Value value = *valueIt++;
for (; valueIt != values.end(); ++valueIt) {
- if (predicate == arith::CmpIPredicate::sgt) {
+ if (predicate == arith::CmpIPredicate::sgt)
value = builder.create<arith::MaxSIOp>(loc, value, *valueIt);
- } else {
+ else
value = builder.create<arith::MinSIOp>(loc, value, *valueIt);
- }
}
return value;
More information about the Mlir-commits
mailing list