[Mlir-commits] [mlir] TosaToLinalg: Support unsigned tosa.clamp (PR #91749)
Matthias Gehre
llvmlistbot at llvm.org
Thu Jun 20 15:02:37 PDT 2024
https://github.com/mgehre-amd updated https://github.com/llvm/llvm-project/pull/91749
>From 4b801e0547eaf912f969e3320063e2d0f6641fba Mon Sep 17 00:00:00 2001
From: Matthias Gehre <matthias.gehre at amd.com>
Date: Thu, 4 Apr 2024 12:51:51 +0200
Subject: [PATCH] TosaToLinalg: Fix unsigned tosa.clamp
Plump the TypeConverter into PointwiseConverter,
and emit unsigned comparisons when the input type is unsigned.
---
.../Conversion/TosaToLinalg/TosaToLinalg.h | 3 +-
.../mlir/Dialect/Tosa/Utils/ConversionUtils.h | 2 +-
.../Conversion/TosaToLinalg/TosaToLinalg.cpp | 101 +++++++++++-------
.../TosaToLinalg/TosaToLinalgNamed.cpp | 3 +-
.../TosaToLinalg/TosaToLinalgPass.cpp | 5 +-
.../Dialect/Tosa/Utils/ConversionUtils.cpp | 6 +-
.../TosaToLinalg/tosa-to-linalg.mlir | 30 +++++-
7 files changed, 105 insertions(+), 45 deletions(-)
diff --git a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
index 67965e34d8a3d..c84e4f17c38d8 100644
--- a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
+++ b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
@@ -47,7 +47,8 @@ void addTosaToLinalgPasses(
void registerTosaToLinalgPipelines();
/// Populates conversion passes from TOSA dialect to Linalg dialect.
-void populateTosaToLinalgConversionPatterns(RewritePatternSet *patterns);
+void populateTosaToLinalgConversionPatterns(TypeConverter &converter,
+ RewritePatternSet *patterns);
/// Populates conversion passes from TOSA dialect to Linalg named operations.
void populateTosaToLinalgNamedConversionPatterns(
diff --git a/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
index ca59b221d03eb..ceab7d9c628a5 100644
--- a/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
+++ b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
@@ -37,7 +37,7 @@ Value clampFloatHelper(Location loc, Value arg, Value min, Value max,
// Takes the parameters for a clamp and turns it into a series of ops for
// integer inputs.
Value clampIntHelper(Location loc, Value arg, Value min, Value max,
- OpBuilder &rewriter);
+ OpBuilder &rewriter, bool isUnsigned);
// Determines whether the integer value falls witin the range of integer type.
bool validIntegerRange(IntegerType ty, int64_t value);
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 8ad8e41414656..1442f2ad72255 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -46,10 +46,9 @@ createConstFromIntAttribute(Operation *op, const std::string &attrName,
op->getLoc(), IntegerAttr::get(requiredAttrType, castedN));
}
-static Value
-createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
- ArrayRef<Type> resultTypes,
- PatternRewriter &rewriter) {
+static Value createLinalgBodyCalculationForElementwiseOp(
+ Operation *op, ValueRange args, ArrayRef<Type> resultTypes,
+ ConversionPatternRewriter &rewriter) {
Location loc = op->getLoc();
auto elementTy =
cast<ShapedType>(op->getOperand(0).getType()).getElementType();
@@ -186,7 +185,8 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
Value max = rewriter.create<arith::ConstantIntOp>(
loc, APInt::getSignedMaxValue(inputBitWidth).getSExtValue(),
intermediateType);
- auto clamp = clampIntHelper(loc, sub, min, max, rewriter);
+ auto clamp =
+ clampIntHelper(loc, sub, min, max, rewriter, /*isUnsigned=*/false);
// Truncate to the final value.
return rewriter.create<arith::TruncIOp>(loc, elementTy, clamp);
@@ -389,25 +389,33 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
int64_t max =
cast<IntegerAttr>(op->getAttr("max_int")).getValue().getSExtValue();
+ int64_t minRepresentable = std::numeric_limits<int64_t>::min();
+ int64_t maxRepresentable = std::numeric_limits<int64_t>::max();
if (intTy.isUnsignedInteger()) {
- min = std::max(min, (int64_t)0);
- max = std::min(
- max,
- APInt::getMaxValue(intTy.getIntOrFloatBitWidth()).getSExtValue());
- } else {
- min =
- std::max(min, APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
- .getSExtValue());
- max =
- std::min(max, APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
- .getSExtValue());
+ minRepresentable = 0;
+ if (intTy.getIntOrFloatBitWidth() <= 63) {
+ maxRepresentable = (int64_t)APInt::getMaxValue(intTy.getIntOrFloatBitWidth())
+ .getZExtValue();
+ }
+ } else if(intTy.getIntOrFloatBitWidth() <= 64) {
+ // Ensure that min & max fit into signed n-bit constants.
+ minRepresentable = APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
+ .getSExtValue();
+ maxRepresentable = APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
+ .getSExtValue();
}
+ // Ensure that the bounds are representable as n-bit signed/unsigned integers.
+ min = std::max(min, minRepresentable);
+ max = std::max(max, minRepresentable);
+ min = std::min(min, maxRepresentable);
+ max = std::min(max, maxRepresentable);
auto minVal = rewriter.create<arith::ConstantIntOp>(
loc, min, intTy.getIntOrFloatBitWidth());
auto maxVal = rewriter.create<arith::ConstantIntOp>(
loc, max, intTy.getIntOrFloatBitWidth());
- return clampIntHelper(loc, args[0], minVal, maxVal, rewriter);
+ return clampIntHelper(loc, args[0], minVal, maxVal, rewriter,
+ intTy.isUnsignedInteger());
}
// tosa::SigmoidOp
@@ -615,10 +623,9 @@ static Value expandRank(PatternRewriter &rewriter, Location loc, Value tensor,
}
static SmallVector<Value> expandInputRanks(PatternRewriter &rewriter,
- Location loc, Operation *operation) {
- auto rank =
- cast<RankedTensorType>(operation->getResultTypes().front()).getRank();
- return llvm::map_to_vector(operation->getOperands(), [&](Value operand) {
+ Location loc, ValueRange operands,
+ int64_t rank) {
+ return llvm::map_to_vector(operands, [&](Value operand) {
return expandRank(rewriter, loc, operand, rank);
});
}
@@ -843,11 +850,16 @@ broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc,
}
static LogicalResult
-emitElementwiseComputation(PatternRewriter &rewriter, Location loc,
+emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc,
Operation *operation, ValueRange operands,
- ArrayRef<OpFoldResult> targetShape) {
+ ArrayRef<OpFoldResult> targetShape,
+ const TypeConverter &converter) {
// Generate output tensor
- auto resultType = cast<RankedTensorType>(operation->getResultTypes().front());
+ auto resultType = cast_or_null<RankedTensorType>(
+ converter.convertType(operation->getResultTypes().front()));
+ if (!resultType) {
+ return rewriter.notifyMatchFailure(operation, "failed to convert type");
+ }
Value outputTensor = rewriter.create<tensor::EmptyOp>(
loc, targetShape, resultType.getElementType());
@@ -894,8 +906,9 @@ emitElementwiseComputation(PatternRewriter &rewriter, Location loc,
}
static LogicalResult
-elementwiseMatchAndRewriteHelper(Operation *operation,
- PatternRewriter &rewriter) {
+elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands,
+ ConversionPatternRewriter &rewriter,
+ const TypeConverter &converter) {
// Collect op properties
assert(operation->getNumResults() == 1 && "elementwise op expects 1 result");
@@ -908,13 +921,15 @@ elementwiseMatchAndRewriteHelper(Operation *operation,
// Lower operation
IndexPool indexPool;
auto loc = operation->getLoc();
- auto expandedOperands = expandInputRanks(rewriter, loc, operation);
+ auto rank =
+ cast<RankedTensorType>(operation->getResultTypes().front()).getRank();
+ auto expandedOperands = expandInputRanks(rewriter, loc, operands, rank);
auto [targetShape, masterOperands] =
computeTargetShape(rewriter, loc, indexPool, expandedOperands);
auto broadcastOperands = broadcastDynamicDimensions(
rewriter, loc, indexPool, expandedOperands, targetShape, masterOperands);
return emitElementwiseComputation(rewriter, loc, operation, broadcastOperands,
- targetShape);
+ targetShape, converter);
}
// Returns the constant initial value for a given reduction operation. The
@@ -1100,13 +1115,16 @@ static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
namespace {
template <typename SrcOp>
-class PointwiseConverter : public OpRewritePattern<SrcOp> {
+class PointwiseConverter : public OpConversionPattern<SrcOp> {
public:
- using OpRewritePattern<SrcOp>::OpRewritePattern;
+ using OpConversionPattern<SrcOp>::OpConversionPattern;
+ using typename OpConversionPattern<SrcOp>::OpAdaptor;
- LogicalResult matchAndRewrite(SrcOp op,
- PatternRewriter &rewriter) const final {
- return elementwiseMatchAndRewriteHelper(op, rewriter);
+ LogicalResult
+ matchAndRewrite(SrcOp op, OpAdaptor operands,
+ ConversionPatternRewriter &rewriter) const final {
+ return elementwiseMatchAndRewriteHelper(
+ op, operands.getOperands(), rewriter, *this->getTypeConverter());
}
};
@@ -1279,7 +1297,7 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
loc, nestedBuilder.getI32IntegerAttr(intMax));
value = clampIntHelper(nestedLoc, value, intMinVal, intMaxVal,
- nestedBuilder);
+ nestedBuilder, /*isUnsigned=*/false);
if (outIntType.getWidth() < 32) {
value = nestedBuilder.create<arith::TruncIOp>(
@@ -1643,7 +1661,7 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
auto offset = b.create<arith::SelectOp>(pred, one, zeroI32);
val = b.create<arith::AddIOp>(val, offset);
- val = clampIntHelper(loc, val, zeroI32, max, b);
+ val = clampIntHelper(loc, val, zeroI32, max, b, /*isUnsigned=*/false);
return b.create<arith::IndexCastOp>(b.getIndexType(), val);
};
@@ -1664,8 +1682,10 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
Value max, ImplicitLocOpBuilder &b) {
val0 = in;
val1 = b.create<arith::AddIOp>(val0, oneVal);
- val0 = clampIntHelper(loc, val0, zeroI32, max, b);
- val1 = clampIntHelper(loc, val1, zeroI32, max, b);
+ val0 =
+ clampIntHelper(loc, val0, zeroI32, max, b, /*isUnsigned=*/false);
+ val1 =
+ clampIntHelper(loc, val1, zeroI32, max, b, /*isUnsigned=*/false);
val0 = b.create<arith::IndexCastOp>(b.getIndexType(), val0);
val1 = b.create<arith::IndexCastOp>(b.getIndexType(), val1);
};
@@ -2555,7 +2575,7 @@ struct FFT2dConverter final : OpRewritePattern<FFT2dOp> {
} // namespace
void mlir::tosa::populateTosaToLinalgConversionPatterns(
- RewritePatternSet *patterns) {
+ TypeConverter &converter, RewritePatternSet *patterns) {
// We have multiple resize coverters to handle degenerate cases.
patterns->add<GenericResizeConverter>(patterns->getContext(),
@@ -2602,7 +2622,10 @@ void mlir::tosa::populateTosaToLinalgConversionPatterns(
PointwiseConverter<tosa::CeilOp>,
PointwiseConverter<tosa::FloorOp>,
PointwiseConverter<tosa::ClampOp>,
- PointwiseConverter<tosa::SigmoidOp>,
+ PointwiseConverter<tosa::SigmoidOp>
+ >(converter, patterns->getContext());
+
+ patterns->add<
IdentityNConverter<tosa::IdentityOp>,
ReduceConverter<tosa::ReduceAllOp>,
ReduceConverter<tosa::ReduceAnyOp>,
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index d8fb3abc0bef8..77c3d2e875791 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -1015,7 +1015,8 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
auto max = rewriter.create<arith::ConstantIntOp>(
loc, APInt::getSignedMaxValue(outBitwidth).getSExtValue(),
accETy);
- auto clamp = clampIntHelper(loc, scaled, min, max, rewriter);
+ auto clamp = clampIntHelper(loc, scaled, min, max, rewriter,
+ /*isUnsigned=*/false);
poolVal = clamp;
// Convert type.
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
index 8904e3253922c..44036d7c31a91 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
@@ -63,8 +63,11 @@ struct TosaToLinalg : public impl::TosaToLinalgBase<TosaToLinalg> {
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
+ TypeConverter converter;
+ tosa::populateTosaTypeConversion(converter);
+
FunctionOpInterface func = getOperation();
- mlir::tosa::populateTosaToLinalgConversionPatterns(&patterns);
+ mlir::tosa::populateTosaToLinalgConversionPatterns(converter, &patterns);
if (failed(applyFullConversion(func, target, std::move(patterns))))
signalPassFailure();
}
diff --git a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
index 4fc97115064f3..f276924a8a9f6 100644
--- a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
+++ b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
@@ -38,7 +38,11 @@ Value mlir::tosa::clampFloatHelper(Location loc, Value arg, Value min,
}
Value mlir::tosa::clampIntHelper(Location loc, Value arg, Value min, Value max,
- OpBuilder &rewriter) {
+ OpBuilder &rewriter, bool isUnsigned) {
+ if (isUnsigned) {
+ auto minOrArg = rewriter.create<arith::MaxUIOp>(loc, min, arg);
+ return rewriter.create<arith::MinUIOp>(loc, max, minOrArg);
+ }
auto minOrArg = rewriter.create<arith::MaxSIOp>(loc, min, arg);
return rewriter.create<arith::MinSIOp>(loc, max, minOrArg);
}
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 45b39f79a2a72..5187d79fd4c0b 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -606,7 +606,7 @@ func.func @test_simple_ui8(%arg0: tensor<1xui8>) -> () {
// -----
// CHECK-LABEL: @test_simple_i32
-func.func @test_simple_i32(%arg0: tensor<1xi32>) -> () {
+func.func @test_simple_i32(%arg0: tensor<1xi32>, %unsigned: tensor<1xui32>, %unsigned64: tensor<1xui64>) -> () {
// CHECK: linalg.generic
// CHECK: arith.addi
%0 = tosa.add %arg0, %arg0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
@@ -700,6 +700,34 @@ func.func @test_simple_i32(%arg0: tensor<1xi32>) -> () {
// CHECK-DAG: arith.minsi
%19 = tosa.clamp %0 {min_int = 1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32>
+ // CHECK: linalg.generic
+ // CHECK-DAG: %[[LB:.*]] = arith.constant 4 : i32
+ // CHECK-DAG: %[[UB:.*]] = arith.constant 32 : i32
+ // CHECK-DAG: arith.maxui %[[LB]],
+ // CHECK-DAG: arith.minui %[[UB]],
+ %u0 = tosa.clamp %unsigned {min_int = 4 : i64, max_int = 32 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xui32>) -> tensor<1xui32>
+
+ // CHECK: linalg.generic
+ // CHECK-DAG: %[[LB:.*]] = arith.constant -1 : i32
+ // CHECK-DAG: %[[UB:.*]] = arith.constant -1 : i32
+ // CHECK-DAG: arith.maxui %[[LB]],
+ // CHECK-DAG: arith.minui %[[UB]],
+ %u1 = tosa.clamp %unsigned {min_int = 9223372036854775807 : i64, max_int = 9223372036854775807 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xui32>) -> tensor<1xui32>
+
+ // CHECK: linalg.generic
+ // CHECK-DAG: %[[LB:.*]] = arith.constant 0 : i32
+ // CHECK-DAG: %[[UB:.*]] = arith.constant 0 : i32
+ // CHECK-DAG: arith.maxui %[[LB]],
+ // CHECK-DAG: arith.minui %[[UB]],
+ %u2 = tosa.clamp %unsigned {min_int = -3 : i64, max_int = -2 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xui32>) -> tensor<1xui32>
+
+ // CHECK: linalg.generic
+ // CHECK-DAG: %[[LB:.*]] = arith.constant 0 : i64
+ // CHECK-DAG: %[[UB:.*]] = arith.constant 9223372036854775807 : i64
+ // CHECK-DAG: arith.maxui %[[LB]],
+ // CHECK-DAG: arith.minui %[[UB]],
+ %u3 = tosa.clamp %unsigned64 {min_int = -3 : i64, max_int = 9223372036854775807 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xui64>) -> tensor<1xui64>
+
// CHECK: linalg.generic
// CHECK: arith.trunci
%20 = tosa.cast %0 : (tensor<1xi32>) -> tensor<1xi16>
More information about the Mlir-commits
mailing list