[Mlir-commits] [mlir] 286a9d4 - [mlir][tosa] Add lowering for tosa.rescale to linalg.generic

Rob Suderman llvmlistbot at llvm.org
Thu Mar 18 16:15:09 PDT 2021


Author: Rob Suderman
Date: 2021-03-18T16:14:05-07:00
New Revision: 286a9d467ea904490548a25e3c73ad0d50190b43

URL: https://github.com/llvm/llvm-project/commit/286a9d467ea904490548a25e3c73ad0d50190b43
DIFF: https://github.com/llvm/llvm-project/commit/286a9d467ea904490548a25e3c73ad0d50190b43.diff

LOG: [mlir][tosa] Add lowering for tosa.rescale to linalg.generic

This adds a tosa.apply_scale operation that handles the scaling operation
common to quantized operatons. This scalar operation is lowered
in TosaToStandard.

We use a separate ApplyScale factorization as this is a replicable pattern
within TOSA. ApplyScale can be reused within pool/convolution/mul/matmul
for their quantized variants.

Tests are added to both tosa-to-standard and tosa-to-linalg-on-tensors
that verify each pass is correct.

Reviewed By: silvas

Differential Revision: https://reviews.llvm.org/D98753

Added: 
    

Modified: 
    mlir/include/mlir/Conversion/TosaToStandard/TosaToStandard.h
    mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
    mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
    mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp
    mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
    mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/TosaToStandard/TosaToStandard.h b/mlir/include/mlir/Conversion/TosaToStandard/TosaToStandard.h
index 82555003661e..5a63d787b38a 100644
--- a/mlir/include/mlir/Conversion/TosaToStandard/TosaToStandard.h
+++ b/mlir/include/mlir/Conversion/TosaToStandard/TosaToStandard.h
@@ -23,6 +23,9 @@ std::unique_ptr<Pass> createTosaToStandard();
 void populateTosaToStandardConversionPatterns(
     MLIRContext *context, OwningRewritePatternList *patterns);
 
+void populateTosaRescaleToStandardConversionPatterns(
+    MLIRContext *context, OwningRewritePatternList *patterns);
+
 /// Populates passes to convert from TOSA to Standard.
 void addTosaToStandardPasses(OpPassManager &pm);
 

diff  --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index c9790596ed88..576471562bf3 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1494,6 +1494,30 @@ def Tosa_RescaleOp: Tosa_Op<"rescale", [NoSideEffect]> {
   );
 }
 
+def Tosa_ApplyScaleOp: Tosa_Op<"apply_scale", [NoSideEffect] # ElementwiseMappable.traits> {
+  let summary = "Rescale scalar operator for Tosa tensor operators";
+
+  let description = [{
+    Applies rescaling for fixed point values. This behavior is replicated in
+    multiple quantized operations (mul, convolution, rescale, matmul, pooling).
+
+    The commonplace implementation is to use i64 operations to avoid integer
+    overflow with target specific implementations can use native operations to
+    avoid wider than necessary types.
+  }];
+
+  let arguments = (ins
+    Tosa_Int32Like:$value,
+    Tosa_Int32Like:$multiplier,
+    Tosa_Int8Like:$shift,
+    BoolAttr:$double_round
+  );
+
+  let results = (outs
+    Tosa_Int32:$output
+  );
+}
+
 //===----------------------------------------------------------------------===//
 // TOSA Spec Section 2.13
 // Operator Class: Data Node Ops.

diff  --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index fe3eee7168c6..64314f06aac2 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -127,6 +127,21 @@ def Tosa_Tensor1Dto6D : TensorRankOf<[Tosa_AnyNumber], [1,2,3,4,5,6]>;
 def Tosa_TensorUpto4D : TensorRankOf<[Tosa_AnyNumber], [0,1,2,3,4]>;
 def Tosa_TensorUpto6D : TensorRankOf<[Tosa_AnyNumber], [0,1,2,3,4,5,6]>;
 
+//===----------------------------------------------------------------------===//
+// Generic scalar, vector, or tensor of a particular type.
+//===----------------------------------------------------------------------===//
+
+class Tosa_TypeLike<list<Type> types, string description = ""> : TypeConstraint<Or<[
+     AnyTypeOf<types>.predicate,
+     VectorOf<types>.predicate,
+     TensorOf<types>.predicate]>,
+     "signless-integer-32-like">;
+
+def Tosa_Int8Like : Tosa_TypeLike<[Tosa_Int8], "signless-integer-8-bit-like">;
+def Tosa_Int16Like : Tosa_TypeLike<[Tosa_Int16], "signless-integer-16-bit-like">;
+def Tosa_Int32Like : Tosa_TypeLike<[Tosa_Int32], "signless-integer-32-bit-like">;
+def Tosa_Int64Like : Tosa_TypeLike<[Tosa_Int64], "signless-integer-64-bit-like">;
+
 //===----------------------------------------------------------------------===//
 // Attribute predicates and classes.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index dd2725cbd0fa..5db47b423d89 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -32,22 +32,49 @@ static SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops) {
 template <typename T>
 static mlir::ConstantOp
 createConstFromIntAttribute(Operation *op, std::string attrName,
-                            Type requiredAttrType, PatternRewriter &rewriter) {
+                            Type requiredAttrType, OpBuilder &rewriter) {
   auto castedN = static_cast<T>(
       op->getAttr(attrName).cast<IntegerAttr>().getValue().getSExtValue());
   return rewriter.create<mlir::ConstantOp>(
       op->getLoc(), IntegerAttr::get(requiredAttrType, castedN));
 }
 
+template <typename T>
+static void getValuesFromIntArrayAttribute(ArrayAttr attr,
+                                           SmallVector<T> &arrayValues) {
+  for (Attribute val : attr.getValue()) {
+    arrayValues.push_back(val.cast<IntegerAttr>().getValue().getSExtValue());
+  }
+}
+
+// Generates an affine map for parallel operations on a given type. This
+// performs implicit broadcasting across any dimension of size-1.
+static AffineMap createAffineMapForType(ShapedType type,
+                                        PatternRewriter &rewriter) {
+  unsigned rank = type.getRank();
+  auto shape = type.getShape();
+  SmallVector<AffineExpr, 4> dimExprs;
+  dimExprs.reserve(rank);
+  for (unsigned i = 0; i < rank; ++i) {
+    // If the dimension is one we can broadcast the input with a constant
+    // affine expression.
+    if (shape[i] == 1)
+      dimExprs.push_back(rewriter.getAffineConstantExpr(0));
+    else
+      dimExprs.push_back(rewriter.getAffineDimExpr(i));
+  }
+  return AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, dimExprs,
+                        rewriter.getContext());
+}
+
 template <typename T, typename P>
-static mlir::SelectOp clampHelper(Operation *op, ValueRange args,
-                                  mlir::ConstantOp min, mlir::ConstantOp max,
-                                  P pred, PatternRewriter &rewriter) {
-  Location loc = op->getLoc();
-  auto smallerThanMin = rewriter.create<T>(loc, pred, args[0], min);
+static mlir::SelectOp clampHelper(Location loc, Value arg, mlir::ConstantOp min,
+                                  mlir::ConstantOp max, P pred,
+                                  OpBuilder &rewriter) {
+  auto smallerThanMin = rewriter.create<T>(loc, pred, arg, min);
   auto minOrArg =
-      rewriter.create<mlir::SelectOp>(loc, smallerThanMin, min, args[0]);
-  auto largerThanMax = rewriter.create<T>(loc, pred, max, args[0]);
+      rewriter.create<mlir::SelectOp>(loc, smallerThanMin, min, arg);
+  auto largerThanMax = rewriter.create<T>(loc, pred, max, arg);
   return rewriter.create<mlir::SelectOp>(loc, largerThanMax, max, minOrArg);
 }
 
@@ -211,7 +238,7 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
                                                  op->getAttr("min_fp"));
     auto max = rewriter.create<mlir::ConstantOp>(loc, elementTy,
                                                  op->getAttr("max_fp"));
-    return clampHelper<mlir::CmpFOp>(op, args, min, max, CmpFPredicate::OLT,
+    return clampHelper<mlir::CmpFOp>(loc, args[0], min, max, CmpFPredicate::OLT,
                                      rewriter);
   }
 
@@ -220,7 +247,7 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
                                                     rewriter);
     auto max = createConstFromIntAttribute<int32_t>(op, "max_int", elementTy,
                                                     rewriter);
-    return clampHelper<mlir::CmpIOp>(op, args, min, max, CmpIPredicate::slt,
+    return clampHelper<mlir::CmpIOp>(loc, args[0], min, max, CmpIPredicate::slt,
                                      rewriter);
   }
 
@@ -230,7 +257,7 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
         rewriter.create<mlir::ConstantOp>(loc, FloatAttr::get(elementTy, 0));
     auto n = rewriter.create<mlir::ConstantOp>(loc, elementTy,
                                                op->getAttr("max_fp"));
-    return clampHelper<mlir::CmpFOp>(op, args, zero, n, CmpFPredicate::OLT,
+    return clampHelper<mlir::CmpFOp>(loc, args[0], zero, n, CmpFPredicate::OLT,
                                      rewriter);
   }
 
@@ -239,7 +266,7 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
         rewriter.create<mlir::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
     auto n = createConstFromIntAttribute<int32_t>(op, "max_int", elementTy,
                                                   rewriter);
-    return clampHelper<mlir::CmpIOp>(op, args, zero, n, CmpIPredicate::slt,
+    return clampHelper<mlir::CmpIOp>(loc, args[0], zero, n, CmpIPredicate::slt,
                                      rewriter);
   }
 
@@ -290,21 +317,9 @@ elementwiseMatchAndRewriteHelper(Operation *operation,
   indexingMaps.reserve(operation->getNumOperands() + bodyResultTypes.size());
 
   // Input indexing maps may be broadcasted.
-  for (Type types : operation->getOperandTypes()) {
-    auto shape = types.cast<ShapedType>().getShape();
-    SmallVector<AffineExpr, 4> dimExprs;
-    dimExprs.reserve(nloops);
-    for (unsigned i = 0; i < nloops; ++i) {
-      // If the dimension is one we can broadcast the input with a constant
-      // affine expression.
-      if (shape[i] == 1)
-        dimExprs.push_back(rewriter.getAffineConstantExpr(0));
-      else
-        dimExprs.push_back(rewriter.getAffineDimExpr(i));
-    }
-    indexingMaps.push_back(AffineMap::get(/*dimCount=*/nloops,
-                                          /*symbolCount=*/0, dimExprs,
-                                          rewriter.getContext()));
+  for (Type type : operation->getOperandTypes()) {
+    indexingMaps.push_back(
+        createAffineMapForType(type.cast<ShapedType>(), rewriter));
   }
 
   indexingMaps.append(operation->getNumResults(),
@@ -632,6 +647,142 @@ class TransposeConverter : public OpRewritePattern<tosa::TransposeOp> {
   }
 };
 
+class RescaleOpConverter : public OpRewritePattern<tosa::RescaleOp> {
+public:
+  using OpRewritePattern<tosa::RescaleOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tosa::RescaleOp op,
+                                PatternRewriter &rewriter) const final {
+    auto loc = op.getLoc();
+    auto input = op.input();
+    auto inputTy = op.input().getType().cast<ShapedType>();
+    auto outputTy = op.output().getType().cast<ShapedType>();
+    unsigned rank = inputTy.getRank();
+
+    if (!outputTy.hasStaticShape())
+      return rewriter.notifyMatchFailure(
+          op, "tosa to linalg conversion expects statically shaped tensors");
+
+    // The shift and multiplier values.
+    SmallVector<int32_t> multiplierValues;
+    getValuesFromIntArrayAttribute(op.multiplier(), multiplierValues);
+
+    SmallVector<int8_t> shiftValues;
+    getValuesFromIntArrayAttribute(op.shift(), shiftValues);
+
+    // Double round only occurs if shift is greater than 31, check that this
+    // is ever true.
+    bool doubleRound =
+        op.double_round() &&
+        llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
+
+    // We need to broadcast along the last dimension, so make all dims 1.
+    SmallVector<int64_t> multiplierShape;
+    multiplierShape.resize(rank, 1);
+
+    SmallVector<int64_t> shiftShape;
+    shiftShape.resize(rank, 1);
+
+    // Set the channel dimension to match the number of shift/broadcast
+    // channels.
+    if (!multiplierShape.empty())
+      multiplierShape.back() = multiplierValues.size();
+    if (!shiftShape.empty())
+      shiftShape.back() = shiftValues.size();
+
+    // Create the tensor types.
+    auto multiplierType =
+        RankedTensorType::get(multiplierShape, rewriter.getI32Type());
+    auto shiftType =
+        RankedTensorType::get(shiftShape, rewriter.getIntegerType(8));
+
+    auto multiplierConst = rewriter.create<ConstantOp>(
+        loc, DenseIntElementsAttr::get(multiplierType, multiplierValues));
+
+    auto shiftConst = rewriter.create<ConstantOp>(
+        loc, DenseIntElementsAttr::get(shiftType, shiftValues));
+
+    // Construct the indexing maps needed for linalg.generic ops.
+    SmallVector<Type> bodyArgTypes = {getElementTypeOrSelf(inputTy),
+                                      rewriter.getI32Type(),
+                                      rewriter.getI32Type()};
+    Value initTensor = rewriter.create<linalg::InitTensorOp>(
+        loc, ArrayRef<Value>({}), outputTy.getShape(),
+        outputTy.getElementType());
+
+    SmallVector<AffineMap, 4> indexingMaps;
+
+    // Indexing map for input values.
+    indexingMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
+
+    // Shift and multiplier will need to broadcast across their non channel
+    // values.
+    indexingMaps.push_back(createAffineMapForType(multiplierType, rewriter));
+    indexingMaps.push_back(createAffineMapForType(shiftType, rewriter));
+
+    // Indexing maps for output values.
+    indexingMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
+
+    auto linalgOp = rewriter.create<linalg::GenericOp>(
+        loc, outputTy, ValueRange{input, multiplierConst, shiftConst},
+        ValueRange{initTensor}, indexingMaps, getNParallelLoopsAttrs(rank),
+        [&](OpBuilder &nestedBuilder, Location nestedLoc,
+            ValueRange blockArgs) {
+          // For now we do all of our math in 64-bit. This is not optimal but
+          // should be correct for now, consider computing correct bit depth
+          // later.
+          auto inputZp = createConstFromIntAttribute<int32_t>(
+              op, "input_zp", nestedBuilder.getI32Type(), nestedBuilder);
+          auto outputZp = createConstFromIntAttribute<int32_t>(
+              op, "output_zp", nestedBuilder.getI32Type(), nestedBuilder);
+
+          Value value = blockArgs[0];
+          Value multiplier = blockArgs[1];
+          Value shift = blockArgs[2];
+
+          if (value.getType().getIntOrFloatBitWidth() < 32) {
+            value = nestedBuilder.create<SignExtendIOp>(
+                nestedLoc, nestedBuilder.getI32Type(), value);
+          }
+
+          value = nestedBuilder.create<SubIOp>(nestedLoc, value, inputZp);
+
+          value = nestedBuilder.create<tosa::ApplyScaleOp>(
+              loc, nestedBuilder.getI32Type(), value, multiplier, shift,
+              nestedBuilder.getBoolAttr(doubleRound));
+
+          // Move to the new zero-point.
+          value = nestedBuilder.create<AddIOp>(nestedLoc, value, outputZp);
+
+          // Saturate to the output size.
+          IntegerType outIntType =
+              blockArgs.back().getType().cast<IntegerType>();
+          unsigned outBitWidth = outIntType.getWidth();
+          auto intMin = nestedBuilder.create<ConstantOp>(
+              loc, nestedBuilder.getIntegerAttr(
+                       nestedBuilder.getI32Type(),
+                       APInt::getSignedMinValue(outBitWidth).getSExtValue()));
+          auto intMax = nestedBuilder.create<ConstantOp>(
+              loc, nestedBuilder.getIntegerAttr(
+                       nestedBuilder.getI32Type(),
+                       APInt::getSignedMaxValue(outBitWidth).getSExtValue()));
+
+          value = clampHelper<mlir::CmpIOp>(nestedLoc, value, intMin, intMax,
+                                            CmpIPredicate::slt, nestedBuilder);
+
+          if (outIntType.getWidth() < 32) {
+            value =
+                nestedBuilder.create<TruncateIOp>(nestedLoc, outIntType, value);
+          }
+
+          nestedBuilder.create<linalg::YieldOp>(loc, value);
+        });
+
+    rewriter.replaceOp(op, linalgOp->getResults());
+    return success();
+  }
+};
+
 // At the codegen level any identity operations should be removed. Any cases
 // where identity is load-bearing (e.g. cross device computation) should be
 // handled before lowering to codegen.
@@ -729,5 +880,5 @@ void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
       IdentityNConverter<tosa::IdentityNOp>, ReduceConverter<tosa::ReduceMinOp>,
       ReduceConverter<tosa::ReduceMaxOp>, ReduceConverter<tosa::ReduceSumOp>,
       ReduceConverter<tosa::ReduceProdOp>, ConcatOpConversion,
-      ReshapeOpConverter, TransposeConverter>(context);
+      ReshapeOpConverter, TransposeConverter, RescaleOpConverter>(context);
 }

diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
index a1bd694f67af..e0f1369b43a5 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
@@ -42,6 +42,13 @@ struct TosaToLinalgOnTensors
     target.addLegalDialect<linalg::LinalgDialect, memref::MemRefDialect,
                            StandardOpsDialect>();
     target.addIllegalDialect<tosa::TosaDialect>();
+
+    // Not every TOSA op can be legalized to linalg.
+    target.addLegalOp<tosa::ApplyScaleOp>();
+    target.addLegalOp<tosa::IfOp>();
+    target.addLegalOp<tosa::ConstOp>();
+    target.addLegalOp<tosa::WhileOp>();
+
     target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
 
     FuncOp func = getFunction();

diff  --git a/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp b/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp
index 6e5411dd5ecb..95f5c51ff1f0 100644
--- a/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp
+++ b/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp
@@ -46,7 +46,107 @@ class SliceOpConverter : public OpRewritePattern<tosa::SliceOp> {
         sliceOp, sliceOp.getType(), input, ValueRange({}), ValueRange({}),
         ValueRange({}), sliceOp.start(), sliceOp.size(),
         rewriter.getI64ArrayAttr(strides));
+    return success();
+  }
+};
+
+// This converts the TOSA ApplyScale operator to a set of StandardOps ops,
+// using 64-bit operations to perform the necessary multiply, bias, and shift.
+// Multiple types are used to use minimal bit width operations.
+class ApplyScaleOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> {
+public:
+  using OpRewritePattern<tosa::ApplyScaleOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tosa::ApplyScaleOp op,
+                                PatternRewriter &rewriter) const final {
+    Location loc = op.getLoc();
+    Value value32 = op.value();
+    Value multiplier32 = op.multiplier();
+    Value shift8 = op.shift();
+    bool doubleRound = op.double_round();
+
+    Value one8 = rewriter.create<ConstantOp>(
+        loc, rewriter.getIntegerAttr(rewriter.getIntegerType(8), 1));
+    Value one32 = rewriter.create<ConstantOp>(
+        loc, rewriter.getIntegerAttr(rewriter.getI32Type(), 1));
+    Value one64 = rewriter.create<ConstantOp>(
+        loc, rewriter.getIntegerAttr(rewriter.getI64Type(), 1));
+
+    Value shiftSubOne8 = rewriter.create<SubIOp>(loc, shift8, one8);
+
+    // The rounding value semantics below equate to the following code:
+    //    int64_t round = 1 << (shift - 1);
+    //    if (double_round) {
+    //      if (shift > 31 && value >= 0) round += 1<<30;
+    //      if (shift > 31 && value < 0) round -= 1<<30;
+    //    }
+    //
+    // Note that minimal bitwidth operators are used throughout the block.
+
+    Value shift32 = rewriter.create<mlir::SignExtendIOp>(
+        loc, rewriter.getI32Type(), shift8);
+
+    Value round64 = rewriter.create<mlir::ShiftLeftOp>(
+        loc, one64,
+        rewriter.create<SignExtendIOp>(loc, rewriter.getI64Type(),
+                                       shiftSubOne8));
+
+    // Double rounding is performing a round operation before the shift
+    if (doubleRound) {
+      Value zero32 = rewriter.create<ConstantOp>(
+          loc, rewriter.getZeroAttr(rewriter.getI32Type()));
+      Value thirty32 = rewriter.create<ConstantOp>(
+          loc, rewriter.getIntegerAttr(rewriter.getI32Type(), 30));
 
+      Value shiftThirty32 =
+          rewriter.create<mlir::ShiftLeftOp>(loc, one32, thirty32);
+      Value shiftThirty64 = rewriter.create<mlir::SignExtendIOp>(
+          loc, rewriter.getI64Type(), shiftThirty32);
+
+      // Round value needs to with be added or sbustracted depending on
+      Value roundAdd64 =
+          rewriter.create<mlir::AddIOp>(loc, round64, shiftThirty64);
+      Value roundSub64 =
+          rewriter.create<mlir::SubIOp>(loc, round64, shiftThirty64);
+
+      Value valueGreaterThanZero = rewriter.create<mlir::CmpIOp>(
+          loc, CmpIPredicate::sge, value32, zero32);
+
+      Value doubleRound64 = rewriter.create<mlir::SelectOp>(
+          loc, valueGreaterThanZero, roundAdd64, roundSub64);
+
+      // We only perform double rounding if the shift value is greater than 32.
+      Value thirtyTwo32 = rewriter.create<ConstantOp>(
+          loc, rewriter.getIntegerAttr(rewriter.getI32Type(), 32));
+      Value shiftGreaterThanThirtyTwo = rewriter.create<mlir::CmpIOp>(
+          loc, CmpIPredicate::sge, shift32, thirtyTwo32);
+      round64 = rewriter.create<mlir::SelectOp>(loc, shiftGreaterThanThirtyTwo,
+                                                doubleRound64, round64);
+    }
+
+    // The computation below equates to the following pseudocode:
+    //    int64_t result = (int64_t)value * multiplier + round;
+    //    result = result >> shift;
+    //
+    // Note that multiply and shift need to be perform in i64 to preserve bits.
+
+    Value value64 =
+        rewriter.create<SignExtendIOp>(loc, rewriter.getI64Type(), value32);
+    Value multiplier64 = rewriter.create<SignExtendIOp>(
+        loc, rewriter.getI64Type(), multiplier32);
+    Value shift64 =
+        rewriter.create<SignExtendIOp>(loc, rewriter.getI64Type(), shift8);
+
+    // Multiply as a pair of i64 values to guarantee the end value fits.
+    Value result64 = rewriter.create<MulIOp>(loc, value64, multiplier64);
+    result64 = rewriter.create<AddIOp>(loc, result64, round64);
+    result64 =
+        rewriter.create<mlir::SignedShiftRightOp>(loc, result64, shift64);
+
+    Value result32 = rewriter.create<mlir::TruncateIOp>(
+        loc, rewriter.getI32Type(), result64);
+
+    rewriter.replaceOp(op, result32);
     return success();
   }
 };
@@ -55,5 +155,11 @@ class SliceOpConverter : public OpRewritePattern<tosa::SliceOp> {
 
 void mlir::tosa::populateTosaToStandardConversionPatterns(
     MLIRContext *context, OwningRewritePatternList *patterns) {
-  patterns->insert<ConstOpConverter, SliceOpConverter>(context);
+  patterns->insert<ApplyScaleOpConverter, ConstOpConverter, SliceOpConverter>(
+      context);
+}
+
+void mlir::tosa::populateTosaRescaleToStandardConversionPatterns(
+    MLIRContext *context, OwningRewritePatternList *patterns) {
+  patterns->insert<ApplyScaleOpConverter>(context);
 }

diff  --git a/mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp b/mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp
index 78a0e65da81b..14c800e2f70d 100644
--- a/mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp
+++ b/mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp
@@ -33,6 +33,7 @@ struct TosaToStandard : public TosaToStandardBase<TosaToStandard> {
     ConversionTarget target(getContext());
     target.addIllegalOp<tosa::ConstOp>();
     target.addIllegalOp<tosa::SliceOp>();
+    target.addIllegalOp<tosa::ApplyScaleOp>();
     target.addLegalDialect<StandardOpsDialect>();
 
     auto *op = getOperation();

diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 9b1f6054ee06..1714f140dbfc 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -473,3 +473,54 @@ func @concat(%arg0: tensor<5x1xf32>, %arg1: tensor<6x1xf32>) -> () {
   %1 = "tosa.concat"(%arg0, %arg0) { axis = 1 : i64} : (tensor<5x1xf32>, tensor<5x1xf32>)  -> (tensor<5x2xf32>)
   return
 }
+
+// -----
+
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
+// CHECK: #[[$MAP1:.*]] = affine_map<(d0) -> (0)>
+
+// CHECK-LABEL: @rescale
+func @rescale(%arg0 : tensor<1xi8>) -> (tensor<1xi8>) {
+  // CHECK: [[C0:%.+]] = constant dense<19689>
+  // CHECK: [[C1:%.+]] = constant dense<15>
+  // CHECK: [[INIT:%.+]] = linalg.init_tensor [1]
+  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%arg0, [[C0]], [[C1]] : tensor<1xi8>, tensor<1xi32>, tensor<1xi8>) outs([[INIT]] : tensor<1xi8>)
+  // CHECK: ^bb0([[IN:%.+]]: i8, [[MULTIPLIER:%.+]]: i32, [[SHIFT:%.+]]: i8, [[UNUSED:%.+]]: i8):
+  // CHECK: [[C243:%.+]] = constant 243
+  // CHECK: [[C252:%.+]] = constant 252
+
+  // CHECK-DAG: [[IN32:%.+]] = sexti [[IN]]
+  // CHECK-DAG: [[IN_ZEROED:%.+]] = subi [[IN32]], [[C243]]
+  // CHECK-DAG: [[SCALED:%.+]] = "tosa.apply_scale"([[IN_ZEROED]], [[MULTIPLIER]], [[SHIFT]]) {double_round = false}
+  // CHECK-DAG: [[SCALED_ZEROED:%.+]] = addi [[SCALED]], [[C252]]
+  // CHECK-DAG: [[CMIN:%.+]] = constant -128
+  // CHECK-DAG: [[CMAX:%.+]] = constant 127
+  // CHECK-DAG: [[MINLT:%.+]] = cmpi slt, [[SCALED_ZEROED]], [[CMIN]]
+  // CHECK-DAG: [[MAXLT:%.+]] = cmpi slt, [[CMAX]], [[SCALED_ZEROED]]
+  // CHECK-DAG: [[LOWER:%.+]] = select [[MINLT]], [[CMIN]], [[SCALED_ZEROED]]
+  // CHECK-DAG: [[BOUNDED:%.+]] = select [[MAXLT]], [[CMAX]], [[LOWER]]
+  // CHECK-DAG: [[TRUNC:%.+]] = trunci [[BOUNDED]]
+  // CHECK-DAG: linalg.yield [[TRUNC]]
+  %0 = "tosa.rescale"(%arg0) {input_zp = 243 : i32, output_zp = 252 : i32, multiplier = [19689 : i32], shift = [15 : i32], scale32 = false, double_round = false, per_channel = false} : (tensor<1xi8>)  -> (tensor<1xi8>)
+
+  // CHECK: return [[GENERIC]]
+  return %0 : tensor<1xi8>
+}
+
+// CHECK-LABEL: @rescaleDoubleRound
+func @rescaleDoubleRound(%arg0 : tensor<1xi8>) -> (tensor<1xi8>) {
+  // CHECK: linalg.generic
+  // CHECK: "tosa.apply_scale"
+  // CHECK-SAME:  {double_round = true}
+  %0 = "tosa.rescale"(%arg0) {input_zp = 243 : i32, output_zp = 252 : i32, multiplier = [19689 : i32], shift = [33 : i32], scale32 = true, double_round = true, per_channel = false} : (tensor<1xi8>)  -> (tensor<1xi8>)
+  return %0 : tensor<1xi8>
+}
+
+// CHECK-LABEL: @rescaleUnnecessaryDoubleRound
+func @rescaleUnnecessaryDoubleRound(%arg0 : tensor<1xi8>) -> (tensor<1xi8>) {
+  // CHECK: linalg.generic
+  // CHECK: "tosa.apply_scale"
+  // CHECK-SAME:  {double_round = false}
+  %0 = "tosa.rescale"(%arg0) {input_zp = 243 : i32, output_zp = 252 : i32, multiplier = [19689 : i32], shift = [15 : i32], scale32 = true, double_round = true, per_channel = false} : (tensor<1xi8>)  -> (tensor<1xi8>)
+  return %0 : tensor<1xi8>
+}

diff  --git a/mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir b/mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir
index 94925aec15c7..2c80c31cf297 100644
--- a/mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir
+++ b/mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir
@@ -9,10 +9,46 @@ func @const_test() -> (tensor<i32>) {
   return %0 : tensor<i32>
 }
 
-// ----
+// -----
 
 func @slice(%arg0: tensor<6xf32>) ->() {
   // CHECK: [[SLICE:%.+]] = subtensor %arg0[2] [1] [1]
   %0 = "tosa.slice"(%arg0) {start = [2], size = [1]} : (tensor<6xf32>)  -> (tensor<1xf32>)
   return
 }
+
+// -----
+
+func @apply_scale_test(%arg0 : i32, %arg1 : i32, %arg2 : i8) -> (i32) {
+  // CHECK: [[C1_8:%.+]] = constant 1 : i8
+  // CHECK: [[C1_32:%.+]] = constant 1 : i32
+  // CHECK: [[C1_64:%.+]] = constant 1 : i64
+  // CHECK: [[SHIFT_MINUS_ONE_8:%.+]] = subi %arg2, [[C1_8]]
+
+  // CHECK: [[SHIFT_32:%.+]] = sexti %arg2 : i8 to i32
+  // CHECK: [[SHIFT_MINUS_ONE_64:%.+]] = sexti [[SHIFT_MINUS_ONE_8]] : i8 to i64
+  // CHECK: [[SHIFTED_64:%.+]] = shift_left [[C1_64]], [[SHIFT_MINUS_ONE_64]]
+
+  // CHECK: [[C0_32:%.+]] = constant 0 : i32
+  // CHECK: [[C30_32:%.+]] = constant 30 : i32
+  // CHECK: [[SECOND_BIAS:%.+]] = shift_left [[C1_32]], [[C30_32]]
+  // CHECK: [[SECOND_BIAS_64:%.+]] = sexti [[SECOND_BIAS]] : i32 to i64
+  // CHECK: [[POSITIVE_ROUND:%.+]] = addi [[SHIFTED_64]], [[SECOND_BIAS_64]]
+  // CHECK: [[NEGATIVE_ROUND:%.+]] = subi [[SHIFTED_64]], [[SECOND_BIAS_64]]
+  // CHECK: [[VALUE_NEGATIVE:%.+]] = cmpi sge, %arg0, [[C0_32]] : i32
+  // CHECK: [[DOUBLE_ROUNDED:%.+]] = select [[VALUE_NEGATIVE]], [[POSITIVE_ROUND]], [[NEGATIVE_ROUND]] : i64
+  // CHECK: [[C32_32:%.+]] = constant 32 : i32
+  // CHECK: [[IS_32BIT_SHIFT:%.+]] = cmpi sge, [[SHIFT_32]], [[C32_32]]
+  // CHECK: [[ROUND:%.+]] = select [[IS_32BIT_SHIFT]], [[DOUBLE_ROUNDED]], [[SHIFTED_64]]
+
+  // CHECK: [[VAL_64:%.+]] = sexti %arg0 : i32 to i64
+  // CHECK: [[MULTIPLY_64:%.+]] = sexti %arg1 : i32 to i64
+  // CHECK: [[SHIFT_64:%.+]] = sexti %arg2 : i8 to i64
+  // CHECK: [[SCALED:%.+]] = muli [[VAL_64]], [[MULTIPLY_64]]
+  // CHECK: [[BIASED:%.+]] = addi [[SCALED]], [[ROUND]]
+  // CHECK: [[DOWNSHIFTED:%.+]] = shift_right_signed [[BIASED]], [[SHIFT_64]]
+  // CHECK: [[TRUNCATED:%.+]] = trunci [[DOWNSHIFTED]]
+
+  %0 = "tosa.apply_scale"(%arg0, %arg1, %arg2) {double_round = true} : (i32, i32, i8) -> i32
+  return %0 : i32
+}


        


More information about the Mlir-commits mailing list