[Mlir-commits] [mlir] 2eb50ce - [mlir][tosa] Use arith::maxf/arith::minf in lowering from tosa
Thomas Raoux
llvmlistbot at llvm.org
Mon Aug 8 18:13:35 PDT 2022
Author: Thomas Raoux
Date: 2022-08-09T01:10:32Z
New Revision: 2eb50cee11ccbfac71eeb7687b9f136d95fc7f52
URL: https://github.com/llvm/llvm-project/commit/2eb50cee11ccbfac71eeb7687b9f136d95fc7f52
DIFF: https://github.com/llvm/llvm-project/commit/2eb50cee11ccbfac71eeb7687b9f136d95fc7f52.diff
LOG: [mlir][tosa] Use arith::maxf/arith::minf in lowering from tosa
now that `arith` dialect has maxf/minf use it instead of cmp/select.
Also refactor clamp helpers to make them simlper.
Reviewed By: rsuderman
Differential Revision: https://reviews.llvm.org/D131426
Added:
Modified:
mlir/include/mlir/Dialect/Tosa/Utils/CoversionUtils.h
mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tosa/Utils/CoversionUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/CoversionUtils.h
index 3e50358835ab9..11509c8a4f6bf 100644
--- a/mlir/include/mlir/Dialect/Tosa/Utils/CoversionUtils.h
+++ b/mlir/include/mlir/Dialect/Tosa/Utils/CoversionUtils.h
@@ -27,17 +27,15 @@ SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops);
// Takes a vector of values and condenses them to a vector with no gaps.
SmallVector<Value> condenseValues(const SmallVector<Value> &values);
-// Takes the parameters for a clamp and turns it into a series of ops.
-template <typename T, typename P>
-arith::SelectOp clampHelper(Location loc, Value arg, arith::ConstantOp min,
- arith::ConstantOp max, P pred,
- OpBuilder &rewriter) {
- auto smallerThanMin = rewriter.create<T>(loc, pred, arg, min);
- auto minOrArg =
- rewriter.create<arith::SelectOp>(loc, smallerThanMin, min, arg);
- auto largerThanMax = rewriter.create<T>(loc, pred, max, arg);
- return rewriter.create<arith::SelectOp>(loc, largerThanMax, max, minOrArg);
-}
+// Takes the parameters for a clamp and turns it into a series of ops for float
+// inputs.
+Value clampFloatHelper(Location loc, Value arg, arith::ConstantOp min,
+ arith::ConstantOp max, OpBuilder &rewriter);
+
+// Takes the parameters for a clamp and turns it into a series of ops for
+// integer inputs.
+Value clampIntHelper(Location loc, Value arg, arith::ConstantOp min,
+ arith::ConstantOp max, OpBuilder &rewriter);
// Returns the values in an attribute as an array of values.
template <typename T>
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index efaf612360852..374c663511599 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -182,8 +182,7 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
auto max = rewriter.create<arith::ConstantIntOp>(
loc, APInt::getSignedMaxValue(inputBitWidth).getSExtValue(),
intermediateType);
- auto clamp = clampHelper<arith::CmpIOp>(
- loc, sub, min, max, arith::CmpIPredicate::slt, rewriter);
+ auto clamp = clampIntHelper(loc, sub, min, max, rewriter);
// Truncate to the final value.
return rewriter.create<arith::TruncIOp>(loc, elementTy, clamp);
@@ -335,9 +334,7 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
// tosa::MaximumOp
if (isa<tosa::MaximumOp>(op) && elementTy.isa<FloatType>()) {
- auto predicate = rewriter.create<arith::CmpFOp>(
- loc, arith::CmpFPredicate::OGT, args[0], args[1]);
- return rewriter.create<arith::SelectOp>(loc, predicate, args[0], args[1]);
+ return rewriter.create<arith::MaxFOp>(loc, args[0], args[1]);
}
if (isa<tosa::MaximumOp>(op) && elementTy.isSignlessInteger()) {
@@ -348,9 +345,7 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
// tosa::MinimumOp
if (isa<tosa::MinimumOp>(op) && elementTy.isa<FloatType>()) {
- auto predicate = rewriter.create<arith::CmpFOp>(
- loc, arith::CmpFPredicate::OLT, args[0], args[1]);
- return rewriter.create<arith::SelectOp>(loc, predicate, args[0], args[1]);
+ return rewriter.create<arith::MinFOp>(loc, args[0], args[1]);
}
if (isa<tosa::MinimumOp>(op) && elementTy.isSignlessInteger()) {
@@ -380,8 +375,7 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
loc, elementTy, rewriter.getFloatAttr(elementTy, min_apf));
auto max = rewriter.create<arith::ConstantOp>(
loc, elementTy, rewriter.getFloatAttr(elementTy, max_apf));
- return clampHelper<arith::CmpFOp>(loc, args[0], min, max,
- arith::CmpFPredicate::OLT, rewriter);
+ return clampFloatHelper(loc, args[0], min, max, rewriter);
}
if (isa<tosa::ClampOp>(op) && elementTy.isa<IntegerType>()) {
@@ -409,8 +403,7 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
loc, min, intTy.getIntOrFloatBitWidth());
auto maxVal = rewriter.create<arith::ConstantIntOp>(
loc, max, intTy.getIntOrFloatBitWidth());
- return clampHelper<arith::CmpIOp>(loc, args[0], minVal, maxVal,
- arith::CmpIPredicate::slt, rewriter);
+ return clampIntHelper(loc, args[0], minVal, maxVal, rewriter);
}
// tosa::ReluNOp
@@ -423,8 +416,7 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
APFloat::rmNearestTiesToEven, &losesInfo);
auto n = rewriter.create<arith::ConstantOp>(
loc, elementTy, rewriter.getFloatAttr(elementTy, max_apf));
- return clampHelper<arith::CmpFOp>(loc, args[0], zero, n,
- arith::CmpFPredicate::OLT, rewriter);
+ return clampFloatHelper(loc, args[0], zero, n, rewriter);
}
if (isa<tosa::ReluNOp>(op) && elementTy.isa<IntegerType>()) {
@@ -432,8 +424,7 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
auto n = createConstFromIntAttribute<int32_t>(op, "max_int", elementTy,
rewriter);
- return clampHelper<arith::CmpIOp>(loc, args[0], zero, n,
- arith::CmpIPredicate::slt, rewriter);
+ return clampIntHelper(loc, args[0], zero, n, rewriter);
}
// tosa::SigmoidOp
@@ -521,8 +512,7 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
auto rounded =
rewriter.create<arith::SelectOp>(loc, negative, subbed, added);
- auto clamped = clampHelper<arith::CmpFOp>(
- loc, rounded, intMin, intMax, arith::CmpFPredicate::OLT, rewriter);
+ auto clamped = clampFloatHelper(loc, rounded, intMin, intMax, rewriter);
return rewriter.create<arith::FPToSIOp>(loc, dstTy, clamped);
}
@@ -553,8 +543,7 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
.getSExtValue(),
srcTy.getIntOrFloatBitWidth());
- auto clamped = clampHelper<arith::CmpIOp>(
- loc, args[0], intMin, intMax, arith::CmpIPredicate::slt, rewriter);
+ auto clamped = clampIntHelper(loc, args[0], intMin, intMax, rewriter);
return rewriter.create<arith::TruncIOp>(loc, dstTy, clamped);
}
}
@@ -751,9 +740,7 @@ static Value createLinalgBodyCalculationForReduceOp(Operation *op,
}
if (isa<tosa::ReduceMinOp>(op) && elementTy.isa<FloatType>()) {
- auto predicate = rewriter.create<arith::CmpFOp>(
- loc, arith::CmpFPredicate::OLT, args[0], args[1]);
- return rewriter.create<arith::SelectOp>(loc, predicate, args[0], args[1]);
+ return rewriter.create<arith::MinFOp>(loc, args[0], args[1]);
}
if (isa<tosa::ReduceMinOp>(op) && elementTy.isa<IntegerType>()) {
@@ -763,9 +750,7 @@ static Value createLinalgBodyCalculationForReduceOp(Operation *op,
}
if (isa<tosa::ReduceMaxOp>(op) && elementTy.isa<FloatType>()) {
- auto predicate = rewriter.create<arith::CmpFOp>(
- loc, arith::CmpFPredicate::OGT, args[0], args[1]);
- return rewriter.create<arith::SelectOp>(loc, predicate, args[0], args[1]);
+ return rewriter.create<arith::MaxFOp>(loc, args[0], args[1]);
}
if (isa<tosa::ReduceMaxOp>(op) && elementTy.isa<IntegerType>()) {
@@ -1314,9 +1299,8 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
auto intMaxVal = nestedBuilder.create<arith::ConstantOp>(
loc, nestedBuilder.getI32IntegerAttr(intMax));
- value = clampHelper<arith::CmpIOp>(
- nestedLoc, value, intMinVal, intMaxVal, arith::CmpIPredicate::slt,
- nestedBuilder);
+ value = clampIntHelper(nestedLoc, value, intMinVal, intMaxVal,
+ nestedBuilder);
if (outIntType.getWidth() < 32) {
value = nestedBuilder.create<arith::TruncIOp>(
@@ -1497,10 +1481,8 @@ class ResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
// Clamp the to be within the bounds of the input image.
- iy = clampHelper<arith::CmpIOp>(loc, iy, hwMin, hMax,
- arith::CmpIPredicate::slt, rewriter);
- ix = clampHelper<arith::CmpIOp>(loc, ix, hwMin, wMax,
- arith::CmpIPredicate::slt, rewriter);
+ iy = clampIntHelper(loc, iy, hwMin, hMax, rewriter);
+ ix = clampIntHelper(loc, ix, hwMin, wMax, rewriter);
// Read the value from the input array.
iy =
@@ -1525,15 +1507,11 @@ class ResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
Value y1 = rewriter.create<arith::AddIOp>(loc, y0, oneVal);
Value x1 = rewriter.create<arith::AddIOp>(loc, x0, oneVal);
- y0 = clampHelper<arith::CmpIOp>(loc, y0, hwMin, hMax,
- arith::CmpIPredicate::slt, rewriter);
- y1 = clampHelper<arith::CmpIOp>(loc, y1, hwMin, hMax,
- arith::CmpIPredicate::slt, rewriter);
+ y0 = clampIntHelper(loc, y0, hwMin, hMax, rewriter);
+ y1 = clampIntHelper(loc, y1, hwMin, hMax, rewriter);
- x0 = clampHelper<arith::CmpIOp>(loc, x0, hwMin, wMax,
- arith::CmpIPredicate::slt, rewriter);
- x1 = clampHelper<arith::CmpIOp>(loc, x1, hwMin, wMax,
- arith::CmpIPredicate::slt, rewriter);
+ x0 = clampIntHelper(loc, x0, hwMin, wMax, rewriter);
+ x1 = clampIntHelper(loc, x1, hwMin, wMax, rewriter);
y0 =
rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), y0);
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 5e491f2ef437c..42bca1ef8ff24 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -943,8 +943,7 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
auto max = rewriter.create<arith::ConstantIntOp>(
loc, APInt::getSignedMaxValue(outBitwidth).getSExtValue(),
accETy);
- auto clamp = clampHelper<arith::CmpIOp>(
- loc, scaled, min, max, arith::CmpIPredicate::slt, rewriter);
+ auto clamp = clampIntHelper(loc, scaled, min, max, rewriter);
poolVal = clamp;
// Convert type.
diff --git a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
index e994adb29bf5c..33999f3ad36ce 100644
--- a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
+++ b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
@@ -28,3 +28,21 @@ mlir::tosa::condenseValues(const SmallVector<Value> &values) {
condensedValues.push_back(value);
return condensedValues;
}
+
+Value mlir::tosa::clampFloatHelper(Location loc, Value arg,
+ arith::ConstantOp min, arith::ConstantOp max,
+ OpBuilder &rewriter) {
+ Value minValue = rewriter.create<arith::MinFOp>(loc, arg, min);
+ return rewriter.create<arith::MaxFOp>(loc, minValue, max);
+}
+
+Value mlir::tosa::clampIntHelper(Location loc, Value arg, arith::ConstantOp min,
+ arith::ConstantOp max, OpBuilder &rewriter) {
+ auto smallerThanMin =
+ rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, arg, min);
+ auto minOrArg =
+ rewriter.create<arith::SelectOp>(loc, smallerThanMin, min, arg);
+ auto largerThanMax =
+ rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, max, arg);
+ return rewriter.create<arith::SelectOp>(loc, largerThanMax, max, minOrArg);
+}
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index cd405cdd03b04..47efb8e72cb1c 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -198,13 +198,11 @@ func.func @test_simple_f32(%arg0: tensor<1xf32>) -> () {
%13 = "tosa.select"(%10, %0, %1) : (tensor<1xi1>, tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
// CHECK: linalg.generic
- // CHECK: arith.cmpf
- // CHECK: select
+ // CHECK: arith.maxf
%14 = "tosa.maximum"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
// CHECK: linalg.generic
- // CHECK: arith.cmpf
- // CHECK: select
+ // CHECK: arith.minf
%15 = "tosa.minimum"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
// CHECK: linalg.generic
@@ -216,13 +214,13 @@ func.func @test_simple_f32(%arg0: tensor<1xf32>) -> () {
%17 = "tosa.floor"(%0) : (tensor<1xf32>) -> tensor<1xf32>
// CHECK: linalg.generic
- // CHECK: arith.cmpf
- // CHECK: select
+ // CHECK: arith.minf
+ // CHECK: arith.maxf
%18 = "tosa.clamp"(%0) {min_int = 1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xf32>) -> tensor<1xf32>
// CHECK: linalg.generic
- // CHECK: arith.cmpf
- // CHECK: select
+ // CHECK: arith.minf
+ // CHECK: arith.maxf
%19 = "tosa.reluN"(%0) {max_int = 5 : i64, max_fp = 5.0 : f32} : (tensor<1xf32>) -> tensor<1xf32>
// CHECK: linalg.generic
@@ -241,10 +239,8 @@ func.func @test_simple_f32(%arg0: tensor<1xf32>) -> () {
// CHECK: arith.subf
// CHECK: arith.cmpf olt
// CHECK: select
- // CHECK: arith.cmpf olt
- // CHECK: select
- // CHECK: arith.cmpf olt
- // CHECK: select
+ // CHECK: arith.minf
+ // CHECK: arith.maxf
// CHECK: arith.fptosi
%21 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xi32>
@@ -451,20 +447,22 @@ func.func @test_simple_ui8(%arg0: tensor<1xi8>) -> () {
// CHECK-LABEL: @test_i8
func.func @test_i8(%arg0: tensor<1xi8>) -> () {
// CHECK: linalg.generic
+ // CHECK: ^bb0(%[[ARG1:.+]]: i8,
// CHECK-DAG: %[[C127:.+]] = arith.constant -127
// CHECK-DAG: %[[C126:.+]] = arith.constant 126
- // CHECK-DAG: %[[CMP1:.+]] = arith.cmpi slt, %arg1, %[[C127]]
+ // CHECK-DAG: %[[CMP1:.+]] = arith.cmpi slt, %[[ARG1]], %[[C127]]
// CHECK-DAG: %[[SEL1:.+]] = arith.select %[[CMP1]], %[[C127]]
- // CHECK-DAG: %[[CMP2:.+]] = arith.cmpi slt, %[[C126]], %arg1
+ // CHECK-DAG: %[[CMP2:.+]] = arith.cmpi slt, %[[C126]], %[[ARG1]]
// CHECK: %[[SEL2:.+]] = arith.select %[[CMP2]], %[[C126]], %[[SEL1]]
%0 = "tosa.clamp"(%arg0) {min_int = -127 : i64, max_int = 126 : i64, min_fp = 0.0 : f32, max_fp = 0.0 : f32} : (tensor<1xi8>) -> tensor<1xi8>
// CHECK: linalg.generic
+ // CHECK: ^bb0(%[[ARG1:.+]]: i8,
// CHECK-DAG: %[[C128:.+]] = arith.constant -128
// CHECK-DAG: %[[C127:.+]] = arith.constant 127
- // CHECK-DAG: %[[CMP1:.+]] = arith.cmpi slt, %arg1, %[[C128]]
+ // CHECK-DAG: %[[CMP1:.+]] = arith.cmpi slt, %[[ARG1]], %[[C128]]
// CHECK-DAG: %[[SEL1:.+]] = arith.select %[[CMP1]], %[[C128]]
- // CHECK-DAG: %[[CMP2:.+]] = arith.cmpi slt, %[[C127]], %arg1
+ // CHECK-DAG: %[[CMP2:.+]] = arith.cmpi slt, %[[C127]], %[[ARG1]]
// CHECK: %[[SEL2:.+]] = arith.select %[[CMP2]], %[[C127]], %[[SEL1]]
%1 = "tosa.clamp"(%arg0) {min_int = -130 : i64, max_int = 130 : i64, min_fp = 0.0 : f32, max_fp = 0.0 : f32} : (tensor<1xi8>) -> tensor<1xi8>
@@ -476,12 +474,11 @@ func.func @test_i8(%arg0: tensor<1xi8>) -> () {
// CHECK-LABEL: @test_clamp_f16
func.func @test_clamp_f16(%arg0: tensor<1xf16>) -> () {
// CHECK: linalg.generic
+ // CHECK: ^bb0(%[[ARG1:.+]]: f16,
// CHECK-DAG: %[[C0:.+]] = arith.constant 0.0
// CHECK-DAG: %[[C6:.+]] = arith.constant 6.0
- // CHECK-DAG: %[[CMP1:.+]] = arith.cmpf olt, %arg1, %[[C0]]
- // CHECK-DAG: %[[SEL1:.+]] = arith.select %[[CMP1]], %[[C0]]
- // CHECK-DAG: %[[CMP2:.+]] = arith.cmpf olt, %[[C6]], %arg1
- // CHECK: %[[SEL2:.+]] = arith.select %[[CMP2]], %[[C6]], %[[SEL1]]
+ // CHECK-DAG: %[[MIN:.+]] = arith.minf %[[ARG1]], %[[C0]]
+ // CHECK-DAG: %[[MAX:.+]] = arith.maxf %[[MIN]], %[[C6]]
%0 = "tosa.clamp"(%arg0) {min_int = 0 : i64, max_int = 0 : i64, min_fp = 0.0 : f32, max_fp = 6.0 : f32} : (tensor<1xf16>) -> tensor<1xf16>
return
@@ -732,15 +729,13 @@ func.func @reduce_float(%arg0: tensor<5x4xf32>) -> () {
// CHECK: arith.constant 3.40282347E+38 : f32
// CHECK: linalg.fill
// CHECK: linalg.generic
- // CHECK: arith.cmpf olt
- // CHECK: select
+ // CHECK: arith.minf
%3 = "tosa.reduce_min"(%arg0) {axis = 0 : i64} : (tensor<5x4xf32>) -> tensor<1x4xf32>
// CHECK: arith.constant -3.40282347E+38 : f32
// CHECK: linalg.fill
// CHECK: linalg.generic
- // CHECK: arith.cmpf ogt
- // CHECK: select
+ // CHECK: arith.maxf
%4 = "tosa.reduce_max"(%arg0) {axis = 0 : i64} : (tensor<5x4xf32>) -> tensor<1x4xf32>
return
}
@@ -803,9 +798,8 @@ func.func @reduce_float_dyn_multiple(%arg0: tensor<?x?xf32>) -> () {
// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[CMIN]]{{.*}}outs(%[[INIT]]
// CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor<?x?xf32>) outs(%[[FILL]] : tensor<?xf32>)
// CHECK: ^bb0(%arg1: f32, %arg2: f32)
- // CHECK: %[[CMP:.+]] = arith.cmpf ogt, %arg1, %arg2 : f32
- // CHECK: %[[RES:.+]] = arith.select %[[CMP]], %arg1, %arg2 : f32
- // CHECK: linalg.yield %[[RES]] : f32
+ // CHECK: %[[MAX:.+]] = arith.maxf %arg1, %arg2 : f32
+ // CHECK: linalg.yield %[[MAX]] : f32
// CHECK: tensor.expand_shape %[[GENERIC]] {{\[}}[0, 1]] : tensor<?xf32> into tensor<?x1xf32>
%0 = "tosa.reduce_max"(%arg0) {axis = 1 : i64} : (tensor<?x?xf32>) -> tensor<?x1xf32>
return
More information about the Mlir-commits
mailing list