[Mlir-commits] [mlir] [mlir] Use arith max or min ops instead of cmp + select (PR #82178)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Feb 18 09:11:13 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: mlevesquedion (mlevesquedion)
<details>
<summary>Changes</summary>
I believe the semantics should be the same, but this saves 1 op and simplifies the code.
For example, the following two instructions:
```
%2 = cmp sgt %0, %1
%3 = select %2, %0, %1
```
Are equivalent to:
```
%2 = maxsi %0 %1
```
---
Patch is 44.05 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/82178.diff
13 Files Affected:
- (modified) mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp (+8-9)
- (modified) mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp (+2-6)
- (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+5-15)
- (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp (+2-7)
- (modified) mlir/lib/Dialect/SCF/Utils/Utils.cpp (+2-4)
- (modified) mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp (+2-7)
- (modified) mlir/test/Conversion/AffineToStandard/lower-affine.mlir (+16-26)
- (modified) mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir (+1-2)
- (modified) mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir (+6-12)
- (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir (+15-29)
- (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-resize.mlir (+30-53)
- (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+21-40)
- (modified) mlir/test/Transforms/parametric-tiling.mlir (+4-8)
``````````diff
diff --git a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
index 15ad6d8cdf629d..98cdfc63252711 100644
--- a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
+++ b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
@@ -34,12 +34,7 @@ using namespace mlir::affine;
using namespace mlir::vector;
/// Given a range of values, emit the code that reduces them with "min" or "max"
-/// depending on the provided comparison predicate. The predicate defines which
-/// comparison to perform, "lt" for "min", "gt" for "max" and is used for the
-/// `cmpi` operation followed by the `select` operation:
-///
-/// %cond = arith.cmpi "predicate" %v0, %v1
-/// %result = select %cond, %v0, %v1
+/// depending on the provided comparison predicate, sgt for max and slt for min.
///
/// Multiple values are scanned in a linear sequence. This creates a data
/// dependences that wouldn't exist in a tree reduction, but is easier to
@@ -48,13 +43,17 @@ static Value buildMinMaxReductionSeq(Location loc,
arith::CmpIPredicate predicate,
ValueRange values, OpBuilder &builder) {
assert(!values.empty() && "empty min/max chain");
+ assert(predicate == arith::CmpIPredicate::sgt ||
+ predicate == arith::CmpIPredicate::slt);
auto valueIt = values.begin();
Value value = *valueIt++;
for (; valueIt != values.end(); ++valueIt) {
- auto cmpOp = builder.create<arith::CmpIOp>(loc, predicate, value, *valueIt);
- value = builder.create<arith::SelectOp>(loc, cmpOp.getResult(), value,
- *valueIt);
+ if (predicate == arith::CmpIPredicate::sgt) {
+ value = builder.create<arith::MaxSIOp>(loc, value, *valueIt);
+ } else {
+ value = builder.create<arith::MinSIOp>(loc, value, *valueIt);
+ }
}
return value;
diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
index a3e51aeed0735a..de649f730ee9d7 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
+++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
@@ -147,9 +147,7 @@ LogicalResult BroadcastOpConverter::matchAndRewrite(
// Find the maximum rank
Value maxRank = ranks.front();
for (Value v : llvm::drop_begin(ranks, 1)) {
- Value rankIsGreater =
- lb.create<arith::CmpIOp>(arith::CmpIPredicate::ugt, v, maxRank);
- maxRank = lb.create<arith::SelectOp>(rankIsGreater, v, maxRank);
+ maxRank = lb.create<arith::MaxUIOp>(v, maxRank);
}
// Calculate the difference of ranks and the maximum rank for later offsets.
@@ -262,9 +260,7 @@ LogicalResult IsBroadcastableOpConverter::matchAndRewrite(
// Find the maximum rank
Value maxRank = ranks.front();
for (Value v : llvm::drop_begin(ranks, 1)) {
- Value rankIsGreater =
- lb.create<arith::CmpIOp>(arith::CmpIPredicate::ugt, v, maxRank);
- maxRank = lb.create<arith::SelectOp>(rankIsGreater, v, maxRank);
+ maxRank = lb.create<arith::MaxUIOp>(v, maxRank);
}
// Calculate the difference of ranks and the maximum rank for later offsets.
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index f4f6dadfb37166..7eb32ebe3228fb 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -61,10 +61,8 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
if (isa<tosa::AbsOp>(op) && isa<IntegerType>(elementTy)) {
auto zero = rewriter.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(elementTy));
- auto cmp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt,
- args[0], zero);
auto neg = rewriter.create<arith::SubIOp>(loc, zero, args[0]);
- return rewriter.create<arith::SelectOp>(loc, cmp, args[0], neg);
+ return rewriter.create<arith::MaxSIOp>(loc, args[0], neg);
}
// tosa::AddOp
@@ -348,9 +346,7 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
}
if (isa<tosa::MaximumOp>(op) && elementTy.isSignlessInteger()) {
- auto predicate = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::sgt, args[0], args[1]);
- return rewriter.create<arith::SelectOp>(loc, predicate, args[0], args[1]);
+ return rewriter.create<arith::MaxSIOp>(loc, args[0], args[1]);
}
// tosa::MinimumOp
@@ -359,9 +355,7 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
}
if (isa<tosa::MinimumOp>(op) && elementTy.isSignlessInteger()) {
- auto predicate = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::slt, args[0], args[1]);
- return rewriter.create<arith::SelectOp>(loc, predicate, args[0], args[1]);
+ return rewriter.create<arith::MinSIOp>(loc, args[0], args[1]);
}
// tosa::CeilOp
@@ -1000,9 +994,7 @@ static Value createLinalgBodyCalculationForReduceOp(Operation *op,
}
if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy)) {
- auto predicate = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::slt, args[0], args[1]);
- return rewriter.create<arith::SelectOp>(loc, predicate, args[0], args[1]);
+ return rewriter.create<arith::MinSIOp>(loc, args[0], args[1]);
}
if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy)) {
@@ -1010,9 +1002,7 @@ static Value createLinalgBodyCalculationForReduceOp(Operation *op,
}
if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy)) {
- auto predicate = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::sgt, args[0], args[1]);
- return rewriter.create<arith::SelectOp>(loc, predicate, args[0], args[1]);
+ return rewriter.create<arith::MaxSIOp>(loc, args[0], args[1]);
}
if (isa<tosa::ReduceAllOp>(op) && elementTy.isInteger(1))
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 607a603cca810f..3f39cbf03a9a80 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -845,10 +845,7 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
auto padVal = rewriter.create<arith::ConstantIndexOp>(loc, pad);
Value dpos = rewriter.create<arith::SubIOp>(loc, pos, padVal);
- Value cmp = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::slt, dpos, zero);
- Value offset =
- rewriter.create<arith::SelectOp>(loc, cmp, dpos, zero);
+ Value offset = rewriter.create<arith::MinSIOp>(loc, dpos, zero);
return rewriter.create<arith::AddIOp>(loc, valid, offset)
->getResult(0);
};
@@ -868,9 +865,7 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
// Determine how much padding was included.
val = padFn(val, left, pad[i * 2]);
val = padFn(val, right, pad[i * 2 + 1]);
- Value cmp = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::slt, val, one);
- return rewriter.create<arith::SelectOp>(loc, cmp, one, val);
+ return rewriter.create<arith::MaxSIOp>(loc, one, val);
};
// Compute the indices from either end.
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 536c02feca1bd5..502d7e197a6f6b 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -791,10 +791,8 @@ static Loops stripmineSink(scf::ForOp forOp, Value factor,
// Insert newForOp before the terminator of `t`.
auto b = OpBuilder::atBlockTerminator((t.getBody()));
Value stepped = b.create<arith::AddIOp>(t.getLoc(), iv, forOp.getStep());
- Value less = b.create<arith::CmpIOp>(t.getLoc(), arith::CmpIPredicate::slt,
- forOp.getUpperBound(), stepped);
- Value ub = b.create<arith::SelectOp>(t.getLoc(), less,
- forOp.getUpperBound(), stepped);
+ Value ub =
+ b.create<arith::MinSIOp>(t.getLoc(), forOp.getUpperBound(), stepped);
// Splice [begin, begin + nOps - 1) into `newForOp` and replace uses.
auto newForOp = b.create<scf::ForOp>(t.getLoc(), iv, ub, originalStep);
diff --git a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
index ee428b201d0073..4fc97115064f33 100644
--- a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
+++ b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
@@ -39,13 +39,8 @@ Value mlir::tosa::clampFloatHelper(Location loc, Value arg, Value min,
Value mlir::tosa::clampIntHelper(Location loc, Value arg, Value min, Value max,
OpBuilder &rewriter) {
- auto smallerThanMin =
- rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, arg, min);
- auto minOrArg =
- rewriter.create<arith::SelectOp>(loc, smallerThanMin, min, arg);
- auto largerThanMax =
- rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, max, arg);
- return rewriter.create<arith::SelectOp>(loc, largerThanMax, max, minOrArg);
+ auto minOrArg = rewriter.create<arith::MaxSIOp>(loc, min, arg);
+ return rewriter.create<arith::MinSIOp>(loc, max, minOrArg);
}
bool mlir::tosa::validIntegerRange(IntegerType ty, int64_t value) {
diff --git a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
index 92608135d24b08..00d7b6b8d65f67 100644
--- a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
+++ b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
@@ -371,16 +371,14 @@ func.func @if_for() {
// CHECK-NEXT: %[[c1:.*]] = arith.constant 1 : index
// CHECK-NEXT: for %{{.*}} = %[[c0]] to %[[c42]] step %[[c1]] {
// CHECK-NEXT: %[[cm1:.*]] = arith.constant -1 : index
-// CHECK-NEXT: %[[a:.*]] = arith.muli %{{.*}}, %[[cm1]] : index
-// CHECK-NEXT: %[[b:.*]] = arith.addi %[[a]], %{{.*}} : index
-// CHECK-NEXT: %[[c:.*]] = arith.cmpi sgt, %{{.*}}, %[[b]] : index
-// CHECK-NEXT: %[[d:.*]] = arith.select %[[c]], %{{.*}}, %[[b]] : index
+// CHECK-NEXT: %[[mul0:.*]] = arith.muli %{{.*}}, %[[cm1]] : index
+// CHECK-NEXT: %[[add0:.*]] = arith.addi %[[mul0]], %{{.*}} : index
+// CHECK-NEXT: %[[max:.*]] = arith.maxsi %{{.*}}, %[[add0]] : index
// CHECK-NEXT: %[[c10:.*]] = arith.constant 10 : index
-// CHECK-NEXT: %[[e:.*]] = arith.addi %{{.*}}, %[[c10]] : index
-// CHECK-NEXT: %[[f:.*]] = arith.cmpi slt, %{{.*}}, %[[e]] : index
-// CHECK-NEXT: %[[g:.*]] = arith.select %[[f]], %{{.*}}, %[[e]] : index
+// CHECK-NEXT: %[[add1:.*]] = arith.addi %{{.*}}, %[[c10]] : index
+// CHECK-NEXT: %[[min:.*]] = arith.minsi %{{.*}}, %[[add1]] : index
// CHECK-NEXT: %[[c1_0:.*]] = arith.constant 1 : index
-// CHECK-NEXT: for %{{.*}} = %[[d]] to %[[g]] step %[[c1_0]] {
+// CHECK-NEXT: for %{{.*}} = %[[max]] to %[[min]] step %[[c1_0]] {
// CHECK-NEXT: call @body2(%{{.*}}, %{{.*}}) : (index, index) -> ()
// CHECK-NEXT: }
// CHECK-NEXT: }
@@ -397,25 +395,19 @@ func.func @loop_min_max(%N : index) {
#map_7_values = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4, d5, d6)>
-// Check that the "min" (cmpi slt + select) reduction sequence is emitted
+// Check that the "min" reduction sequence is emitted
// correctly for an affine map with 7 results.
// CHECK-LABEL: func @min_reduction_tree
// CHECK-NEXT: %[[c0:.*]] = arith.constant 0 : index
-// CHECK-NEXT: %[[c01:.+]] = arith.cmpi slt, %{{.*}}, %{{.*}} : index
-// CHECK-NEXT: %[[r01:.+]] = arith.select %[[c01]], %{{.*}}, %{{.*}} : index
-// CHECK-NEXT: %[[c012:.+]] = arith.cmpi slt, %[[r01]], %{{.*}} : index
-// CHECK-NEXT: %[[r012:.+]] = arith.select %[[c012]], %[[r01]], %{{.*}} : index
-// CHECK-NEXT: %[[c0123:.+]] = arith.cmpi slt, %[[r012]], %{{.*}} : index
-// CHECK-NEXT: %[[r0123:.+]] = arith.select %[[c0123]], %[[r012]], %{{.*}} : index
-// CHECK-NEXT: %[[c01234:.+]] = arith.cmpi slt, %[[r0123]], %{{.*}} : index
-// CHECK-NEXT: %[[r01234:.+]] = arith.select %[[c01234]], %[[r0123]], %{{.*}} : index
-// CHECK-NEXT: %[[c012345:.+]] = arith.cmpi slt, %[[r01234]], %{{.*}} : index
-// CHECK-NEXT: %[[r012345:.+]] = arith.select %[[c012345]], %[[r01234]], %{{.*}} : index
-// CHECK-NEXT: %[[c0123456:.+]] = arith.cmpi slt, %[[r012345]], %{{.*}} : index
-// CHECK-NEXT: %[[r0123456:.+]] = arith.select %[[c0123456]], %[[r012345]], %{{.*}} : index
+// CHECK-NEXT: %[[min:.+]] = arith.minsi %{{.*}}, %{{.*}} : index
+// CHECK-NEXT: %[[min_0:.+]] = arith.minsi %[[min]], %{{.*}} : index
+// CHECK-NEXT: %[[min_1:.+]] = arith.minsi %[[min_0]], %{{.*}} : index
+// CHECK-NEXT: %[[min_2:.+]] = arith.minsi %[[min_1]], %{{.*}} : index
+// CHECK-NEXT: %[[min_3:.+]] = arith.minsi %[[min_2]], %{{.*}} : index
+// CHECK-NEXT: %[[min_4:.+]] = arith.minsi %[[min_3]], %{{.*}} : index
// CHECK-NEXT: %[[c1:.*]] = arith.constant 1 : index
-// CHECK-NEXT: for %{{.*}} = %[[c0]] to %[[r0123456]] step %[[c1]] {
+// CHECK-NEXT: for %{{.*}} = %[[c0]] to %[[min_4]] step %[[c1]] {
// CHECK-NEXT: call @body(%{{.*}}) : (index) -> ()
// CHECK-NEXT: }
// CHECK-NEXT: return
@@ -690,8 +682,7 @@ func.func @affine_min(%arg0: index, %arg1: index) -> index{
// CHECK: %[[Cm2:.*]] = arith.constant -1
// CHECK: %[[neg2:.*]] = arith.muli %[[ARG0]], %[[Cm2:.*]]
// CHECK: %[[second:.*]] = arith.addi %[[ARG1]], %[[neg2]]
- // CHECK: %[[cmp:.*]] = arith.cmpi slt, %[[first]], %[[second]]
- // CHECK: arith.select %[[cmp]], %[[first]], %[[second]]
+ // CHECK: arith.minsi %[[first]], %[[second]]
%0 = affine.min affine_map<(d0,d1) -> (d0 - d1, d1 - d0)>(%arg0, %arg1)
return %0 : index
}
@@ -705,8 +696,7 @@ func.func @affine_max(%arg0: index, %arg1: index) -> index{
// CHECK: %[[Cm2:.*]] = arith.constant -1
// CHECK: %[[neg2:.*]] = arith.muli %[[ARG0]], %[[Cm2:.*]]
// CHECK: %[[second:.*]] = arith.addi %[[ARG1]], %[[neg2]]
- // CHECK: %[[cmp:.*]] = arith.cmpi sgt, %[[first]], %[[second]]
- // CHECK: arith.select %[[cmp]], %[[first]], %[[second]]
+ // CHECK: arith.maxsi %[[first]], %[[second]]
%0 = affine.max affine_map<(d0,d1) -> (d0 - d1, d1 - d0)>(%arg0, %arg1)
return %0 : index
}
diff --git a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
index eb45112b117c0d..87d613986c7c3f 100644
--- a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
@@ -554,8 +554,7 @@ func.func @collapse_shape_dynamic(%arg0 : memref<1x2x?xf32>) -> memref<1x?xf32>
// CHECK: %[[SIZE1_TO_IDX:.*]] = builtin.unrealized_conversion_cast %[[FINAL_SIZE1]] : i64 to index
// CHECK: %[[FINAL_SIZE1:.*]] = builtin.unrealized_conversion_cast %[[SIZE1_TO_IDX]] : index to i64
// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : index) : i64
-// CHECK: %[[IS_MIN_STRIDE1:.*]] = llvm.icmp "slt" %[[STRIDE1]], %[[C1]] : i64
-// CHECK: %[[MIN_STRIDE1:.*]] = llvm.select %[[IS_MIN_STRIDE1]], %[[STRIDE1]], %[[C1]] : i1, i64
+// CHECK: %[[MIN_STRIDE1:.*]] = llvm.intr.smin(%[[STRIDE1]], %[[C1]]) : (i64, i64) -> i64
// CHECK: %[[MIN_STRIDE1_TO_IDX:.*]] = builtin.unrealized_conversion_cast %[[MIN_STRIDE1]] : i64 to index
// CHECK: %[[MIN_STRIDE1:.*]] = builtin.unrealized_conversion_cast %[[MIN_STRIDE1_TO_IDX]] : index to i64
// CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
index cb3af973daee20..3b73c513b7955f 100644
--- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
+++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
@@ -377,10 +377,8 @@ func.func @try_is_broadcastable (%a : tensor<2xindex>, %b : tensor<3xindex>, %c
// CHECK: %[[RANK0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<2xindex>
// CHECK: %[[RANK1:.*]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<3xindex>
// CHECK: %[[RANK2:.*]] = tensor.dim %[[ARG2]], %[[C0]] : tensor<2xindex>
-// CHECK: %[[CMP0:.*]] = arith.cmpi ugt, %[[RANK1]], %[[RANK0]] : index
-// CHECK: %[[LARGER_DIM:.*]] = arith.select %[[CMP0]], %[[RANK1]], %[[RANK0]] : index
-// CHECK: %[[CMP1:.*]] = arith.cmpi ugt, %[[RANK2]], %[[LARGER_DIM]] : index
-// CHECK: %[[MAX_RANK:.*]] = arith.select %[[CMP1]], %[[RANK2]], %[[LARGER_DIM]] : index
+// CHECK: %[[MAX0:.*]] = arith.maxui %[[RANK1]], %[[RANK0]] : index
+// CHECK: %[[MAX_RANK:.*]] = arith.maxui %[[RANK2]], %[[MAX0]] : index
// CHECK: %[[DIM_DIFF0:.*]] = arith.subi %[[MAX_RANK]], %[[RANK0]] : index
// CHECK: %[[DIM_DIFF1:.*]] = arith.subi %[[MAX_RANK]], %[[RANK1]] : index
// CHECK: %[[DIM_DIFF2:.*]] = arith.subi %[[MAX_RANK]], %[[RANK2]] : index
@@ -467,10 +465,8 @@ func.func @broadcast(%a : tensor<2xindex>, %b : tensor<3xindex>, %c : tensor<2xi
// CHECK: %[[RANK0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<2xindex>
// CHECK: %[[RANK1:.*]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<3xindex>
// CHECK: %[[RANK2:.*]] = tensor.dim %[[ARG2]], %[[C0]] : tensor<2xindex>
-// CHECK: %[[CMP0:.*]] = arith.cmpi ugt, %[[RANK1]], %[[RANK0]] : index
-// CHECK: %[[LARGER_DIM:.*]] = arith.select %[[CMP0]], %[[RANK1]], %[[RANK0]] : index
-// CHECK: %[[CMP1:.*]] = arith.cmpi ugt, %[[RANK2]], %[[LARGER_DIM]] : index
-// CHECK: %[[MAX_RANK:.*]] = arith.select %[[CMP1]], %[[RANK2]], %[[LARGER_DIM]] : index
+// CHECK: %[[MAX0:.*]] = arith.maxui %[[RANK1]], %[[RANK0]] : index
+// CHECK: %[[MAX_RANK:.*]] = arith.maxui %[[RANK2]], %[[MAX0]] : index
// CHECK: %[[DIM_DIFF0:.*]] = arith.subi %[[MAX_RANK]], %[[RANK0]] : index
// CHECK: %[[DIM_DIFF1:.*]] = arith.subi %[[MAX_RANK]], %[[RANK1]] : index
// CHECK: %[[DIM_DIFF2:.*]] = arith.subi %[[MAX_RANK]], %[[RANK2]] : index
@@ -559,10 +555,8 @@ func.func @broadcast_3_shapes_different_extents(%a : tensor<2xindex>,
// CHECK: %[[RANK0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<2xindex>
// CHECK: %[[RANK1:.*]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<3xindex>
// CHECK: %[[RANK2:.*]] = tensor.dim %[[ARG2]], %[[C0]] : tensor<2xindex>
-// CHECK: %[[CMP0:.*]] = arith.cmpi ugt, %[[RANK1]], %[[RANK0]] : index
-// CHECK: %[[LARGER_DIM:.*]] = arith.select %[[CMP0]], %[[RANK1]], %[[RANK0]] : index
-// CHECK: %[[CMP1:.*]] = arith.cmpi ugt, %[[RANK2]], %[[LARGER_DIM]] : index
-// CHECK: %[[MAX_RANK:.*]] = arith.select %[[CMP1]], %[[RANK2]], %[[LARGER_DIM]] : index
+// CHECK: %[[MAX0:.*]] = arith.maxui %[[RANK1]], %[[RANK0]] : index
+// CHECK: %[[MAX_RANK:.*]] = arith.maxui %[[RANK2]], %[[MAX0]] : index
// CHECK: %[[DIM_DIFF0:.*]] = arith.subi %[[MAX_RANK]], %[[RANK0]] : index
// CHECK: %[[DIM_DIFF1:.*]] = arith.subi %[[MAX_RANK]], %[[RANK1]] : index
// CHECK: %[[DIM_DIFF2:.*]] = arith.subi %[[MAX_RANK]], %[[RANK2]] : index
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index 51ebcad0797807..e64903671e599f 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -263,16 +263,13 @@ func.func @avg_pool_f32(%arg0: tensor<1x6x34x62xf32>) -> (tensor<1x5x33x62xf32>)
// CHECK: %[[SRC_END:.+]] = arith.muli %[[END]], %[[STRIDE]]
// CHECK: %[[PAD_START:.+]] = arith.constant 1
// CHECK: %[[START_SUB:.+]] = arith.subi %[[SRC_START]], %[[PAD_START]]
- // CHECK: %[[CMP:.+]] = arith.cmpi slt, %[[START_SUB]], %[[ZERO]]
- // CHECK: %...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/82178
More information about the Mlir-commits
mailing list