[Mlir-commits] [mlir] be01b09 - [mlir][tosa] Remove constant-0 dim expr values from TOSA lowerings

Rob Suderman llvmlistbot at llvm.org
Thu Apr 29 15:23:17 PDT 2021


Author: Rob Suderman
Date: 2021-04-29T15:06:03-07:00
New Revision: be01b091afd820c5784ba960241ea6140529b654

URL: https://github.com/llvm/llvm-project/commit/be01b091afd820c5784ba960241ea6140529b654
DIFF: https://github.com/llvm/llvm-project/commit/be01b091afd820c5784ba960241ea6140529b654.diff

LOG: [mlir][tosa] Remove constant-0 dim expr values from TOSA lowerings

Constant-0 dim expr values should be avoided for linalg as it can prevent
fusion. This includes adding support for rank-0 reshapes.

Differential Revision: https://reviews.llvm.org/D101418

Added: 
    

Modified: 
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index b31656467c485..8107d97cb61f2 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -48,26 +48,6 @@ static void getValuesFromIntArrayAttribute(ArrayAttr attr,
   }
 }
 
-// Generates an affine map for parallel operations on a given type. This
-// performs implicit broadcasting across any dimension of size-1.
-static AffineMap createAffineMapForType(ShapedType type,
-                                        PatternRewriter &rewriter) {
-  unsigned rank = type.getRank();
-  auto shape = type.getShape();
-  SmallVector<AffineExpr, 4> dimExprs;
-  dimExprs.reserve(rank);
-  for (unsigned i = 0; i < rank; ++i) {
-    // If the dimension is one we can broadcast the input with a constant
-    // affine expression.
-    if (shape[i] == 1)
-      dimExprs.push_back(rewriter.getAffineConstantExpr(0));
-    else
-      dimExprs.push_back(rewriter.getAffineDimExpr(i));
-  }
-  return AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, dimExprs,
-                        rewriter.getContext());
-}
-
 template <typename T, typename P>
 static mlir::SelectOp clampHelper(Location loc, Value arg, mlir::ConstantOp min,
                                   mlir::ConstantOp max, P pred,
@@ -464,11 +444,14 @@ elementwiseMatchAndRewriteHelper(Operation *operation,
                                  PatternRewriter &rewriter) {
   auto loc = operation->getLoc();
   auto results = operation->getResults();
-  auto t0 = operation->getOperand(0).getType().template dyn_cast<ShapedType>();
-  if (!t0)
+  auto resultTy = operation->getOperand(0).getType().dyn_cast<ShapedType>();
+
+  if (!resultTy)
     return rewriter.notifyMatchFailure(operation,
                                        "All results must be a shaped type");
 
+  unsigned rank = resultTy.getRank();
+
   assert(operation->getNumResults() == 1 &&
          "All TOSA elementwise ops should only return a single result.");
 
@@ -496,23 +479,42 @@ elementwiseMatchAndRewriteHelper(Operation *operation,
   auto bodyResultTypes = llvm::to_vector<4>(llvm::map_range(
       initTensors, [](Value v) { return getElementTypeOrSelf(v); }));
 
-  unsigned nloops = t0.getRank();
+  SmallVector<Value, 2> operands;
   SmallVector<AffineMap, 2> indexingMaps;
   indexingMaps.reserve(operation->getNumOperands() + bodyResultTypes.size());
 
   // Input indexing maps may be broadcasted.
-  for (Type type : operation->getOperandTypes()) {
-    indexingMaps.push_back(
-        createAffineMapForType(type.cast<ShapedType>(), rewriter));
+  for (Value operand : operation->getOperands()) {
+    ShapedType type = operand.getType().cast<ShapedType>();
+    SmallVector<int64_t, 5> newShape;
+    SmallVector<AffineExpr, 4> affineExprs;
+    newShape.reserve(type.getRank());
+    for (auto it : llvm::enumerate(type.getShape())) {
+      if (it.value() != 1) {
+        newShape.push_back(it.value());
+        affineExprs.push_back(
+            mlir::getAffineDimExpr(it.index(), rewriter.getContext()));
+      }
+    }
+
+    if (newShape.size() != rank) {
+      operand = rewriter.create<tosa::ReshapeOp>(
+          loc, RankedTensorType::get(newShape, type.getElementType()), operand);
+    }
+
+    operands.push_back(operand);
+    indexingMaps.push_back(AffineMap::get(
+        /*dimCount=*/type.getRank(), /*symbolCount=*/0, affineExprs,
+        rewriter.getContext()));
   }
 
   indexingMaps.append(operation->getNumResults(),
-                      rewriter.getMultiDimIdentityMap(nloops));
+                      rewriter.getMultiDimIdentityMap(rank));
 
   bool didEncounterError = false;
   auto linalgOp = rewriter.create<linalg::GenericOp>(
-      loc, opResultTypes, operation->getOperands(), initTensors, indexingMaps,
-      getNParallelLoopsAttrs(nloops),
+      loc, opResultTypes, operands, initTensors, indexingMaps,
+      getNParallelLoopsAttrs(rank),
       [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) {
         Value opResult = createLinalgBodyCalculationForElementwiseOp(
             operation, blockArgs.take_front(operation->getNumOperands()),
@@ -650,12 +652,20 @@ static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
   auto elementTy = resultTy.getElementType();
   Value input = op->getOperand(0);
 
+  llvm::SmallVector<int64_t> reduceShape;
+  for (unsigned i = 0; i < inputTy.getRank(); i++) {
+    if (axis != i)
+      reduceShape.push_back(inputTy.getDimSize(i));
+  }
+
+  Type reduceTy = RankedTensorType::get(reduceShape, resultTy.getElementType());
+
   // First fill the output buffer with the init value.
-  auto initTensor = rewriter
-                        .create<linalg::InitTensorOp>(loc, ArrayRef<Value>({}),
-                                                      resultTy.getShape(),
-                                                      resultTy.getElementType())
-                        .result();
+  auto initTensor =
+      rewriter
+          .create<linalg::InitTensorOp>(loc, ArrayRef<Value>({}), reduceShape,
+                                        resultTy.getElementType())
+          .result();
 
   auto fillValueAttr = createInitialValueForReduceOp(op, elementTy, rewriter);
   if (!fillValueAttr)
@@ -676,14 +686,12 @@ static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
                                       : getParallelIteratorTypeName());
     if (axis != i)
       dstExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
-    else
-      dstExprs.push_back(rewriter.getAffineConstantExpr(0));
   }
 
   bool didEncounterError = false;
   auto maps = AffineMap::inferFromExprList({srcExprs, dstExprs});
   auto linalgOp = rewriter.create<linalg::GenericOp>(
-      loc, resultTy, input, filledTensor, maps, iteratorTypes,
+      loc, reduceTy, input, filledTensor, maps, iteratorTypes,
       [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) {
         auto result = createLinalgBodyCalculationForReduceOp(
             op, blockArgs, elementTy, rewriter);
@@ -696,7 +704,8 @@ static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
   if (!didEncounterError)
     return failure();
 
-  rewriter.replaceOp(op, linalgOp.getOperation()->getResults());
+  rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(op, resultTy,
+                                               linalgOp.getResults());
   return success();
 }
 
@@ -971,9 +980,12 @@ class ReshapeConverter : public OpConversionPattern<tosa::ReshapeOp> {
       }
       currDstDim++;
     }
+
+    // Check if any remaining dimensions exist. If either is rank-0 we only
+    // require the directly lowering.
     if (currSrcDim != expandedShape.size() ||
         currDstDim != collapsedShape.size())
-      isCollapsingSource = false;
+      isCollapsingSource = collapsedShape.empty() || expandedShape.empty();
 
     // Otherwise, we need to first reduce all source dimensions into one and
     // then expand to the destination dimensions.
@@ -1084,56 +1096,65 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
         op.double_round() &&
         llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
 
-    // We need to broadcast along the last dimension, so make all dims 1.
-    SmallVector<int64_t> multiplierShape;
-    multiplierShape.resize(rank, 1);
-
-    SmallVector<int64_t> shiftShape;
-    shiftShape.resize(rank, 1);
-
-    // Set the channel dimension to match the number of shift/broadcast
-    // channels.
-    if (!multiplierShape.empty())
-      multiplierShape.back() = multiplierValues.size();
-    if (!shiftShape.empty())
-      shiftShape.back() = shiftValues.size();
-
-    // Create the tensor types.
-    auto multiplierType =
-        RankedTensorType::get(multiplierShape, rewriter.getI32Type());
-    auto shiftType =
-        RankedTensorType::get(shiftShape, rewriter.getIntegerType(8));
+    SmallVector<AffineMap> indexingMaps = {
+        rewriter.getMultiDimIdentityMap(rank)};
+    SmallVector<Value, 4> genericInputs = {input};
+
+    // If we are rescaling per-channel then we need to store the multiplier
+    // values in a buffer.
+    Value multiplierConstant;
+    int64_t multiplierArg = 0;
+    if (multiplierValues.size() == 1) {
+      multiplierConstant = rewriter.create<ConstantOp>(
+          loc, rewriter.getI32IntegerAttr(multiplierValues.front()));
+    } else {
+      SmallVector<AffineExpr, 2> multiplierExprs{
+          rewriter.getAffineDimExpr(rank - 1)};
+      auto multiplierType =
+          RankedTensorType::get({static_cast<int64_t>(multiplierValues.size())},
+                                rewriter.getI32Type());
+      genericInputs.push_back(rewriter.create<ConstantOp>(
+          loc, DenseIntElementsAttr::get(multiplierType, multiplierValues)));
+
+      indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
+                                            /*symbolCount=*/0, multiplierExprs,
+                                            rewriter.getContext()));
+
+      multiplierArg = indexingMaps.size() - 1;
+    }
 
-    auto multiplierConst = rewriter.create<ConstantOp>(
-        loc, DenseIntElementsAttr::get(multiplierType, multiplierValues));
+    // If we are rescaling per-channel then we need to store the shift
+    // values in a buffer.
+    Value shiftConstant;
+    int64_t shiftArg = 0;
+    if (shiftValues.size() == 1) {
+      shiftConstant = rewriter.create<ConstantOp>(
+          loc, rewriter.getI8IntegerAttr(shiftValues.front()));
+    } else {
+      SmallVector<AffineExpr, 2> shiftExprs = {
+          rewriter.getAffineDimExpr(rank - 1)};
+      auto shiftType =
+          RankedTensorType::get({static_cast<int64_t>(shiftValues.size())},
+                                rewriter.getIntegerType(8));
+      genericInputs.push_back(rewriter.create<ConstantOp>(
+          loc, DenseIntElementsAttr::get(shiftType, shiftValues)));
+      indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
+                                            /*symbolCount=*/0, shiftExprs,
+                                            rewriter.getContext()));
+      shiftArg = indexingMaps.size() - 1;
+    }
 
-    auto shiftConst = rewriter.create<ConstantOp>(
-        loc, DenseIntElementsAttr::get(shiftType, shiftValues));
+    // Indexing maps for output values.
+    indexingMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
 
     // Construct the indexing maps needed for linalg.generic ops.
-    SmallVector<Type> bodyArgTypes = {getElementTypeOrSelf(inputTy),
-                                      rewriter.getI32Type(),
-                                      rewriter.getI32Type()};
     Value initTensor = rewriter.create<linalg::InitTensorOp>(
         loc, ArrayRef<Value>({}), outputTy.getShape(),
         outputTy.getElementType());
 
-    SmallVector<AffineMap, 4> indexingMaps;
-
-    // Indexing map for input values.
-    indexingMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
-
-    // Shift and multiplier will need to broadcast across their non channel
-    // values.
-    indexingMaps.push_back(createAffineMapForType(multiplierType, rewriter));
-    indexingMaps.push_back(createAffineMapForType(shiftType, rewriter));
-
-    // Indexing maps for output values.
-    indexingMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
-
     auto linalgOp = rewriter.create<linalg::GenericOp>(
-        loc, outputTy, ValueRange{input, multiplierConst, shiftConst},
-        ValueRange{initTensor}, indexingMaps, getNParallelLoopsAttrs(rank),
+        loc, outputTy, genericInputs, ValueRange{initTensor}, indexingMaps,
+        getNParallelLoopsAttrs(rank),
         [&](OpBuilder &nestedBuilder, Location nestedLoc,
             ValueRange blockArgs) {
           // For now we do all of our math in 64-bit. This is not optimal but
@@ -1145,8 +1166,9 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
               op, "output_zp", nestedBuilder.getI32Type(), nestedBuilder);
 
           Value value = blockArgs[0];
-          Value multiplier = blockArgs[1];
-          Value shift = blockArgs[2];
+          Value multiplier = multiplierConstant ? multiplierConstant
+                                                : blockArgs[multiplierArg];
+          Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg];
 
           if (value.getType().getIntOrFloatBitWidth() < 32) {
             value = nestedBuilder.create<SignExtendIOp>(
@@ -1608,17 +1630,6 @@ struct TileConverter : public OpConversionPattern<tosa::TileOp> {
     SmallVector<int64_t> multiples;
     getValuesFromIntArrayAttribute(op.multiples(), multiples);
 
-    llvm::SmallVector<int64_t, 4> reshapeShape;
-    reshapeShape.reserve(rank * 2);
-    for (int i = 0; i < rank; i++) {
-      reshapeShape.push_back(1);
-      reshapeShape.push_back(inputShape[i]);
-    }
-
-    ShapedType reshapeTy = RankedTensorType::get(reshapeShape, elementTy);
-    Value reshape = rewriter.create<tosa::ReshapeOp>(
-        loc, reshapeTy, input, rewriter.getI64ArrayAttr(reshapeTy.getShape()));
-
     // Broadcast the newly added dimensions to their appropriate multiple.
     SmallVector<int64_t, 2> genericShape;
     for (int i = 0; i < rank; i++) {
@@ -1629,12 +1640,21 @@ struct TileConverter : public OpConversionPattern<tosa::TileOp> {
     auto initTensor = rewriter.create<linalg::InitTensorOp>(
         op.getLoc(), ArrayRef<Value>({}), genericShape, elementTy);
 
+    // We needs to map the input shape to the non-broadcasted dimensions.
+    SmallVector<AffineExpr, 4> dimExprs;
+    dimExprs.reserve(rank);
+    for (unsigned i = 0; i < rank; ++i)
+      dimExprs.push_back(rewriter.getAffineDimExpr(i * 2 + 1));
+
+    auto readAffineMap =
+        AffineMap::get(/*dimCount=*/rank * 2, /*symbolCount=*/0, dimExprs,
+                       rewriter.getContext());
+
     SmallVector<AffineMap, 2> affineMaps = {
-        createAffineMapForType(reshapeTy, rewriter),
-        rewriter.getMultiDimIdentityMap(genericShape.size())};
+        readAffineMap, rewriter.getMultiDimIdentityMap(genericShape.size())};
 
     auto genericOp = rewriter.create<linalg::GenericOp>(
-        loc, RankedTensorType::get(genericShape, elementTy), reshape,
+        loc, RankedTensorType::get(genericShape, elementTy), input,
         ValueRange{initTensor}, affineMaps,
         getNParallelLoopsAttrs(genericShape.size()),
         [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {

diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 57d8c86fa25bc..775ac6d06ed4e 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -55,13 +55,14 @@ func @test_abs(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
 
 // -----
 
-// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (0)>
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> ()>
 // CHECK: #[[$MAP1:.*]] = affine_map<(d0) -> (d0)>
 
 // CHECK-LABEL: @test_broadcast
 func @test_broadcast(%arg0: tensor<1xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
   // CHECK: [[INIT:%.+]] = linalg.init_tensor [2] : tensor<2xf32>
-  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<1xf32>, tensor<2xf32>) outs([[INIT]] : tensor<2xf32>) {
+  // CHECK: [[RESHAPE:%.+]] = linalg.tensor_reshape %arg0
+  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel"]} ins([[RESHAPE]], %arg1 : tensor<f32>, tensor<2xf32>) outs([[INIT]] : tensor<2xf32>) {
   // CHECK: ^bb0(%arg2: f32, %arg3: f32, %arg4: f32):
   // CHECK:   [[ELEMENT:%.+]] = addf %arg2, %arg3 : f32
   // CHECK:   linalg.yield [[ELEMENT]] : f32
@@ -72,14 +73,16 @@ func @test_broadcast(%arg0: tensor<1xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32
 
 // -----
 
-// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (0, d1)>
-// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0, 0)>
-// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)>
+// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0)>
 
 // CHECK-LABEL: @test_multibroadcast
 func @test_multibroadcast(%arg0: tensor<1x3xf32>, %arg1: tensor<2x1xf32>) -> tensor<2x3xf32> {
   // CHECK: [[INIT:%.+]] = linalg.init_tensor [2, 3] : tensor<2x3xf32>
-  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<1x3xf32>, tensor<2x1xf32>) outs([[INIT]] : tensor<2x3xf32>) {
+  // CHECK: [[RESHAPE1:%.+]] = linalg.tensor_reshape %arg0 [#map0]
+  // CHECK: [[RESHAPE2:%.+]] = linalg.tensor_reshape %arg1 [#map0]
+  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins([[RESHAPE1]], [[RESHAPE2]] : tensor<3xf32>, tensor<2xf32>) outs([[INIT]] : tensor<2x3xf32>) {
   // CHECK: ^bb0(%arg2: f32, %arg3: f32, %arg4: f32):
   // CHECK:   [[ELEMENT:%.+]] = addf %arg2, %arg3 : f32
   // CHECK:   linalg.yield [[ELEMENT]] : f32
@@ -472,28 +475,30 @@ func @test_transpose(%arg0: tensor<1x2x3xi32>) -> () {
 // -----
 
 // CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
-// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (0, d1)>
-// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, 0)>
+// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)>
+// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0)>
 
 // CHECK-LABEL: @reduce_float
 // CHECK-SAME: [[ARG0:%.+]]: tensor<5x4xf32>
 func @reduce_float(%arg0: tensor<5x4xf32>) -> () {
-  // CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 4]
+  // CHECK: [[INIT:%.+]] = linalg.init_tensor [4]
   // CHECK: [[CST0:%.+]] = constant 0.0
   // CHECK: [[FILL:%.+]] = linalg.fill([[INIT]], [[CST0]])
-  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction", "parallel"]} ins([[ARG0]] : tensor<5x4xf32>) outs([[FILL]] : tensor<1x4xf32>)
+  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction", "parallel"]} ins([[ARG0]] : tensor<5x4xf32>) outs([[FILL]] : tensor<4xf32>)
   // CHECK: ^bb0(%arg1: f32, %arg2: f32)
   // CHECK:   [[RES:%.+]] = addf %arg1, %arg2 : f32
   // CHECK:   linalg.yield [[RES]] : f32
+  // CHECK: linalg.tensor_reshape [[GENERIC]] [#map0] : tensor<4xf32> into tensor<1x4xf32>
   %0 = "tosa.reduce_sum"(%arg0) {axis = 0 : i64} : (tensor<5x4xf32>) -> tensor<1x4xf32>
 
-  // CHECK: [[INIT:%.+]] = linalg.init_tensor [5, 1]
+  // CHECK: [[INIT:%.+]] = linalg.init_tensor [5]
   // CHECK: [[CST0:%.+]] = constant 0.0
   // CHECK: [[FILL:%.+]] = linalg.fill([[INIT]], [[CST0]])
-  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP2]]], iterator_types = ["parallel", "reduction"]} ins([[ARG0]] : tensor<5x4xf32>) outs([[FILL]] : tensor<5x1xf32>)
+  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP2]]], iterator_types = ["parallel", "reduction"]} ins([[ARG0]] : tensor<5x4xf32>) outs([[FILL]] : tensor<5xf32>)
   // CHECK: ^bb0(%arg1: f32, %arg2: f32)
   // CHECK:   [[RES:%.+]] = addf %arg1, %arg2 : f32
   // CHECK:   linalg.yield [[RES]] : f32
+  // CHECK: linalg.tensor_reshape [[GENERIC]] [#map0] : tensor<5xf32> into tensor<5x1xf32>
   %1 = "tosa.reduce_sum"(%arg0) {axis = 1 : i64} : (tensor<5x4xf32>) -> tensor<5x1xf32>
 
   // CHECK: constant 1.0
@@ -521,28 +526,30 @@ func @reduce_float(%arg0: tensor<5x4xf32>) -> () {
 // -----
 
 // CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
-// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (0, d1)>
-// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, 0)>
+// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)>
+// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0)>
 
 // CHECK-LABEL: @reduce_int
 // CHECK-SAME: [[ARG0:%.+]]: tensor<5x4xi32>
 func @reduce_int(%arg0: tensor<5x4xi32>) -> () {
-  // CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 4]
+  // CHECK: [[INIT:%.+]] = linalg.init_tensor [4]
   // CHECK: [[CST0:%.+]] = constant 0
   // CHECK: [[FILL:%.+]] = linalg.fill([[INIT]], [[CST0]])
-  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction", "parallel"]} ins([[ARG0]] : tensor<5x4xi32>) outs([[FILL]] : tensor<1x4xi32>)
+  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction", "parallel"]} ins([[ARG0]] : tensor<5x4xi32>) outs([[FILL]] : tensor<4xi32>)
   // CHECK: ^bb0(%arg1: i32, %arg2: i32)
   // CHECK:   [[RES:%.+]] = addi %arg1, %arg2 : i32
   // CHECK:   linalg.yield [[RES]] : i32
+  // CHECK: linalg.tensor_reshape [[GENERIC]] [#map0] : tensor<4xi32> into tensor<1x4xi32>
   %0 = "tosa.reduce_sum"(%arg0) {axis = 0 : i64} : (tensor<5x4xi32>) -> tensor<1x4xi32>
 
-  // CHECK: [[INIT:%.+]] = linalg.init_tensor [5, 1]
+  // CHECK: [[INIT:%.+]] = linalg.init_tensor [5]
   // CHECK: [[CST0:%.+]] = constant 0
   // CHECK: [[FILL:%.+]] = linalg.fill([[INIT]], [[CST0]])
-  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP2]]], iterator_types = ["parallel", "reduction"]} ins([[ARG0]] : tensor<5x4xi32>) outs([[FILL]] : tensor<5x1xi32>)
+  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP2]]], iterator_types = ["parallel", "reduction"]} ins([[ARG0]] : tensor<5x4xi32>) outs([[FILL]] : tensor<5xi32>)
   // CHECK: ^bb0(%arg1: i32, %arg2: i32)
   // CHECK:   [[RES:%.+]] = addi %arg1, %arg2 : i32
   // CHECK:   linalg.yield [[RES]] : i32
+  // CHECK: linalg.tensor_reshape [[GENERIC]] [#map0] : tensor<5xi32> into tensor<5x1xi32>
   %1 = "tosa.reduce_sum"(%arg0) {axis = 1 : i64} : (tensor<5x4xi32>) -> tensor<5x1xi32>
 
   // CHECK: constant 1
@@ -570,18 +577,19 @@ func @reduce_int(%arg0: tensor<5x4xi32>) -> () {
 // -----
 
 // CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
-// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (0, d1)>
+// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)>
 
 // CHECK-LABEL: @reduce_bool
 // CHECK-SAME: [[ARG0:%.+]]: tensor<5x4xi1>
 func @reduce_bool(%arg0: tensor<5x4xi1>) -> () {
-  // CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 4]
+  // CHECK: [[INIT:%.+]] = linalg.init_tensor [4]
   // CHECK: [[CST0:%.+]] = constant true
   // CHECK: [[FILL:%.+]] = linalg.fill([[INIT]], [[CST0]])
-  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction", "parallel"]} ins([[ARG0]] : tensor<5x4xi1>) outs([[FILL]] : tensor<1x4xi1>)
+  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction", "parallel"]} ins([[ARG0]] : tensor<5x4xi1>) outs([[FILL]] : tensor<4xi1>)
   // CHECK: ^bb0(%arg1: i1, %arg2: i1)
   // CHECK:   [[RES:%.+]] = and %arg1, %arg2 : i1
   // CHECK:   linalg.yield [[RES]] : i1
+  // CHECK: linalg.tensor_reshape [[GENERIC]] [#map0] : tensor<4xi1> into tensor<1x4xi1>
   %0 = "tosa.reduce_all"(%arg0) {axis = 0 : i64} : (tensor<5x4xi1>) -> tensor<1x4xi1>
 
   // CHECK: constant false
@@ -636,14 +644,45 @@ func @concat(%arg0: tensor<5x1xf32>, %arg1: tensor<6x1xf32>) -> () {
 // -----
 
 // CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
-// CHECK: #[[$MAP1:.*]] = affine_map<(d0) -> (0)>
 
 // CHECK-LABEL: @rescale
-func @rescale(%arg0 : tensor<1xi8>) -> (tensor<1xi8>) {
-  // CHECK: [[C0:%.+]] = constant dense<19689>
-  // CHECK: [[C1:%.+]] = constant dense<15>
-  // CHECK: [[INIT:%.+]] = linalg.init_tensor [1]
-  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%arg0, [[C0]], [[C1]] : tensor<1xi8>, tensor<1xi32>, tensor<1xi8>) outs([[INIT]] : tensor<1xi8>)
+func @rescale(%arg0 : tensor<2xi8>) -> (tensor<2xi8>) {
+  // CHECK: [[C0:%.+]] = constant 19689
+  // CHECK: [[C1:%.+]] = constant 15
+  // CHECK: [[INIT:%.+]] = linalg.init_tensor [2]
+  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%arg0 : tensor<2xi8>) outs([[INIT]] : tensor<2xi8>)
+  // CHECK: ^bb0([[IN:%.+]]: i8, [[UNUSED:%.+]]: i8):
+  // CHECK: [[C243:%.+]] = constant 243
+  // CHECK: [[C252:%.+]] = constant 252
+
+  // CHECK-DAG: [[IN32:%.+]] = sexti [[IN]]
+  // CHECK-DAG: [[IN_ZEROED:%.+]] = subi [[IN32]], [[C243]]
+  // CHECK-DAG: [[SCALED:%.+]] = "tosa.apply_scale"([[IN_ZEROED]], [[C0]], [[C1]]) {double_round = false}
+  // CHECK-DAG: [[SCALED_ZEROED:%.+]] = addi [[SCALED]], [[C252]]
+  // CHECK-DAG: [[CMIN:%.+]] = constant -128
+  // CHECK-DAG: [[CMAX:%.+]] = constant 127
+  // CHECK-DAG: [[MINLT:%.+]] = cmpi slt, [[SCALED_ZEROED]], [[CMIN]]
+  // CHECK-DAG: [[MAXLT:%.+]] = cmpi slt, [[CMAX]], [[SCALED_ZEROED]]
+  // CHECK-DAG: [[LOWER:%.+]] = select [[MINLT]], [[CMIN]], [[SCALED_ZEROED]]
+  // CHECK-DAG: [[BOUNDED:%.+]] = select [[MAXLT]], [[CMAX]], [[LOWER]]
+  // CHECK-DAG: [[TRUNC:%.+]] = trunci [[BOUNDED]]
+  // CHECK-DAG: linalg.yield [[TRUNC]]
+  %0 = "tosa.rescale"(%arg0) {input_zp = 243 : i32, output_zp = 252 : i32, multiplier = [19689 : i32], shift = [15 : i32], scale32 = false, double_round = false, per_channel = false} : (tensor<2xi8>)  -> (tensor<2xi8>)
+
+  // CHECK: return [[GENERIC]]
+  return %0 : tensor<2xi8>
+}
+
+// -----
+
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
+
+// CHECK-LABEL: @rescale_per_channel
+func @rescale_per_channel(%arg0 : tensor<2xi8>) -> (tensor<2xi8>) {
+  // CHECK: [[MULTIPLIERS:%.+]] = constant dense<[42, 43]>
+  // CHECK: [[SHIFTS:%.+]] = constant dense<[14, 15]>
+  // CHECK: [[INIT:%.+]] = linalg.init_tensor [2]
+  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]], #[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%arg0, [[MULTIPLIERS]], [[SHIFTS]] : tensor<2xi8>, tensor<2xi32>, tensor<2xi8>) outs([[INIT]] : tensor<2xi8>)
   // CHECK: ^bb0([[IN:%.+]]: i8, [[MULTIPLIER:%.+]]: i32, [[SHIFT:%.+]]: i8, [[UNUSED:%.+]]: i8):
   // CHECK: [[C243:%.+]] = constant 243
   // CHECK: [[C252:%.+]] = constant 252
@@ -660,28 +699,30 @@ func @rescale(%arg0 : tensor<1xi8>) -> (tensor<1xi8>) {
   // CHECK-DAG: [[BOUNDED:%.+]] = select [[MAXLT]], [[CMAX]], [[LOWER]]
   // CHECK-DAG: [[TRUNC:%.+]] = trunci [[BOUNDED]]
   // CHECK-DAG: linalg.yield [[TRUNC]]
-  %0 = "tosa.rescale"(%arg0) {input_zp = 243 : i32, output_zp = 252 : i32, multiplier = [19689 : i32], shift = [15 : i32], scale32 = false, double_round = false, per_channel = false} : (tensor<1xi8>)  -> (tensor<1xi8>)
+  %0 = "tosa.rescale"(%arg0) {input_zp = 243 : i32, output_zp = 252 : i32, multiplier = [42 : i32, 43 : i32], shift = [14 : i32, 15 : i32], scale32 = false, double_round = false, per_channel = false} : (tensor<2xi8>)  -> (tensor<2xi8>)
 
   // CHECK: return [[GENERIC]]
-  return %0 : tensor<1xi8>
+  return %0 : tensor<2xi8>
 }
 
+// -----
+
 // CHECK-LABEL: @rescaleDoubleRound
-func @rescaleDoubleRound(%arg0 : tensor<1xi8>) -> (tensor<1xi8>) {
+func @rescaleDoubleRound(%arg0 : tensor<2xi8>) -> (tensor<2xi8>) {
   // CHECK: linalg.generic
   // CHECK: "tosa.apply_scale"
   // CHECK-SAME:  {double_round = true}
-  %0 = "tosa.rescale"(%arg0) {input_zp = 243 : i32, output_zp = 252 : i32, multiplier = [19689 : i32], shift = [33 : i32], scale32 = true, double_round = true, per_channel = false} : (tensor<1xi8>)  -> (tensor<1xi8>)
-  return %0 : tensor<1xi8>
+  %0 = "tosa.rescale"(%arg0) {input_zp = 243 : i32, output_zp = 252 : i32, multiplier = [19689 : i32], shift = [33 : i32], scale32 = true, double_round = true, per_channel = false} : (tensor<2xi8>)  -> (tensor<2xi8>)
+  return %0 : tensor<2xi8>
 }
 
 // CHECK-LABEL: @rescaleUnnecessaryDoubleRound
-func @rescaleUnnecessaryDoubleRound(%arg0 : tensor<1xi8>) -> (tensor<1xi8>) {
+func @rescaleUnnecessaryDoubleRound(%arg0 : tensor<2xi8>) -> (tensor<2xi8>) {
   // CHECK: linalg.generic
   // CHECK: "tosa.apply_scale"
   // CHECK-SAME:  {double_round = false}
-  %0 = "tosa.rescale"(%arg0) {input_zp = 243 : i32, output_zp = 252 : i32, multiplier = [19689 : i32], shift = [15 : i32], scale32 = true, double_round = true, per_channel = false} : (tensor<1xi8>)  -> (tensor<1xi8>)
-  return %0 : tensor<1xi8>
+  %0 = "tosa.rescale"(%arg0) {input_zp = 243 : i32, output_zp = 252 : i32, multiplier = [19689 : i32], shift = [15 : i32], scale32 = true, double_round = true, per_channel = false} : (tensor<2xi8>)  -> (tensor<2xi8>)
+  return %0 : tensor<2xi8>
 }
 
 // -----
@@ -708,32 +749,29 @@ func @reverse(%arg0: tensor<5x4xi32>) -> () {
 
 // -----
 
-// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
-// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d3)>
-// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (0, d1, 0, d3)>
-// CHECK: #[[$MAP3:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d3)>
+// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+// CHECK: #[[$MAP3:.*]] = affine_map<(d0, d1, d2, d3) -> (d3)>
 // CHECK: #[[$MAP4:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
 // CHECK: #[[$MAP5:.*]] = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
 
 // CHECK-LABEL: @tile
 func @tile(%arg0 : tensor<2x3xi8>) -> () {
-  // CHECK: [[RESHAPE:%.+]] = linalg.tensor_reshape %arg0 [#[[$MAP0]], #[[$MAP1]]] : tensor<2x3xi8> into tensor<1x2x1x3xi8>
   // CHECK: [[INIT:%.+]] = linalg.init_tensor [2, 2, 1, 3]
-  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP3]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins([[RESHAPE]] : tensor<1x2x1x3xi8>) outs([[INIT]] : tensor<2x2x1x3xi8>)
+  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<2x3xi8>) outs([[INIT]] : tensor<2x2x1x3xi8>)
   // CHECK:   linalg.yield %arg1 : i8
-  // CHECK: linalg.tensor_reshape [[GENERIC]] [#[[$MAP0]], #[[$MAP1]]]
+  // CHECK: linalg.tensor_reshape [[GENERIC]] [#[[$MAP2]], #[[$MAP3]]]
   %0 = "tosa.tile"(%arg0) {multiples = [2, 1]} : (tensor<2x3xi8>)  -> (tensor<4x3xi8>)
 
-  // CHECK: [[RESHAPE:%.+]] = linalg.tensor_reshape %arg0 [#[[$MAP0]], #[[$MAP1]]] : tensor<2x3xi8> into tensor<1x2x1x3xi8>
   // CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 2, 2, 3]
-  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP3]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins([[RESHAPE]] : tensor<1x2x1x3xi8>) outs([[INIT]] : tensor<1x2x2x3xi8>)
+  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<2x3xi8>) outs([[INIT]] : tensor<1x2x2x3xi8>)
   // CHECK:   linalg.yield %arg1 : i8
   // CHECK: linalg.tensor_reshape [[GENERIC]] [#[[$MAP4]], #[[$MAP5]]]
   %1 = "tosa.tile"(%arg0) {multiples = [1, 2]} : (tensor<2x3xi8>)  -> (tensor<2x6xi8>)
 
-  // CHECK: [[RESHAPE:%.+]] = linalg.tensor_reshape %arg0 [#[[$MAP0]], #[[$MAP1]]] : tensor<2x3xi8> into tensor<1x2x1x3xi8>
   // CHECK: [[INIT:%.+]] = linalg.init_tensor [5, 2, 7, 3]
-  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP3]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins([[RESHAPE]] : tensor<1x2x1x3xi8>) outs([[INIT]] : tensor<5x2x7x3xi8>)
+  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<2x3xi8>) outs([[INIT]] : tensor<5x2x7x3xi8>)
   // CHECK:   linalg.yield %arg1 : i8
   // CHECK: linalg.tensor_reshape [[GENERIC]] [#[[$MAP4]], #[[$MAP5]]]
   %2 = "tosa.tile"(%arg0) {multiples = [5, 7]} : (tensor<2x3xi8>)  -> (tensor<10x21xi8>)


        


More information about the Mlir-commits mailing list