[Mlir-commits] [mlir] 2eb50ce - [mlir][tosa] Use arith::maxf/arith::minf in lowering from tosa

Thomas Raoux llvmlistbot at llvm.org
Mon Aug 8 18:13:35 PDT 2022


Author: Thomas Raoux
Date: 2022-08-09T01:10:32Z
New Revision: 2eb50cee11ccbfac71eeb7687b9f136d95fc7f52

URL: https://github.com/llvm/llvm-project/commit/2eb50cee11ccbfac71eeb7687b9f136d95fc7f52
DIFF: https://github.com/llvm/llvm-project/commit/2eb50cee11ccbfac71eeb7687b9f136d95fc7f52.diff

LOG: [mlir][tosa] Use arith::maxf/arith::minf in lowering from tosa

now that `arith` dialect has maxf/minf use it instead of cmp/select.
Also refactor clamp helpers to make them simlper.

Reviewed By: rsuderman

Differential Revision: https://reviews.llvm.org/D131426

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tosa/Utils/CoversionUtils.h
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
    mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tosa/Utils/CoversionUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/CoversionUtils.h
index 3e50358835ab9..11509c8a4f6bf 100644
--- a/mlir/include/mlir/Dialect/Tosa/Utils/CoversionUtils.h
+++ b/mlir/include/mlir/Dialect/Tosa/Utils/CoversionUtils.h
@@ -27,17 +27,15 @@ SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops);
 // Takes a vector of values and condenses them to a vector with no gaps.
 SmallVector<Value> condenseValues(const SmallVector<Value> &values);
 
-// Takes the parameters for a clamp and turns it into a series of ops.
-template <typename T, typename P>
-arith::SelectOp clampHelper(Location loc, Value arg, arith::ConstantOp min,
-                            arith::ConstantOp max, P pred,
-                            OpBuilder &rewriter) {
-  auto smallerThanMin = rewriter.create<T>(loc, pred, arg, min);
-  auto minOrArg =
-      rewriter.create<arith::SelectOp>(loc, smallerThanMin, min, arg);
-  auto largerThanMax = rewriter.create<T>(loc, pred, max, arg);
-  return rewriter.create<arith::SelectOp>(loc, largerThanMax, max, minOrArg);
-}
+// Takes the parameters for a clamp and turns it into a series of ops for float
+// inputs.
+Value clampFloatHelper(Location loc, Value arg, arith::ConstantOp min,
+                       arith::ConstantOp max, OpBuilder &rewriter);
+
+// Takes the parameters for a clamp and turns it into a series of ops for
+// integer inputs.
+Value clampIntHelper(Location loc, Value arg, arith::ConstantOp min,
+                     arith::ConstantOp max, OpBuilder &rewriter);
 
 // Returns the values in an attribute as an array of values.
 template <typename T>

diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index efaf612360852..374c663511599 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -182,8 +182,7 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
     auto max = rewriter.create<arith::ConstantIntOp>(
         loc, APInt::getSignedMaxValue(inputBitWidth).getSExtValue(),
         intermediateType);
-    auto clamp = clampHelper<arith::CmpIOp>(
-        loc, sub, min, max, arith::CmpIPredicate::slt, rewriter);
+    auto clamp = clampIntHelper(loc, sub, min, max, rewriter);
 
     // Truncate to the final value.
     return rewriter.create<arith::TruncIOp>(loc, elementTy, clamp);
@@ -335,9 +334,7 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
 
   // tosa::MaximumOp
   if (isa<tosa::MaximumOp>(op) && elementTy.isa<FloatType>()) {
-    auto predicate = rewriter.create<arith::CmpFOp>(
-        loc, arith::CmpFPredicate::OGT, args[0], args[1]);
-    return rewriter.create<arith::SelectOp>(loc, predicate, args[0], args[1]);
+    return rewriter.create<arith::MaxFOp>(loc, args[0], args[1]);
   }
 
   if (isa<tosa::MaximumOp>(op) && elementTy.isSignlessInteger()) {
@@ -348,9 +345,7 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
 
   // tosa::MinimumOp
   if (isa<tosa::MinimumOp>(op) && elementTy.isa<FloatType>()) {
-    auto predicate = rewriter.create<arith::CmpFOp>(
-        loc, arith::CmpFPredicate::OLT, args[0], args[1]);
-    return rewriter.create<arith::SelectOp>(loc, predicate, args[0], args[1]);
+    return rewriter.create<arith::MinFOp>(loc, args[0], args[1]);
   }
 
   if (isa<tosa::MinimumOp>(op) && elementTy.isSignlessInteger()) {
@@ -380,8 +375,7 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
         loc, elementTy, rewriter.getFloatAttr(elementTy, min_apf));
     auto max = rewriter.create<arith::ConstantOp>(
         loc, elementTy, rewriter.getFloatAttr(elementTy, max_apf));
-    return clampHelper<arith::CmpFOp>(loc, args[0], min, max,
-                                      arith::CmpFPredicate::OLT, rewriter);
+    return clampFloatHelper(loc, args[0], min, max, rewriter);
   }
 
   if (isa<tosa::ClampOp>(op) && elementTy.isa<IntegerType>()) {
@@ -409,8 +403,7 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
         loc, min, intTy.getIntOrFloatBitWidth());
     auto maxVal = rewriter.create<arith::ConstantIntOp>(
         loc, max, intTy.getIntOrFloatBitWidth());
-    return clampHelper<arith::CmpIOp>(loc, args[0], minVal, maxVal,
-                                      arith::CmpIPredicate::slt, rewriter);
+    return clampIntHelper(loc, args[0], minVal, maxVal, rewriter);
   }
 
   // tosa::ReluNOp
@@ -423,8 +416,7 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
                     APFloat::rmNearestTiesToEven, &losesInfo);
     auto n = rewriter.create<arith::ConstantOp>(
         loc, elementTy, rewriter.getFloatAttr(elementTy, max_apf));
-    return clampHelper<arith::CmpFOp>(loc, args[0], zero, n,
-                                      arith::CmpFPredicate::OLT, rewriter);
+    return clampFloatHelper(loc, args[0], zero, n, rewriter);
   }
 
   if (isa<tosa::ReluNOp>(op) && elementTy.isa<IntegerType>()) {
@@ -432,8 +424,7 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
         rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
     auto n = createConstFromIntAttribute<int32_t>(op, "max_int", elementTy,
                                                   rewriter);
-    return clampHelper<arith::CmpIOp>(loc, args[0], zero, n,
-                                      arith::CmpIPredicate::slt, rewriter);
+    return clampIntHelper(loc, args[0], zero, n, rewriter);
   }
 
   // tosa::SigmoidOp
@@ -521,8 +512,7 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
       auto rounded =
           rewriter.create<arith::SelectOp>(loc, negative, subbed, added);
 
-      auto clamped = clampHelper<arith::CmpFOp>(
-          loc, rounded, intMin, intMax, arith::CmpFPredicate::OLT, rewriter);
+      auto clamped = clampFloatHelper(loc, rounded, intMin, intMax, rewriter);
 
       return rewriter.create<arith::FPToSIOp>(loc, dstTy, clamped);
     }
@@ -553,8 +543,7 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
               .getSExtValue(),
           srcTy.getIntOrFloatBitWidth());
 
-      auto clamped = clampHelper<arith::CmpIOp>(
-          loc, args[0], intMin, intMax, arith::CmpIPredicate::slt, rewriter);
+      auto clamped = clampIntHelper(loc, args[0], intMin, intMax, rewriter);
       return rewriter.create<arith::TruncIOp>(loc, dstTy, clamped);
     }
   }
@@ -751,9 +740,7 @@ static Value createLinalgBodyCalculationForReduceOp(Operation *op,
   }
 
   if (isa<tosa::ReduceMinOp>(op) && elementTy.isa<FloatType>()) {
-    auto predicate = rewriter.create<arith::CmpFOp>(
-        loc, arith::CmpFPredicate::OLT, args[0], args[1]);
-    return rewriter.create<arith::SelectOp>(loc, predicate, args[0], args[1]);
+    return rewriter.create<arith::MinFOp>(loc, args[0], args[1]);
   }
 
   if (isa<tosa::ReduceMinOp>(op) && elementTy.isa<IntegerType>()) {
@@ -763,9 +750,7 @@ static Value createLinalgBodyCalculationForReduceOp(Operation *op,
   }
 
   if (isa<tosa::ReduceMaxOp>(op) && elementTy.isa<FloatType>()) {
-    auto predicate = rewriter.create<arith::CmpFOp>(
-        loc, arith::CmpFPredicate::OGT, args[0], args[1]);
-    return rewriter.create<arith::SelectOp>(loc, predicate, args[0], args[1]);
+    return rewriter.create<arith::MaxFOp>(loc, args[0], args[1]);
   }
 
   if (isa<tosa::ReduceMaxOp>(op) && elementTy.isa<IntegerType>()) {
@@ -1314,9 +1299,8 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
           auto intMaxVal = nestedBuilder.create<arith::ConstantOp>(
               loc, nestedBuilder.getI32IntegerAttr(intMax));
 
-          value = clampHelper<arith::CmpIOp>(
-              nestedLoc, value, intMinVal, intMaxVal, arith::CmpIPredicate::slt,
-              nestedBuilder);
+          value = clampIntHelper(nestedLoc, value, intMinVal, intMaxVal,
+                                 nestedBuilder);
 
           if (outIntType.getWidth() < 32) {
             value = nestedBuilder.create<arith::TruncIOp>(
@@ -1497,10 +1481,8 @@ class ResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
 
       // Clamp the to be within the bounds of the input image.
 
-      iy = clampHelper<arith::CmpIOp>(loc, iy, hwMin, hMax,
-                                      arith::CmpIPredicate::slt, rewriter);
-      ix = clampHelper<arith::CmpIOp>(loc, ix, hwMin, wMax,
-                                      arith::CmpIPredicate::slt, rewriter);
+      iy = clampIntHelper(loc, iy, hwMin, hMax, rewriter);
+      ix = clampIntHelper(loc, ix, hwMin, wMax, rewriter);
 
       // Read the value from the input array.
       iy =
@@ -1525,15 +1507,11 @@ class ResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
       Value y1 = rewriter.create<arith::AddIOp>(loc, y0, oneVal);
       Value x1 = rewriter.create<arith::AddIOp>(loc, x0, oneVal);
 
-      y0 = clampHelper<arith::CmpIOp>(loc, y0, hwMin, hMax,
-                                      arith::CmpIPredicate::slt, rewriter);
-      y1 = clampHelper<arith::CmpIOp>(loc, y1, hwMin, hMax,
-                                      arith::CmpIPredicate::slt, rewriter);
+      y0 = clampIntHelper(loc, y0, hwMin, hMax, rewriter);
+      y1 = clampIntHelper(loc, y1, hwMin, hMax, rewriter);
 
-      x0 = clampHelper<arith::CmpIOp>(loc, x0, hwMin, wMax,
-                                      arith::CmpIPredicate::slt, rewriter);
-      x1 = clampHelper<arith::CmpIOp>(loc, x1, hwMin, wMax,
-                                      arith::CmpIPredicate::slt, rewriter);
+      x0 = clampIntHelper(loc, x0, hwMin, wMax, rewriter);
+      x1 = clampIntHelper(loc, x1, hwMin, wMax, rewriter);
 
       y0 =
           rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), y0);

diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 5e491f2ef437c..42bca1ef8ff24 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -943,8 +943,7 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
             auto max = rewriter.create<arith::ConstantIntOp>(
                 loc, APInt::getSignedMaxValue(outBitwidth).getSExtValue(),
                 accETy);
-            auto clamp = clampHelper<arith::CmpIOp>(
-                loc, scaled, min, max, arith::CmpIPredicate::slt, rewriter);
+            auto clamp = clampIntHelper(loc, scaled, min, max, rewriter);
 
             poolVal = clamp;
             // Convert type.

diff  --git a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
index e994adb29bf5c..33999f3ad36ce 100644
--- a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
+++ b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
@@ -28,3 +28,21 @@ mlir::tosa::condenseValues(const SmallVector<Value> &values) {
       condensedValues.push_back(value);
   return condensedValues;
 }
+
+Value mlir::tosa::clampFloatHelper(Location loc, Value arg,
+                                   arith::ConstantOp min, arith::ConstantOp max,
+                                   OpBuilder &rewriter) {
+  Value minValue = rewriter.create<arith::MinFOp>(loc, arg, min);
+  return rewriter.create<arith::MaxFOp>(loc, minValue, max);
+}
+
+Value mlir::tosa::clampIntHelper(Location loc, Value arg, arith::ConstantOp min,
+                                 arith::ConstantOp max, OpBuilder &rewriter) {
+  auto smallerThanMin =
+      rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, arg, min);
+  auto minOrArg =
+      rewriter.create<arith::SelectOp>(loc, smallerThanMin, min, arg);
+  auto largerThanMax =
+      rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, max, arg);
+  return rewriter.create<arith::SelectOp>(loc, largerThanMax, max, minOrArg);
+}

diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index cd405cdd03b04..47efb8e72cb1c 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -198,13 +198,11 @@ func.func @test_simple_f32(%arg0: tensor<1xf32>) -> () {
   %13 = "tosa.select"(%10, %0, %1) : (tensor<1xi1>, tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
 
   // CHECK: linalg.generic
-  // CHECK: arith.cmpf
-  // CHECK: select
+  // CHECK: arith.maxf
   %14 = "tosa.maximum"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
 
   // CHECK: linalg.generic
-  // CHECK: arith.cmpf
-  // CHECK: select
+  // CHECK: arith.minf
   %15 = "tosa.minimum"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
 
   // CHECK: linalg.generic
@@ -216,13 +214,13 @@ func.func @test_simple_f32(%arg0: tensor<1xf32>) -> () {
   %17 = "tosa.floor"(%0) : (tensor<1xf32>) -> tensor<1xf32>
 
   // CHECK: linalg.generic
-  // CHECK: arith.cmpf
-  // CHECK: select
+  // CHECK: arith.minf
+  // CHECK: arith.maxf
   %18 = "tosa.clamp"(%0) {min_int = 1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xf32>) -> tensor<1xf32>
 
   // CHECK: linalg.generic
-  // CHECK: arith.cmpf
-  // CHECK: select
+  // CHECK: arith.minf
+  // CHECK: arith.maxf
   %19 = "tosa.reluN"(%0) {max_int = 5 : i64, max_fp = 5.0 : f32} : (tensor<1xf32>) -> tensor<1xf32>
 
   // CHECK: linalg.generic
@@ -241,10 +239,8 @@ func.func @test_simple_f32(%arg0: tensor<1xf32>) -> () {
   // CHECK: arith.subf
   // CHECK: arith.cmpf olt
   // CHECK: select
-  // CHECK: arith.cmpf olt
-  // CHECK: select
-  // CHECK: arith.cmpf olt
-  // CHECK: select
+  // CHECK: arith.minf
+  // CHECK: arith.maxf
   // CHECK: arith.fptosi
   %21 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xi32>
 
@@ -451,20 +447,22 @@ func.func @test_simple_ui8(%arg0: tensor<1xi8>) -> () {
 // CHECK-LABEL: @test_i8
 func.func @test_i8(%arg0: tensor<1xi8>) -> () {
   // CHECK: linalg.generic
+  // CHECK: ^bb0(%[[ARG1:.+]]: i8,
   // CHECK-DAG: %[[C127:.+]] = arith.constant -127
   // CHECK-DAG: %[[C126:.+]] = arith.constant 126
-  // CHECK-DAG: %[[CMP1:.+]] = arith.cmpi slt, %arg1, %[[C127]]
+  // CHECK-DAG: %[[CMP1:.+]] = arith.cmpi slt, %[[ARG1]], %[[C127]]
   // CHECK-DAG: %[[SEL1:.+]] = arith.select %[[CMP1]], %[[C127]]
-  // CHECK-DAG: %[[CMP2:.+]] = arith.cmpi slt, %[[C126]], %arg1
+  // CHECK-DAG: %[[CMP2:.+]] = arith.cmpi slt, %[[C126]], %[[ARG1]]
   // CHECK: %[[SEL2:.+]] = arith.select %[[CMP2]], %[[C126]], %[[SEL1]]
   %0 = "tosa.clamp"(%arg0) {min_int = -127 : i64, max_int = 126 : i64, min_fp = 0.0 : f32, max_fp = 0.0 : f32} : (tensor<1xi8>) -> tensor<1xi8>
 
   // CHECK: linalg.generic
+  // CHECK: ^bb0(%[[ARG1:.+]]: i8,
   // CHECK-DAG: %[[C128:.+]] = arith.constant -128
   // CHECK-DAG: %[[C127:.+]] = arith.constant 127
-  // CHECK-DAG: %[[CMP1:.+]] = arith.cmpi slt, %arg1, %[[C128]]
+  // CHECK-DAG: %[[CMP1:.+]] = arith.cmpi slt, %[[ARG1]], %[[C128]]
   // CHECK-DAG: %[[SEL1:.+]] = arith.select %[[CMP1]], %[[C128]]
-  // CHECK-DAG: %[[CMP2:.+]] = arith.cmpi slt, %[[C127]], %arg1
+  // CHECK-DAG: %[[CMP2:.+]] = arith.cmpi slt, %[[C127]], %[[ARG1]]
   // CHECK: %[[SEL2:.+]] = arith.select %[[CMP2]], %[[C127]], %[[SEL1]]
   %1 = "tosa.clamp"(%arg0) {min_int = -130 : i64, max_int = 130 : i64, min_fp = 0.0 : f32, max_fp = 0.0 : f32} : (tensor<1xi8>) -> tensor<1xi8>
 
@@ -476,12 +474,11 @@ func.func @test_i8(%arg0: tensor<1xi8>) -> () {
 // CHECK-LABEL: @test_clamp_f16
 func.func @test_clamp_f16(%arg0: tensor<1xf16>) -> () {
   // CHECK: linalg.generic
+  // CHECK: ^bb0(%[[ARG1:.+]]: f16, 
   // CHECK-DAG: %[[C0:.+]] = arith.constant 0.0
   // CHECK-DAG: %[[C6:.+]] = arith.constant 6.0
-  // CHECK-DAG: %[[CMP1:.+]] = arith.cmpf olt, %arg1, %[[C0]]
-  // CHECK-DAG: %[[SEL1:.+]] = arith.select %[[CMP1]], %[[C0]]
-  // CHECK-DAG: %[[CMP2:.+]] = arith.cmpf olt, %[[C6]], %arg1
-  // CHECK: %[[SEL2:.+]] = arith.select %[[CMP2]], %[[C6]], %[[SEL1]]
+  // CHECK-DAG: %[[MIN:.+]] = arith.minf %[[ARG1]], %[[C0]]
+  // CHECK-DAG: %[[MAX:.+]] = arith.maxf %[[MIN]], %[[C6]]
   %0 = "tosa.clamp"(%arg0) {min_int = 0 : i64, max_int = 0 : i64, min_fp = 0.0 : f32, max_fp = 6.0 : f32} : (tensor<1xf16>) -> tensor<1xf16>
 
   return
@@ -732,15 +729,13 @@ func.func @reduce_float(%arg0: tensor<5x4xf32>) -> () {
   // CHECK: arith.constant 3.40282347E+38 : f32
   // CHECK: linalg.fill
   // CHECK: linalg.generic
-  // CHECK: arith.cmpf olt
-  // CHECK: select
+  // CHECK: arith.minf
   %3 = "tosa.reduce_min"(%arg0) {axis = 0 : i64} : (tensor<5x4xf32>) -> tensor<1x4xf32>
 
   // CHECK: arith.constant -3.40282347E+38 : f32
   // CHECK: linalg.fill
   // CHECK: linalg.generic
-  // CHECK: arith.cmpf ogt
-  // CHECK: select
+  // CHECK: arith.maxf
   %4 = "tosa.reduce_max"(%arg0) {axis = 0 : i64} : (tensor<5x4xf32>) -> tensor<1x4xf32>
   return
 }
@@ -803,9 +798,8 @@ func.func @reduce_float_dyn_multiple(%arg0: tensor<?x?xf32>) -> () {
   // CHECK: %[[FILL:.+]] = linalg.fill ins(%[[CMIN]]{{.*}}outs(%[[INIT]]
   // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor<?x?xf32>) outs(%[[FILL]] : tensor<?xf32>)
   // CHECK: ^bb0(%arg1: f32, %arg2: f32)
-  // CHECK:   %[[CMP:.+]] = arith.cmpf ogt, %arg1, %arg2 : f32
-  // CHECK:   %[[RES:.+]] = arith.select %[[CMP]], %arg1, %arg2 : f32
-  // CHECK:   linalg.yield %[[RES]] : f32
+  // CHECK:   %[[MAX:.+]] = arith.maxf %arg1, %arg2 : f32
+  // CHECK:   linalg.yield %[[MAX]] : f32
   // CHECK: tensor.expand_shape %[[GENERIC]] {{\[}}[0, 1]] : tensor<?xf32> into tensor<?x1xf32>
   %0 = "tosa.reduce_max"(%arg0) {axis = 1 : i64} : (tensor<?x?xf32>) -> tensor<?x1xf32>
   return


        


More information about the Mlir-commits mailing list