[Mlir-commits] [mlir] 286a9d4 - [mlir][tosa] Add lowering for tosa.rescale to linalg.generic
Rob Suderman
llvmlistbot at llvm.org
Thu Mar 18 16:15:09 PDT 2021
Author: Rob Suderman
Date: 2021-03-18T16:14:05-07:00
New Revision: 286a9d467ea904490548a25e3c73ad0d50190b43
URL: https://github.com/llvm/llvm-project/commit/286a9d467ea904490548a25e3c73ad0d50190b43
DIFF: https://github.com/llvm/llvm-project/commit/286a9d467ea904490548a25e3c73ad0d50190b43.diff
LOG: [mlir][tosa] Add lowering for tosa.rescale to linalg.generic
This adds a tosa.apply_scale operation that handles the scaling operation
common to quantized operatons. This scalar operation is lowered
in TosaToStandard.
We use a separate ApplyScale factorization as this is a replicable pattern
within TOSA. ApplyScale can be reused within pool/convolution/mul/matmul
for their quantized variants.
Tests are added to both tosa-to-standard and tosa-to-linalg-on-tensors
that verify each pass is correct.
Reviewed By: silvas
Differential Revision: https://reviews.llvm.org/D98753
Added:
Modified:
mlir/include/mlir/Conversion/TosaToStandard/TosaToStandard.h
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp
mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp
mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/TosaToStandard/TosaToStandard.h b/mlir/include/mlir/Conversion/TosaToStandard/TosaToStandard.h
index 82555003661e..5a63d787b38a 100644
--- a/mlir/include/mlir/Conversion/TosaToStandard/TosaToStandard.h
+++ b/mlir/include/mlir/Conversion/TosaToStandard/TosaToStandard.h
@@ -23,6 +23,9 @@ std::unique_ptr<Pass> createTosaToStandard();
void populateTosaToStandardConversionPatterns(
MLIRContext *context, OwningRewritePatternList *patterns);
+void populateTosaRescaleToStandardConversionPatterns(
+ MLIRContext *context, OwningRewritePatternList *patterns);
+
/// Populates passes to convert from TOSA to Standard.
void addTosaToStandardPasses(OpPassManager &pm);
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index c9790596ed88..576471562bf3 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1494,6 +1494,30 @@ def Tosa_RescaleOp: Tosa_Op<"rescale", [NoSideEffect]> {
);
}
+def Tosa_ApplyScaleOp: Tosa_Op<"apply_scale", [NoSideEffect] # ElementwiseMappable.traits> {
+ let summary = "Rescale scalar operator for Tosa tensor operators";
+
+ let description = [{
+ Applies rescaling for fixed point values. This behavior is replicated in
+ multiple quantized operations (mul, convolution, rescale, matmul, pooling).
+
+ The commonplace implementation is to use i64 operations to avoid integer
+ overflow with target specific implementations can use native operations to
+ avoid wider than necessary types.
+ }];
+
+ let arguments = (ins
+ Tosa_Int32Like:$value,
+ Tosa_Int32Like:$multiplier,
+ Tosa_Int8Like:$shift,
+ BoolAttr:$double_round
+ );
+
+ let results = (outs
+ Tosa_Int32:$output
+ );
+}
+
//===----------------------------------------------------------------------===//
// TOSA Spec Section 2.13
// Operator Class: Data Node Ops.
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index fe3eee7168c6..64314f06aac2 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -127,6 +127,21 @@ def Tosa_Tensor1Dto6D : TensorRankOf<[Tosa_AnyNumber], [1,2,3,4,5,6]>;
def Tosa_TensorUpto4D : TensorRankOf<[Tosa_AnyNumber], [0,1,2,3,4]>;
def Tosa_TensorUpto6D : TensorRankOf<[Tosa_AnyNumber], [0,1,2,3,4,5,6]>;
+//===----------------------------------------------------------------------===//
+// Generic scalar, vector, or tensor of a particular type.
+//===----------------------------------------------------------------------===//
+
+class Tosa_TypeLike<list<Type> types, string description = ""> : TypeConstraint<Or<[
+ AnyTypeOf<types>.predicate,
+ VectorOf<types>.predicate,
+ TensorOf<types>.predicate]>,
+ "signless-integer-32-like">;
+
+def Tosa_Int8Like : Tosa_TypeLike<[Tosa_Int8], "signless-integer-8-bit-like">;
+def Tosa_Int16Like : Tosa_TypeLike<[Tosa_Int16], "signless-integer-16-bit-like">;
+def Tosa_Int32Like : Tosa_TypeLike<[Tosa_Int32], "signless-integer-32-bit-like">;
+def Tosa_Int64Like : Tosa_TypeLike<[Tosa_Int64], "signless-integer-64-bit-like">;
+
//===----------------------------------------------------------------------===//
// Attribute predicates and classes.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index dd2725cbd0fa..5db47b423d89 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -32,22 +32,49 @@ static SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops) {
template <typename T>
static mlir::ConstantOp
createConstFromIntAttribute(Operation *op, std::string attrName,
- Type requiredAttrType, PatternRewriter &rewriter) {
+ Type requiredAttrType, OpBuilder &rewriter) {
auto castedN = static_cast<T>(
op->getAttr(attrName).cast<IntegerAttr>().getValue().getSExtValue());
return rewriter.create<mlir::ConstantOp>(
op->getLoc(), IntegerAttr::get(requiredAttrType, castedN));
}
+template <typename T>
+static void getValuesFromIntArrayAttribute(ArrayAttr attr,
+ SmallVector<T> &arrayValues) {
+ for (Attribute val : attr.getValue()) {
+ arrayValues.push_back(val.cast<IntegerAttr>().getValue().getSExtValue());
+ }
+}
+
+// Generates an affine map for parallel operations on a given type. This
+// performs implicit broadcasting across any dimension of size-1.
+static AffineMap createAffineMapForType(ShapedType type,
+ PatternRewriter &rewriter) {
+ unsigned rank = type.getRank();
+ auto shape = type.getShape();
+ SmallVector<AffineExpr, 4> dimExprs;
+ dimExprs.reserve(rank);
+ for (unsigned i = 0; i < rank; ++i) {
+ // If the dimension is one we can broadcast the input with a constant
+ // affine expression.
+ if (shape[i] == 1)
+ dimExprs.push_back(rewriter.getAffineConstantExpr(0));
+ else
+ dimExprs.push_back(rewriter.getAffineDimExpr(i));
+ }
+ return AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, dimExprs,
+ rewriter.getContext());
+}
+
template <typename T, typename P>
-static mlir::SelectOp clampHelper(Operation *op, ValueRange args,
- mlir::ConstantOp min, mlir::ConstantOp max,
- P pred, PatternRewriter &rewriter) {
- Location loc = op->getLoc();
- auto smallerThanMin = rewriter.create<T>(loc, pred, args[0], min);
+static mlir::SelectOp clampHelper(Location loc, Value arg, mlir::ConstantOp min,
+ mlir::ConstantOp max, P pred,
+ OpBuilder &rewriter) {
+ auto smallerThanMin = rewriter.create<T>(loc, pred, arg, min);
auto minOrArg =
- rewriter.create<mlir::SelectOp>(loc, smallerThanMin, min, args[0]);
- auto largerThanMax = rewriter.create<T>(loc, pred, max, args[0]);
+ rewriter.create<mlir::SelectOp>(loc, smallerThanMin, min, arg);
+ auto largerThanMax = rewriter.create<T>(loc, pred, max, arg);
return rewriter.create<mlir::SelectOp>(loc, largerThanMax, max, minOrArg);
}
@@ -211,7 +238,7 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
op->getAttr("min_fp"));
auto max = rewriter.create<mlir::ConstantOp>(loc, elementTy,
op->getAttr("max_fp"));
- return clampHelper<mlir::CmpFOp>(op, args, min, max, CmpFPredicate::OLT,
+ return clampHelper<mlir::CmpFOp>(loc, args[0], min, max, CmpFPredicate::OLT,
rewriter);
}
@@ -220,7 +247,7 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
rewriter);
auto max = createConstFromIntAttribute<int32_t>(op, "max_int", elementTy,
rewriter);
- return clampHelper<mlir::CmpIOp>(op, args, min, max, CmpIPredicate::slt,
+ return clampHelper<mlir::CmpIOp>(loc, args[0], min, max, CmpIPredicate::slt,
rewriter);
}
@@ -230,7 +257,7 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
rewriter.create<mlir::ConstantOp>(loc, FloatAttr::get(elementTy, 0));
auto n = rewriter.create<mlir::ConstantOp>(loc, elementTy,
op->getAttr("max_fp"));
- return clampHelper<mlir::CmpFOp>(op, args, zero, n, CmpFPredicate::OLT,
+ return clampHelper<mlir::CmpFOp>(loc, args[0], zero, n, CmpFPredicate::OLT,
rewriter);
}
@@ -239,7 +266,7 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
rewriter.create<mlir::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
auto n = createConstFromIntAttribute<int32_t>(op, "max_int", elementTy,
rewriter);
- return clampHelper<mlir::CmpIOp>(op, args, zero, n, CmpIPredicate::slt,
+ return clampHelper<mlir::CmpIOp>(loc, args[0], zero, n, CmpIPredicate::slt,
rewriter);
}
@@ -290,21 +317,9 @@ elementwiseMatchAndRewriteHelper(Operation *operation,
indexingMaps.reserve(operation->getNumOperands() + bodyResultTypes.size());
// Input indexing maps may be broadcasted.
- for (Type types : operation->getOperandTypes()) {
- auto shape = types.cast<ShapedType>().getShape();
- SmallVector<AffineExpr, 4> dimExprs;
- dimExprs.reserve(nloops);
- for (unsigned i = 0; i < nloops; ++i) {
- // If the dimension is one we can broadcast the input with a constant
- // affine expression.
- if (shape[i] == 1)
- dimExprs.push_back(rewriter.getAffineConstantExpr(0));
- else
- dimExprs.push_back(rewriter.getAffineDimExpr(i));
- }
- indexingMaps.push_back(AffineMap::get(/*dimCount=*/nloops,
- /*symbolCount=*/0, dimExprs,
- rewriter.getContext()));
+ for (Type type : operation->getOperandTypes()) {
+ indexingMaps.push_back(
+ createAffineMapForType(type.cast<ShapedType>(), rewriter));
}
indexingMaps.append(operation->getNumResults(),
@@ -632,6 +647,142 @@ class TransposeConverter : public OpRewritePattern<tosa::TransposeOp> {
}
};
+class RescaleOpConverter : public OpRewritePattern<tosa::RescaleOp> {
+public:
+ using OpRewritePattern<tosa::RescaleOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tosa::RescaleOp op,
+ PatternRewriter &rewriter) const final {
+ auto loc = op.getLoc();
+ auto input = op.input();
+ auto inputTy = op.input().getType().cast<ShapedType>();
+ auto outputTy = op.output().getType().cast<ShapedType>();
+ unsigned rank = inputTy.getRank();
+
+ if (!outputTy.hasStaticShape())
+ return rewriter.notifyMatchFailure(
+ op, "tosa to linalg conversion expects statically shaped tensors");
+
+ // The shift and multiplier values.
+ SmallVector<int32_t> multiplierValues;
+ getValuesFromIntArrayAttribute(op.multiplier(), multiplierValues);
+
+ SmallVector<int8_t> shiftValues;
+ getValuesFromIntArrayAttribute(op.shift(), shiftValues);
+
+ // Double round only occurs if shift is greater than 31, check that this
+ // is ever true.
+ bool doubleRound =
+ op.double_round() &&
+ llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
+
+ // We need to broadcast along the last dimension, so make all dims 1.
+ SmallVector<int64_t> multiplierShape;
+ multiplierShape.resize(rank, 1);
+
+ SmallVector<int64_t> shiftShape;
+ shiftShape.resize(rank, 1);
+
+ // Set the channel dimension to match the number of shift/broadcast
+ // channels.
+ if (!multiplierShape.empty())
+ multiplierShape.back() = multiplierValues.size();
+ if (!shiftShape.empty())
+ shiftShape.back() = shiftValues.size();
+
+ // Create the tensor types.
+ auto multiplierType =
+ RankedTensorType::get(multiplierShape, rewriter.getI32Type());
+ auto shiftType =
+ RankedTensorType::get(shiftShape, rewriter.getIntegerType(8));
+
+ auto multiplierConst = rewriter.create<ConstantOp>(
+ loc, DenseIntElementsAttr::get(multiplierType, multiplierValues));
+
+ auto shiftConst = rewriter.create<ConstantOp>(
+ loc, DenseIntElementsAttr::get(shiftType, shiftValues));
+
+ // Construct the indexing maps needed for linalg.generic ops.
+ SmallVector<Type> bodyArgTypes = {getElementTypeOrSelf(inputTy),
+ rewriter.getI32Type(),
+ rewriter.getI32Type()};
+ Value initTensor = rewriter.create<linalg::InitTensorOp>(
+ loc, ArrayRef<Value>({}), outputTy.getShape(),
+ outputTy.getElementType());
+
+ SmallVector<AffineMap, 4> indexingMaps;
+
+ // Indexing map for input values.
+ indexingMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
+
+ // Shift and multiplier will need to broadcast across their non channel
+ // values.
+ indexingMaps.push_back(createAffineMapForType(multiplierType, rewriter));
+ indexingMaps.push_back(createAffineMapForType(shiftType, rewriter));
+
+ // Indexing maps for output values.
+ indexingMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
+
+ auto linalgOp = rewriter.create<linalg::GenericOp>(
+ loc, outputTy, ValueRange{input, multiplierConst, shiftConst},
+ ValueRange{initTensor}, indexingMaps, getNParallelLoopsAttrs(rank),
+ [&](OpBuilder &nestedBuilder, Location nestedLoc,
+ ValueRange blockArgs) {
+ // For now we do all of our math in 64-bit. This is not optimal but
+ // should be correct for now, consider computing correct bit depth
+ // later.
+ auto inputZp = createConstFromIntAttribute<int32_t>(
+ op, "input_zp", nestedBuilder.getI32Type(), nestedBuilder);
+ auto outputZp = createConstFromIntAttribute<int32_t>(
+ op, "output_zp", nestedBuilder.getI32Type(), nestedBuilder);
+
+ Value value = blockArgs[0];
+ Value multiplier = blockArgs[1];
+ Value shift = blockArgs[2];
+
+ if (value.getType().getIntOrFloatBitWidth() < 32) {
+ value = nestedBuilder.create<SignExtendIOp>(
+ nestedLoc, nestedBuilder.getI32Type(), value);
+ }
+
+ value = nestedBuilder.create<SubIOp>(nestedLoc, value, inputZp);
+
+ value = nestedBuilder.create<tosa::ApplyScaleOp>(
+ loc, nestedBuilder.getI32Type(), value, multiplier, shift,
+ nestedBuilder.getBoolAttr(doubleRound));
+
+ // Move to the new zero-point.
+ value = nestedBuilder.create<AddIOp>(nestedLoc, value, outputZp);
+
+ // Saturate to the output size.
+ IntegerType outIntType =
+ blockArgs.back().getType().cast<IntegerType>();
+ unsigned outBitWidth = outIntType.getWidth();
+ auto intMin = nestedBuilder.create<ConstantOp>(
+ loc, nestedBuilder.getIntegerAttr(
+ nestedBuilder.getI32Type(),
+ APInt::getSignedMinValue(outBitWidth).getSExtValue()));
+ auto intMax = nestedBuilder.create<ConstantOp>(
+ loc, nestedBuilder.getIntegerAttr(
+ nestedBuilder.getI32Type(),
+ APInt::getSignedMaxValue(outBitWidth).getSExtValue()));
+
+ value = clampHelper<mlir::CmpIOp>(nestedLoc, value, intMin, intMax,
+ CmpIPredicate::slt, nestedBuilder);
+
+ if (outIntType.getWidth() < 32) {
+ value =
+ nestedBuilder.create<TruncateIOp>(nestedLoc, outIntType, value);
+ }
+
+ nestedBuilder.create<linalg::YieldOp>(loc, value);
+ });
+
+ rewriter.replaceOp(op, linalgOp->getResults());
+ return success();
+ }
+};
+
// At the codegen level any identity operations should be removed. Any cases
// where identity is load-bearing (e.g. cross device computation) should be
// handled before lowering to codegen.
@@ -729,5 +880,5 @@ void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
IdentityNConverter<tosa::IdentityNOp>, ReduceConverter<tosa::ReduceMinOp>,
ReduceConverter<tosa::ReduceMaxOp>, ReduceConverter<tosa::ReduceSumOp>,
ReduceConverter<tosa::ReduceProdOp>, ConcatOpConversion,
- ReshapeOpConverter, TransposeConverter>(context);
+ ReshapeOpConverter, TransposeConverter, RescaleOpConverter>(context);
}
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
index a1bd694f67af..e0f1369b43a5 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
@@ -42,6 +42,13 @@ struct TosaToLinalgOnTensors
target.addLegalDialect<linalg::LinalgDialect, memref::MemRefDialect,
StandardOpsDialect>();
target.addIllegalDialect<tosa::TosaDialect>();
+
+ // Not every TOSA op can be legalized to linalg.
+ target.addLegalOp<tosa::ApplyScaleOp>();
+ target.addLegalOp<tosa::IfOp>();
+ target.addLegalOp<tosa::ConstOp>();
+ target.addLegalOp<tosa::WhileOp>();
+
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
FuncOp func = getFunction();
diff --git a/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp b/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp
index 6e5411dd5ecb..95f5c51ff1f0 100644
--- a/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp
+++ b/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp
@@ -46,7 +46,107 @@ class SliceOpConverter : public OpRewritePattern<tosa::SliceOp> {
sliceOp, sliceOp.getType(), input, ValueRange({}), ValueRange({}),
ValueRange({}), sliceOp.start(), sliceOp.size(),
rewriter.getI64ArrayAttr(strides));
+ return success();
+ }
+};
+
+// This converts the TOSA ApplyScale operator to a set of StandardOps ops,
+// using 64-bit operations to perform the necessary multiply, bias, and shift.
+// Multiple types are used to use minimal bit width operations.
+class ApplyScaleOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> {
+public:
+ using OpRewritePattern<tosa::ApplyScaleOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tosa::ApplyScaleOp op,
+ PatternRewriter &rewriter) const final {
+ Location loc = op.getLoc();
+ Value value32 = op.value();
+ Value multiplier32 = op.multiplier();
+ Value shift8 = op.shift();
+ bool doubleRound = op.double_round();
+
+ Value one8 = rewriter.create<ConstantOp>(
+ loc, rewriter.getIntegerAttr(rewriter.getIntegerType(8), 1));
+ Value one32 = rewriter.create<ConstantOp>(
+ loc, rewriter.getIntegerAttr(rewriter.getI32Type(), 1));
+ Value one64 = rewriter.create<ConstantOp>(
+ loc, rewriter.getIntegerAttr(rewriter.getI64Type(), 1));
+
+ Value shiftSubOne8 = rewriter.create<SubIOp>(loc, shift8, one8);
+
+ // The rounding value semantics below equate to the following code:
+ // int64_t round = 1 << (shift - 1);
+ // if (double_round) {
+ // if (shift > 31 && value >= 0) round += 1<<30;
+ // if (shift > 31 && value < 0) round -= 1<<30;
+ // }
+ //
+ // Note that minimal bitwidth operators are used throughout the block.
+
+ Value shift32 = rewriter.create<mlir::SignExtendIOp>(
+ loc, rewriter.getI32Type(), shift8);
+
+ Value round64 = rewriter.create<mlir::ShiftLeftOp>(
+ loc, one64,
+ rewriter.create<SignExtendIOp>(loc, rewriter.getI64Type(),
+ shiftSubOne8));
+
+ // Double rounding is performing a round operation before the shift
+ if (doubleRound) {
+ Value zero32 = rewriter.create<ConstantOp>(
+ loc, rewriter.getZeroAttr(rewriter.getI32Type()));
+ Value thirty32 = rewriter.create<ConstantOp>(
+ loc, rewriter.getIntegerAttr(rewriter.getI32Type(), 30));
+ Value shiftThirty32 =
+ rewriter.create<mlir::ShiftLeftOp>(loc, one32, thirty32);
+ Value shiftThirty64 = rewriter.create<mlir::SignExtendIOp>(
+ loc, rewriter.getI64Type(), shiftThirty32);
+
+ // Round value needs to with be added or sbustracted depending on
+ Value roundAdd64 =
+ rewriter.create<mlir::AddIOp>(loc, round64, shiftThirty64);
+ Value roundSub64 =
+ rewriter.create<mlir::SubIOp>(loc, round64, shiftThirty64);
+
+ Value valueGreaterThanZero = rewriter.create<mlir::CmpIOp>(
+ loc, CmpIPredicate::sge, value32, zero32);
+
+ Value doubleRound64 = rewriter.create<mlir::SelectOp>(
+ loc, valueGreaterThanZero, roundAdd64, roundSub64);
+
+ // We only perform double rounding if the shift value is greater than 32.
+ Value thirtyTwo32 = rewriter.create<ConstantOp>(
+ loc, rewriter.getIntegerAttr(rewriter.getI32Type(), 32));
+ Value shiftGreaterThanThirtyTwo = rewriter.create<mlir::CmpIOp>(
+ loc, CmpIPredicate::sge, shift32, thirtyTwo32);
+ round64 = rewriter.create<mlir::SelectOp>(loc, shiftGreaterThanThirtyTwo,
+ doubleRound64, round64);
+ }
+
+ // The computation below equates to the following pseudocode:
+ // int64_t result = (int64_t)value * multiplier + round;
+ // result = result >> shift;
+ //
+ // Note that multiply and shift need to be perform in i64 to preserve bits.
+
+ Value value64 =
+ rewriter.create<SignExtendIOp>(loc, rewriter.getI64Type(), value32);
+ Value multiplier64 = rewriter.create<SignExtendIOp>(
+ loc, rewriter.getI64Type(), multiplier32);
+ Value shift64 =
+ rewriter.create<SignExtendIOp>(loc, rewriter.getI64Type(), shift8);
+
+ // Multiply as a pair of i64 values to guarantee the end value fits.
+ Value result64 = rewriter.create<MulIOp>(loc, value64, multiplier64);
+ result64 = rewriter.create<AddIOp>(loc, result64, round64);
+ result64 =
+ rewriter.create<mlir::SignedShiftRightOp>(loc, result64, shift64);
+
+ Value result32 = rewriter.create<mlir::TruncateIOp>(
+ loc, rewriter.getI32Type(), result64);
+
+ rewriter.replaceOp(op, result32);
return success();
}
};
@@ -55,5 +155,11 @@ class SliceOpConverter : public OpRewritePattern<tosa::SliceOp> {
void mlir::tosa::populateTosaToStandardConversionPatterns(
MLIRContext *context, OwningRewritePatternList *patterns) {
- patterns->insert<ConstOpConverter, SliceOpConverter>(context);
+ patterns->insert<ApplyScaleOpConverter, ConstOpConverter, SliceOpConverter>(
+ context);
+}
+
+void mlir::tosa::populateTosaRescaleToStandardConversionPatterns(
+ MLIRContext *context, OwningRewritePatternList *patterns) {
+ patterns->insert<ApplyScaleOpConverter>(context);
}
diff --git a/mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp b/mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp
index 78a0e65da81b..14c800e2f70d 100644
--- a/mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp
+++ b/mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp
@@ -33,6 +33,7 @@ struct TosaToStandard : public TosaToStandardBase<TosaToStandard> {
ConversionTarget target(getContext());
target.addIllegalOp<tosa::ConstOp>();
target.addIllegalOp<tosa::SliceOp>();
+ target.addIllegalOp<tosa::ApplyScaleOp>();
target.addLegalDialect<StandardOpsDialect>();
auto *op = getOperation();
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 9b1f6054ee06..1714f140dbfc 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -473,3 +473,54 @@ func @concat(%arg0: tensor<5x1xf32>, %arg1: tensor<6x1xf32>) -> () {
%1 = "tosa.concat"(%arg0, %arg0) { axis = 1 : i64} : (tensor<5x1xf32>, tensor<5x1xf32>) -> (tensor<5x2xf32>)
return
}
+
+// -----
+
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
+// CHECK: #[[$MAP1:.*]] = affine_map<(d0) -> (0)>
+
+// CHECK-LABEL: @rescale
+func @rescale(%arg0 : tensor<1xi8>) -> (tensor<1xi8>) {
+ // CHECK: [[C0:%.+]] = constant dense<19689>
+ // CHECK: [[C1:%.+]] = constant dense<15>
+ // CHECK: [[INIT:%.+]] = linalg.init_tensor [1]
+ // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%arg0, [[C0]], [[C1]] : tensor<1xi8>, tensor<1xi32>, tensor<1xi8>) outs([[INIT]] : tensor<1xi8>)
+ // CHECK: ^bb0([[IN:%.+]]: i8, [[MULTIPLIER:%.+]]: i32, [[SHIFT:%.+]]: i8, [[UNUSED:%.+]]: i8):
+ // CHECK: [[C243:%.+]] = constant 243
+ // CHECK: [[C252:%.+]] = constant 252
+
+ // CHECK-DAG: [[IN32:%.+]] = sexti [[IN]]
+ // CHECK-DAG: [[IN_ZEROED:%.+]] = subi [[IN32]], [[C243]]
+ // CHECK-DAG: [[SCALED:%.+]] = "tosa.apply_scale"([[IN_ZEROED]], [[MULTIPLIER]], [[SHIFT]]) {double_round = false}
+ // CHECK-DAG: [[SCALED_ZEROED:%.+]] = addi [[SCALED]], [[C252]]
+ // CHECK-DAG: [[CMIN:%.+]] = constant -128
+ // CHECK-DAG: [[CMAX:%.+]] = constant 127
+ // CHECK-DAG: [[MINLT:%.+]] = cmpi slt, [[SCALED_ZEROED]], [[CMIN]]
+ // CHECK-DAG: [[MAXLT:%.+]] = cmpi slt, [[CMAX]], [[SCALED_ZEROED]]
+ // CHECK-DAG: [[LOWER:%.+]] = select [[MINLT]], [[CMIN]], [[SCALED_ZEROED]]
+ // CHECK-DAG: [[BOUNDED:%.+]] = select [[MAXLT]], [[CMAX]], [[LOWER]]
+ // CHECK-DAG: [[TRUNC:%.+]] = trunci [[BOUNDED]]
+ // CHECK-DAG: linalg.yield [[TRUNC]]
+ %0 = "tosa.rescale"(%arg0) {input_zp = 243 : i32, output_zp = 252 : i32, multiplier = [19689 : i32], shift = [15 : i32], scale32 = false, double_round = false, per_channel = false} : (tensor<1xi8>) -> (tensor<1xi8>)
+
+ // CHECK: return [[GENERIC]]
+ return %0 : tensor<1xi8>
+}
+
+// CHECK-LABEL: @rescaleDoubleRound
+func @rescaleDoubleRound(%arg0 : tensor<1xi8>) -> (tensor<1xi8>) {
+ // CHECK: linalg.generic
+ // CHECK: "tosa.apply_scale"
+ // CHECK-SAME: {double_round = true}
+ %0 = "tosa.rescale"(%arg0) {input_zp = 243 : i32, output_zp = 252 : i32, multiplier = [19689 : i32], shift = [33 : i32], scale32 = true, double_round = true, per_channel = false} : (tensor<1xi8>) -> (tensor<1xi8>)
+ return %0 : tensor<1xi8>
+}
+
+// CHECK-LABEL: @rescaleUnnecessaryDoubleRound
+func @rescaleUnnecessaryDoubleRound(%arg0 : tensor<1xi8>) -> (tensor<1xi8>) {
+ // CHECK: linalg.generic
+ // CHECK: "tosa.apply_scale"
+ // CHECK-SAME: {double_round = false}
+ %0 = "tosa.rescale"(%arg0) {input_zp = 243 : i32, output_zp = 252 : i32, multiplier = [19689 : i32], shift = [15 : i32], scale32 = true, double_round = true, per_channel = false} : (tensor<1xi8>) -> (tensor<1xi8>)
+ return %0 : tensor<1xi8>
+}
diff --git a/mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir b/mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir
index 94925aec15c7..2c80c31cf297 100644
--- a/mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir
+++ b/mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir
@@ -9,10 +9,46 @@ func @const_test() -> (tensor<i32>) {
return %0 : tensor<i32>
}
-// ----
+// -----
func @slice(%arg0: tensor<6xf32>) ->() {
// CHECK: [[SLICE:%.+]] = subtensor %arg0[2] [1] [1]
%0 = "tosa.slice"(%arg0) {start = [2], size = [1]} : (tensor<6xf32>) -> (tensor<1xf32>)
return
}
+
+// -----
+
+func @apply_scale_test(%arg0 : i32, %arg1 : i32, %arg2 : i8) -> (i32) {
+ // CHECK: [[C1_8:%.+]] = constant 1 : i8
+ // CHECK: [[C1_32:%.+]] = constant 1 : i32
+ // CHECK: [[C1_64:%.+]] = constant 1 : i64
+ // CHECK: [[SHIFT_MINUS_ONE_8:%.+]] = subi %arg2, [[C1_8]]
+
+ // CHECK: [[SHIFT_32:%.+]] = sexti %arg2 : i8 to i32
+ // CHECK: [[SHIFT_MINUS_ONE_64:%.+]] = sexti [[SHIFT_MINUS_ONE_8]] : i8 to i64
+ // CHECK: [[SHIFTED_64:%.+]] = shift_left [[C1_64]], [[SHIFT_MINUS_ONE_64]]
+
+ // CHECK: [[C0_32:%.+]] = constant 0 : i32
+ // CHECK: [[C30_32:%.+]] = constant 30 : i32
+ // CHECK: [[SECOND_BIAS:%.+]] = shift_left [[C1_32]], [[C30_32]]
+ // CHECK: [[SECOND_BIAS_64:%.+]] = sexti [[SECOND_BIAS]] : i32 to i64
+ // CHECK: [[POSITIVE_ROUND:%.+]] = addi [[SHIFTED_64]], [[SECOND_BIAS_64]]
+ // CHECK: [[NEGATIVE_ROUND:%.+]] = subi [[SHIFTED_64]], [[SECOND_BIAS_64]]
+ // CHECK: [[VALUE_NEGATIVE:%.+]] = cmpi sge, %arg0, [[C0_32]] : i32
+ // CHECK: [[DOUBLE_ROUNDED:%.+]] = select [[VALUE_NEGATIVE]], [[POSITIVE_ROUND]], [[NEGATIVE_ROUND]] : i64
+ // CHECK: [[C32_32:%.+]] = constant 32 : i32
+ // CHECK: [[IS_32BIT_SHIFT:%.+]] = cmpi sge, [[SHIFT_32]], [[C32_32]]
+ // CHECK: [[ROUND:%.+]] = select [[IS_32BIT_SHIFT]], [[DOUBLE_ROUNDED]], [[SHIFTED_64]]
+
+ // CHECK: [[VAL_64:%.+]] = sexti %arg0 : i32 to i64
+ // CHECK: [[MULTIPLY_64:%.+]] = sexti %arg1 : i32 to i64
+ // CHECK: [[SHIFT_64:%.+]] = sexti %arg2 : i8 to i64
+ // CHECK: [[SCALED:%.+]] = muli [[VAL_64]], [[MULTIPLY_64]]
+ // CHECK: [[BIASED:%.+]] = addi [[SCALED]], [[ROUND]]
+ // CHECK: [[DOWNSHIFTED:%.+]] = shift_right_signed [[BIASED]], [[SHIFT_64]]
+ // CHECK: [[TRUNCATED:%.+]] = trunci [[DOWNSHIFTED]]
+
+ %0 = "tosa.apply_scale"(%arg0, %arg1, %arg2) {double_round = true} : (i32, i32, i8) -> i32
+ return %0 : i32
+}
More information about the Mlir-commits
mailing list