[Mlir-commits] [mlir] [mlir] Use arith max or min ops instead of cmp + select (PR #82178)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Feb 18 09:11:13 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: mlevesquedion (mlevesquedion)

<details>
<summary>Changes</summary>

I believe the semantics should be the same, but this saves 1 op and simplifies the code.

For example, the following two instructions:

```
%2 = cmp sgt %0, %1
%3 = select %2, %0, %1
```

Are equivalent to:

```
%2 = maxsi %0 %1
```

---

Patch is 44.05 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/82178.diff


13 Files Affected:

- (modified) mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp (+8-9) 
- (modified) mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp (+2-6) 
- (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+5-15) 
- (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp (+2-7) 
- (modified) mlir/lib/Dialect/SCF/Utils/Utils.cpp (+2-4) 
- (modified) mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp (+2-7) 
- (modified) mlir/test/Conversion/AffineToStandard/lower-affine.mlir (+16-26) 
- (modified) mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir (+1-2) 
- (modified) mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir (+6-12) 
- (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir (+15-29) 
- (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-resize.mlir (+30-53) 
- (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+21-40) 
- (modified) mlir/test/Transforms/parametric-tiling.mlir (+4-8) 


``````````diff
diff --git a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
index 15ad6d8cdf629d..98cdfc63252711 100644
--- a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
+++ b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
@@ -34,12 +34,7 @@ using namespace mlir::affine;
 using namespace mlir::vector;
 
 /// Given a range of values, emit the code that reduces them with "min" or "max"
-/// depending on the provided comparison predicate.  The predicate defines which
-/// comparison to perform, "lt" for "min", "gt" for "max" and is used for the
-/// `cmpi` operation followed by the `select` operation:
-///
-///   %cond   = arith.cmpi "predicate" %v0, %v1
-///   %result = select %cond, %v0, %v1
+/// depending on the provided comparison predicate, sgt for max and slt for min.
 ///
 /// Multiple values are scanned in a linear sequence.  This creates a data
 /// dependences that wouldn't exist in a tree reduction, but is easier to
@@ -48,13 +43,17 @@ static Value buildMinMaxReductionSeq(Location loc,
                                      arith::CmpIPredicate predicate,
                                      ValueRange values, OpBuilder &builder) {
   assert(!values.empty() && "empty min/max chain");
+  assert(predicate == arith::CmpIPredicate::sgt ||
+         predicate == arith::CmpIPredicate::slt);
 
   auto valueIt = values.begin();
   Value value = *valueIt++;
   for (; valueIt != values.end(); ++valueIt) {
-    auto cmpOp = builder.create<arith::CmpIOp>(loc, predicate, value, *valueIt);
-    value = builder.create<arith::SelectOp>(loc, cmpOp.getResult(), value,
-                                            *valueIt);
+    if (predicate == arith::CmpIPredicate::sgt) {
+      value = builder.create<arith::MaxSIOp>(loc, value, *valueIt);
+    } else {
+      value = builder.create<arith::MinSIOp>(loc, value, *valueIt);
+    }
   }
 
   return value;
diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
index a3e51aeed0735a..de649f730ee9d7 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
+++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
@@ -147,9 +147,7 @@ LogicalResult BroadcastOpConverter::matchAndRewrite(
   // Find the maximum rank
   Value maxRank = ranks.front();
   for (Value v : llvm::drop_begin(ranks, 1)) {
-    Value rankIsGreater =
-        lb.create<arith::CmpIOp>(arith::CmpIPredicate::ugt, v, maxRank);
-    maxRank = lb.create<arith::SelectOp>(rankIsGreater, v, maxRank);
+    maxRank = lb.create<arith::MaxUIOp>(v, maxRank);
   }
 
   // Calculate the difference of ranks and the maximum rank for later offsets.
@@ -262,9 +260,7 @@ LogicalResult IsBroadcastableOpConverter::matchAndRewrite(
   // Find the maximum rank
   Value maxRank = ranks.front();
   for (Value v : llvm::drop_begin(ranks, 1)) {
-    Value rankIsGreater =
-        lb.create<arith::CmpIOp>(arith::CmpIPredicate::ugt, v, maxRank);
-    maxRank = lb.create<arith::SelectOp>(rankIsGreater, v, maxRank);
+    maxRank = lb.create<arith::MaxUIOp>(v, maxRank);
   }
 
   // Calculate the difference of ranks and the maximum rank for later offsets.
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index f4f6dadfb37166..7eb32ebe3228fb 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -61,10 +61,8 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
   if (isa<tosa::AbsOp>(op) && isa<IntegerType>(elementTy)) {
     auto zero = rewriter.create<arith::ConstantOp>(
         loc, rewriter.getZeroAttr(elementTy));
-    auto cmp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt,
-                                              args[0], zero);
     auto neg = rewriter.create<arith::SubIOp>(loc, zero, args[0]);
-    return rewriter.create<arith::SelectOp>(loc, cmp, args[0], neg);
+    return rewriter.create<arith::MaxSIOp>(loc, args[0], neg);
   }
 
   // tosa::AddOp
@@ -348,9 +346,7 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
   }
 
   if (isa<tosa::MaximumOp>(op) && elementTy.isSignlessInteger()) {
-    auto predicate = rewriter.create<arith::CmpIOp>(
-        loc, arith::CmpIPredicate::sgt, args[0], args[1]);
-    return rewriter.create<arith::SelectOp>(loc, predicate, args[0], args[1]);
+    return rewriter.create<arith::MaxSIOp>(loc, args[0], args[1]);
   }
 
   // tosa::MinimumOp
@@ -359,9 +355,7 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
   }
 
   if (isa<tosa::MinimumOp>(op) && elementTy.isSignlessInteger()) {
-    auto predicate = rewriter.create<arith::CmpIOp>(
-        loc, arith::CmpIPredicate::slt, args[0], args[1]);
-    return rewriter.create<arith::SelectOp>(loc, predicate, args[0], args[1]);
+    return rewriter.create<arith::MinSIOp>(loc, args[0], args[1]);
   }
 
   // tosa::CeilOp
@@ -1000,9 +994,7 @@ static Value createLinalgBodyCalculationForReduceOp(Operation *op,
   }
 
   if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy)) {
-    auto predicate = rewriter.create<arith::CmpIOp>(
-        loc, arith::CmpIPredicate::slt, args[0], args[1]);
-    return rewriter.create<arith::SelectOp>(loc, predicate, args[0], args[1]);
+    return rewriter.create<arith::MinSIOp>(loc, args[0], args[1]);
   }
 
   if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy)) {
@@ -1010,9 +1002,7 @@ static Value createLinalgBodyCalculationForReduceOp(Operation *op,
   }
 
   if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy)) {
-    auto predicate = rewriter.create<arith::CmpIOp>(
-        loc, arith::CmpIPredicate::sgt, args[0], args[1]);
-    return rewriter.create<arith::SelectOp>(loc, predicate, args[0], args[1]);
+    return rewriter.create<arith::MaxSIOp>(loc, args[0], args[1]);
   }
 
   if (isa<tosa::ReduceAllOp>(op) && elementTy.isInteger(1))
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 607a603cca810f..3f39cbf03a9a80 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -845,10 +845,7 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
             auto padVal = rewriter.create<arith::ConstantIndexOp>(loc, pad);
             Value dpos = rewriter.create<arith::SubIOp>(loc, pos, padVal);
 
-            Value cmp = rewriter.create<arith::CmpIOp>(
-                loc, arith::CmpIPredicate::slt, dpos, zero);
-            Value offset =
-                rewriter.create<arith::SelectOp>(loc, cmp, dpos, zero);
+            Value offset = rewriter.create<arith::MinSIOp>(loc, dpos, zero);
             return rewriter.create<arith::AddIOp>(loc, valid, offset)
                 ->getResult(0);
           };
@@ -868,9 +865,7 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
             // Determine how much padding was included.
             val = padFn(val, left, pad[i * 2]);
             val = padFn(val, right, pad[i * 2 + 1]);
-            Value cmp = rewriter.create<arith::CmpIOp>(
-                loc, arith::CmpIPredicate::slt, val, one);
-            return rewriter.create<arith::SelectOp>(loc, cmp, one, val);
+            return rewriter.create<arith::MaxSIOp>(loc, one, val);
           };
 
           // Compute the indices from either end.
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 536c02feca1bd5..502d7e197a6f6b 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -791,10 +791,8 @@ static Loops stripmineSink(scf::ForOp forOp, Value factor,
     // Insert newForOp before the terminator of `t`.
     auto b = OpBuilder::atBlockTerminator((t.getBody()));
     Value stepped = b.create<arith::AddIOp>(t.getLoc(), iv, forOp.getStep());
-    Value less = b.create<arith::CmpIOp>(t.getLoc(), arith::CmpIPredicate::slt,
-                                         forOp.getUpperBound(), stepped);
-    Value ub = b.create<arith::SelectOp>(t.getLoc(), less,
-                                         forOp.getUpperBound(), stepped);
+    Value ub =
+        b.create<arith::MinSIOp>(t.getLoc(), forOp.getUpperBound(), stepped);
 
     // Splice [begin, begin + nOps - 1) into `newForOp` and replace uses.
     auto newForOp = b.create<scf::ForOp>(t.getLoc(), iv, ub, originalStep);
diff --git a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
index ee428b201d0073..4fc97115064f33 100644
--- a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
+++ b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
@@ -39,13 +39,8 @@ Value mlir::tosa::clampFloatHelper(Location loc, Value arg, Value min,
 
 Value mlir::tosa::clampIntHelper(Location loc, Value arg, Value min, Value max,
                                  OpBuilder &rewriter) {
-  auto smallerThanMin =
-      rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, arg, min);
-  auto minOrArg =
-      rewriter.create<arith::SelectOp>(loc, smallerThanMin, min, arg);
-  auto largerThanMax =
-      rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, max, arg);
-  return rewriter.create<arith::SelectOp>(loc, largerThanMax, max, minOrArg);
+  auto minOrArg = rewriter.create<arith::MaxSIOp>(loc, min, arg);
+  return rewriter.create<arith::MinSIOp>(loc, max, minOrArg);
 }
 
 bool mlir::tosa::validIntegerRange(IntegerType ty, int64_t value) {
diff --git a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
index 92608135d24b08..00d7b6b8d65f67 100644
--- a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
+++ b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
@@ -371,16 +371,14 @@ func.func @if_for() {
 // CHECK-NEXT:   %[[c1:.*]] = arith.constant 1 : index
 // CHECK-NEXT:   for %{{.*}} = %[[c0]] to %[[c42]] step %[[c1]] {
 // CHECK-NEXT:     %[[cm1:.*]] = arith.constant -1 : index
-// CHECK-NEXT:     %[[a:.*]] = arith.muli %{{.*}}, %[[cm1]] : index
-// CHECK-NEXT:     %[[b:.*]] = arith.addi %[[a]], %{{.*}} : index
-// CHECK-NEXT:     %[[c:.*]] = arith.cmpi sgt, %{{.*}}, %[[b]] : index
-// CHECK-NEXT:     %[[d:.*]] = arith.select %[[c]], %{{.*}}, %[[b]] : index
+// CHECK-NEXT:     %[[mul0:.*]] = arith.muli %{{.*}}, %[[cm1]] : index
+// CHECK-NEXT:     %[[add0:.*]] = arith.addi %[[mul0]], %{{.*}} : index
+// CHECK-NEXT:     %[[max:.*]] = arith.maxsi %{{.*}}, %[[add0]] : index
 // CHECK-NEXT:     %[[c10:.*]] = arith.constant 10 : index
-// CHECK-NEXT:     %[[e:.*]] = arith.addi %{{.*}}, %[[c10]] : index
-// CHECK-NEXT:     %[[f:.*]] = arith.cmpi slt, %{{.*}}, %[[e]] : index
-// CHECK-NEXT:     %[[g:.*]] = arith.select %[[f]], %{{.*}}, %[[e]] : index
+// CHECK-NEXT:     %[[add1:.*]] = arith.addi %{{.*}}, %[[c10]] : index
+// CHECK-NEXT:     %[[min:.*]] = arith.minsi %{{.*}}, %[[add1]] : index
 // CHECK-NEXT:     %[[c1_0:.*]] = arith.constant 1 : index
-// CHECK-NEXT:     for %{{.*}} = %[[d]] to %[[g]] step %[[c1_0]] {
+// CHECK-NEXT:     for %{{.*}} = %[[max]] to %[[min]] step %[[c1_0]] {
 // CHECK-NEXT:       call @body2(%{{.*}}, %{{.*}}) : (index, index) -> ()
 // CHECK-NEXT:     }
 // CHECK-NEXT:   }
@@ -397,25 +395,19 @@ func.func @loop_min_max(%N : index) {
 
 #map_7_values = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4, d5, d6)>
 
-// Check that the "min" (cmpi slt + select) reduction sequence is emitted
+// Check that the "min" reduction sequence is emitted
 // correctly for an affine map with 7 results.
 
 // CHECK-LABEL: func @min_reduction_tree
 // CHECK-NEXT:   %[[c0:.*]] = arith.constant 0 : index
-// CHECK-NEXT:   %[[c01:.+]] = arith.cmpi slt, %{{.*}}, %{{.*}} : index
-// CHECK-NEXT:   %[[r01:.+]] = arith.select %[[c01]], %{{.*}}, %{{.*}} : index
-// CHECK-NEXT:   %[[c012:.+]] = arith.cmpi slt, %[[r01]], %{{.*}} : index
-// CHECK-NEXT:   %[[r012:.+]] = arith.select %[[c012]], %[[r01]], %{{.*}} : index
-// CHECK-NEXT:   %[[c0123:.+]] = arith.cmpi slt, %[[r012]], %{{.*}} : index
-// CHECK-NEXT:   %[[r0123:.+]] = arith.select %[[c0123]], %[[r012]], %{{.*}} : index
-// CHECK-NEXT:   %[[c01234:.+]] = arith.cmpi slt, %[[r0123]], %{{.*}} : index
-// CHECK-NEXT:   %[[r01234:.+]] = arith.select %[[c01234]], %[[r0123]], %{{.*}} : index
-// CHECK-NEXT:   %[[c012345:.+]] = arith.cmpi slt, %[[r01234]], %{{.*}} : index
-// CHECK-NEXT:   %[[r012345:.+]] = arith.select %[[c012345]], %[[r01234]], %{{.*}} : index
-// CHECK-NEXT:   %[[c0123456:.+]] = arith.cmpi slt, %[[r012345]], %{{.*}} : index
-// CHECK-NEXT:   %[[r0123456:.+]] = arith.select %[[c0123456]], %[[r012345]], %{{.*}} : index
+// CHECK-NEXT:   %[[min:.+]] = arith.minsi %{{.*}}, %{{.*}} : index
+// CHECK-NEXT:   %[[min_0:.+]] = arith.minsi %[[min]], %{{.*}} : index
+// CHECK-NEXT:   %[[min_1:.+]] = arith.minsi %[[min_0]], %{{.*}} : index
+// CHECK-NEXT:   %[[min_2:.+]] = arith.minsi %[[min_1]], %{{.*}} : index
+// CHECK-NEXT:   %[[min_3:.+]] = arith.minsi %[[min_2]], %{{.*}} : index
+// CHECK-NEXT:   %[[min_4:.+]] = arith.minsi %[[min_3]], %{{.*}} : index
 // CHECK-NEXT:   %[[c1:.*]] = arith.constant 1 : index
-// CHECK-NEXT:   for %{{.*}} = %[[c0]] to %[[r0123456]] step %[[c1]] {
+// CHECK-NEXT:   for %{{.*}} = %[[c0]] to %[[min_4]] step %[[c1]] {
 // CHECK-NEXT:     call @body(%{{.*}}) : (index) -> ()
 // CHECK-NEXT:   }
 // CHECK-NEXT:   return
@@ -690,8 +682,7 @@ func.func @affine_min(%arg0: index, %arg1: index) -> index{
   // CHECK: %[[Cm2:.*]] = arith.constant -1
   // CHECK: %[[neg2:.*]] = arith.muli %[[ARG0]], %[[Cm2:.*]]
   // CHECK: %[[second:.*]] = arith.addi %[[ARG1]], %[[neg2]]
-  // CHECK: %[[cmp:.*]] = arith.cmpi slt, %[[first]], %[[second]]
-  // CHECK: arith.select %[[cmp]], %[[first]], %[[second]]
+  // CHECK: arith.minsi %[[first]], %[[second]]
   %0 = affine.min affine_map<(d0,d1) -> (d0 - d1, d1 - d0)>(%arg0, %arg1)
   return %0 : index
 }
@@ -705,8 +696,7 @@ func.func @affine_max(%arg0: index, %arg1: index) -> index{
   // CHECK: %[[Cm2:.*]] = arith.constant -1
   // CHECK: %[[neg2:.*]] = arith.muli %[[ARG0]], %[[Cm2:.*]]
   // CHECK: %[[second:.*]] = arith.addi %[[ARG1]], %[[neg2]]
-  // CHECK: %[[cmp:.*]] = arith.cmpi sgt, %[[first]], %[[second]]
-  // CHECK: arith.select %[[cmp]], %[[first]], %[[second]]
+  // CHECK: arith.maxsi %[[first]], %[[second]]
   %0 = affine.max affine_map<(d0,d1) -> (d0 - d1, d1 - d0)>(%arg0, %arg1)
   return %0 : index
 }
diff --git a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
index eb45112b117c0d..87d613986c7c3f 100644
--- a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
@@ -554,8 +554,7 @@ func.func @collapse_shape_dynamic(%arg0 : memref<1x2x?xf32>) -> memref<1x?xf32>
 // CHECK:           %[[SIZE1_TO_IDX:.*]] = builtin.unrealized_conversion_cast %[[FINAL_SIZE1]] : i64 to index
 // CHECK:           %[[FINAL_SIZE1:.*]] = builtin.unrealized_conversion_cast %[[SIZE1_TO_IDX]] : index to i64
 // CHECK:           %[[C1:.*]] = llvm.mlir.constant(1 : index) : i64
-// CHECK:           %[[IS_MIN_STRIDE1:.*]] = llvm.icmp "slt" %[[STRIDE1]], %[[C1]] : i64
-// CHECK:           %[[MIN_STRIDE1:.*]] = llvm.select %[[IS_MIN_STRIDE1]], %[[STRIDE1]], %[[C1]] : i1, i64
+// CHECK:           %[[MIN_STRIDE1:.*]] = llvm.intr.smin(%[[STRIDE1]], %[[C1]]) : (i64, i64) -> i64
 // CHECK:           %[[MIN_STRIDE1_TO_IDX:.*]] = builtin.unrealized_conversion_cast %[[MIN_STRIDE1]] : i64 to index
 // CHECK:           %[[MIN_STRIDE1:.*]] = builtin.unrealized_conversion_cast %[[MIN_STRIDE1_TO_IDX]] : index to i64
 // CHECK:           %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
index cb3af973daee20..3b73c513b7955f 100644
--- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
+++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
@@ -377,10 +377,8 @@ func.func @try_is_broadcastable (%a : tensor<2xindex>, %b : tensor<3xindex>, %c
 // CHECK:           %[[RANK0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<2xindex>
 // CHECK:           %[[RANK1:.*]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<3xindex>
 // CHECK:           %[[RANK2:.*]] = tensor.dim %[[ARG2]], %[[C0]] : tensor<2xindex>
-// CHECK:           %[[CMP0:.*]] = arith.cmpi ugt, %[[RANK1]], %[[RANK0]] : index
-// CHECK:           %[[LARGER_DIM:.*]] = arith.select %[[CMP0]], %[[RANK1]], %[[RANK0]] : index
-// CHECK:           %[[CMP1:.*]] = arith.cmpi ugt, %[[RANK2]], %[[LARGER_DIM]] : index
-// CHECK:           %[[MAX_RANK:.*]] = arith.select %[[CMP1]], %[[RANK2]], %[[LARGER_DIM]] : index
+// CHECK:           %[[MAX0:.*]] = arith.maxui %[[RANK1]], %[[RANK0]] : index
+// CHECK:           %[[MAX_RANK:.*]] = arith.maxui %[[RANK2]], %[[MAX0]] : index
 // CHECK:           %[[DIM_DIFF0:.*]] = arith.subi %[[MAX_RANK]], %[[RANK0]] : index
 // CHECK:           %[[DIM_DIFF1:.*]] = arith.subi %[[MAX_RANK]], %[[RANK1]] : index
 // CHECK:           %[[DIM_DIFF2:.*]] = arith.subi %[[MAX_RANK]], %[[RANK2]] : index
@@ -467,10 +465,8 @@ func.func @broadcast(%a : tensor<2xindex>, %b : tensor<3xindex>, %c : tensor<2xi
 // CHECK:           %[[RANK0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<2xindex>
 // CHECK:           %[[RANK1:.*]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<3xindex>
 // CHECK:           %[[RANK2:.*]] = tensor.dim %[[ARG2]], %[[C0]] : tensor<2xindex>
-// CHECK:           %[[CMP0:.*]] = arith.cmpi ugt, %[[RANK1]], %[[RANK0]] : index
-// CHECK:           %[[LARGER_DIM:.*]] = arith.select %[[CMP0]], %[[RANK1]], %[[RANK0]] : index
-// CHECK:           %[[CMP1:.*]] = arith.cmpi ugt, %[[RANK2]], %[[LARGER_DIM]] : index
-// CHECK:           %[[MAX_RANK:.*]] = arith.select %[[CMP1]], %[[RANK2]], %[[LARGER_DIM]] : index
+// CHECK:           %[[MAX0:.*]] = arith.maxui %[[RANK1]], %[[RANK0]] : index
+// CHECK:           %[[MAX_RANK:.*]] = arith.maxui %[[RANK2]], %[[MAX0]] : index
 // CHECK:           %[[DIM_DIFF0:.*]] = arith.subi %[[MAX_RANK]], %[[RANK0]] : index
 // CHECK:           %[[DIM_DIFF1:.*]] = arith.subi %[[MAX_RANK]], %[[RANK1]] : index
 // CHECK:           %[[DIM_DIFF2:.*]] = arith.subi %[[MAX_RANK]], %[[RANK2]] : index
@@ -559,10 +555,8 @@ func.func @broadcast_3_shapes_different_extents(%a : tensor<2xindex>,
 // CHECK:           %[[RANK0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<2xindex>
 // CHECK:           %[[RANK1:.*]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<3xindex>
 // CHECK:           %[[RANK2:.*]] = tensor.dim %[[ARG2]], %[[C0]] : tensor<2xindex>
-// CHECK:           %[[CMP0:.*]] = arith.cmpi ugt, %[[RANK1]], %[[RANK0]] : index
-// CHECK:           %[[LARGER_DIM:.*]] = arith.select %[[CMP0]], %[[RANK1]], %[[RANK0]] : index
-// CHECK:           %[[CMP1:.*]] = arith.cmpi ugt, %[[RANK2]], %[[LARGER_DIM]] : index
-// CHECK:           %[[MAX_RANK:.*]] = arith.select %[[CMP1]], %[[RANK2]], %[[LARGER_DIM]] : index
+// CHECK:           %[[MAX0:.*]] = arith.maxui %[[RANK1]], %[[RANK0]] : index
+// CHECK:           %[[MAX_RANK:.*]] = arith.maxui %[[RANK2]], %[[MAX0]] : index
 // CHECK:           %[[DIM_DIFF0:.*]] = arith.subi %[[MAX_RANK]], %[[RANK0]] : index
 // CHECK:           %[[DIM_DIFF1:.*]] = arith.subi %[[MAX_RANK]], %[[RANK1]] : index
 // CHECK:           %[[DIM_DIFF2:.*]] = arith.subi %[[MAX_RANK]], %[[RANK2]] : index
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index 51ebcad0797807..e64903671e599f 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -263,16 +263,13 @@ func.func @avg_pool_f32(%arg0: tensor<1x6x34x62xf32>) -> (tensor<1x5x33x62xf32>)
   // CHECK:   %[[SRC_END:.+]] = arith.muli %[[END]], %[[STRIDE]]
   // CHECK:   %[[PAD_START:.+]] = arith.constant 1
   // CHECK:   %[[START_SUB:.+]] = arith.subi %[[SRC_START]], %[[PAD_START]]
-  // CHECK:   %[[CMP:.+]] = arith.cmpi slt, %[[START_SUB]], %[[ZERO]]
-  // CHECK:   %...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/82178


More information about the Mlir-commits mailing list