[Mlir-commits] [mlir] 5541a05 - [mlir][tosa] Quantized tosa.avg_pool2d lowering to linalg

Rob Suderman llvmlistbot at llvm.org
Tue Aug 24 19:02:55 PDT 2021


Author: Rob Suderman
Date: 2021-08-24T18:54:23-07:00
New Revision: 5541a05d6a5a74424dd6d98cfb6d9014a5fb17ca

URL: https://github.com/llvm/llvm-project/commit/5541a05d6a5a74424dd6d98cfb6d9014a5fb17ca
DIFF: https://github.com/llvm/llvm-project/commit/5541a05d6a5a74424dd6d98cfb6d9014a5fb17ca.diff

LOG: [mlir][tosa] Quantized tosa.avg_pool2d lowering to linalg

Includes the quantized version of average pool lowering to linalg dialect.
This includes a lit test for the transform. It is not 100% correct as the
multiplier / shift should be done in i64 however this is negligable rounding
difference.

Reviewed By: NatashaKnk

Differential Revision: https://reviews.llvm.org/D108676

Added: 
    

Modified: 
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 880da55cbd24..f7e8b0e078eb 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -2504,39 +2504,34 @@ class TableConverter : public OpRewritePattern<tosa::TableOp> {
   }
 };
 
-template <typename SrcOp>
-class Pool2dConverter : public OpRewritePattern<SrcOp> {
+class MaxPool2dConverter : public OpRewritePattern<tosa::MaxPool2dOp> {
 public:
-  using OpRewritePattern<SrcOp>::OpRewritePattern;
+  using OpRewritePattern<tosa::MaxPool2dOp>::OpRewritePattern;
 
-  LogicalResult matchAndRewrite(SrcOp op,
+  LogicalResult matchAndRewrite(tosa::MaxPool2dOp op,
                                 PatternRewriter &rewriter) const final {
     Location loc = op.getLoc();
     Value input = op.input();
     ShapedType inputTy = input.getType().cast<ShapedType>();
-    Type inElementTy = inputTy.getElementType();
 
     ShapedType resultTy = op.getType().template cast<ShapedType>();
-    Type outElementTy = inputTy.getElementType();
+    Type resultETy = inputTy.getElementType();
 
     if (!inputTy.hasStaticShape())
       return failure();
 
     // Determine what the initial value needs to be for the max pool op.
     Attribute initialAttr;
-    if (isa<tosa::MaxPool2dOp>(op) && outElementTy.isF32())
+    if (resultETy.isF32())
       initialAttr = rewriter.getFloatAttr(
-          outElementTy,
-          APFloat::getLargest(
-              outElementTy.cast<FloatType>().getFloatSemantics(), true));
+          resultETy,
+          APFloat::getLargest(resultETy.cast<FloatType>().getFloatSemantics(),
+                              true));
 
-    if (isa<tosa::MaxPool2dOp>(op) && outElementTy.isa<IntegerType>())
+    if (resultETy.isa<IntegerType>())
       initialAttr = rewriter.getIntegerAttr(
-          outElementTy,
-          APInt::getSignedMinValue(outElementTy.getIntOrFloatBitWidth()));
-
-    if (isa<tosa::AvgPool2dOp>(op) && outElementTy.isa<FloatType>())
-      initialAttr = rewriter.getZeroAttr(outElementTy);
+          resultETy,
+          APInt::getSignedMinValue(resultETy.getIntOrFloatBitWidth()));
 
     if (!initialAttr)
       return rewriter.notifyMatchFailure(
@@ -2566,93 +2561,216 @@ class Pool2dConverter : public OpRewritePattern<SrcOp> {
         rewriter.create<linalg::FillOp>(loc, initialValue, initTensor).result();
 
     Value fakeWindowDims =
-        rewriter.create<linalg::InitTensorOp>(loc, kernel, outElementTy);
+        rewriter.create<linalg::InitTensorOp>(loc, kernel, resultETy);
 
-    if (isa<tosa::MaxPool2dOp>(op)) {
-      rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMaxOp>(
-          op, ArrayRef<Type>{resultTy}, ValueRange{paddedInput, fakeWindowDims},
-          filledInitTensor, strideAttr, dilationAttr);
-      return success();
-    }
+    rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMaxOp>(
+        op, ArrayRef<Type>{resultTy}, ValueRange{paddedInput, fakeWindowDims},
+        filledInitTensor, strideAttr, dilationAttr);
+    return success();
+  }
+};
 
-    if (isa<tosa::AvgPool2dOp>(op) && inElementTy.isF32()) {
-      Value poolingOp = rewriter
-                            .create<linalg::PoolingNhwcSumOp>(
-                                loc, ArrayRef<Type>{resultTy},
-                                ValueRange{paddedInput, fakeWindowDims},
-                                filledInitTensor, strideAttr, dilationAttr)
-                            .getResult(0);
-      auto poolingOpTy = poolingOp.getType().cast<ShapedType>();
-      auto affineMap = rewriter.getMultiDimIdentityMap(resultTy.getRank());
-      auto genericOp = rewriter.create<linalg::GenericOp>(
-          loc, ArrayRef<Type>({resultTy}), ValueRange{}, ValueRange{poolingOp},
-          ArrayRef<AffineMap>({affineMap}),
-          getNParallelLoopsAttrs(resultTy.getRank()),
-          [&](OpBuilder &b, Location loc, ValueRange args) {
-            auto zero = rewriter.create<ConstantIndexOp>(loc, 0);
-            auto one = rewriter.create<ConstantIndexOp>(loc, 1);
-            auto iH = rewriter.create<ConstantIndexOp>(
-                loc, poolingOpTy.getDimSize(1) - 1);
-            auto iW = rewriter.create<ConstantIndexOp>(
-                loc, poolingOpTy.getDimSize(2) - 1);
-
-            // Compute the indices from either end.
-            auto y0 = rewriter.create<linalg::IndexOp>(loc, 1);
-            auto x0 = rewriter.create<linalg::IndexOp>(loc, 2);
-            auto y1 = rewriter.create<SubIOp>(loc, iH, y0);
-            auto x1 = rewriter.create<SubIOp>(loc, iW, x0);
-
-            // Determines what the portion of valid input is covered by the
-            // kernel.
-            auto padFn = [&](Value v, Value x, int64_t pad) -> Value {
-              if (pad == 0)
-                return v;
-
-              auto padVal = rewriter.create<ConstantIndexOp>(loc, pad);
-              Value dx = rewriter.create<SubIOp>(loc, x, padVal);
-
-              Value cmp = rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::slt,
-                                                        dx, zero);
-              Value offset =
-                  rewriter.create<mlir::SelectOp>(loc, cmp, dx, zero);
-              return rewriter.create<mlir::AddIOp>(loc, v, offset)
-                  ->getResult(0);
-            };
-
-            // Compute the vertical component of coverage.
-            auto kH0 = rewriter.create<ConstantIndexOp>(loc, kernel[0]);
-            auto kH1 = padFn(kH0, y0, pad[2]);
-            auto kH2 = padFn(kH1, y1, pad[3]);
-            auto kHCmp =
-                rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, kH2, one);
-            auto kH3 = rewriter.create<SelectOp>(loc, kHCmp, one, kH2);
-
-            // compute teh horizontal component of coverage.
-            auto kW0 = rewriter.create<ConstantIndexOp>(loc, kernel[1]);
-            auto kW1 = padFn(kW0, x0, pad[4]);
-            auto kW2 = padFn(kW1, x1, pad[5]);
-            auto kWCmp =
-                rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, kW2, one);
-            auto kW3 = rewriter.create<SelectOp>(loc, kWCmp, one, kW2);
-
-            // Compute the total number of elements and normalize.
-            Value count = rewriter.create<MulIOp>(loc, kH3, kW3);
-            auto countI = rewriter.create<mlir::IndexCastOp>(
-                loc, rewriter.getI32Type(), count);
+class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
+public:
+  using OpRewritePattern<tosa::AvgPool2dOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tosa::AvgPool2dOp op,
+                                PatternRewriter &rewriter) const final {
+    Location loc = op.getLoc();
+    Value input = op.input();
+    ShapedType inputTy = input.getType().cast<ShapedType>();
+    Type inElementTy = inputTy.getElementType();
+
+    ShapedType resultTy = op.getType().template cast<ShapedType>();
+    Type resultETy = inputTy.getElementType();
+
+    Type accETy =
+        inElementTy.isa<IntegerType>() ? rewriter.getI32Type() : inElementTy;
+    ShapedType accTy = resultTy.clone(accETy);
+
+    if (!inputTy.hasStaticShape())
+      return failure();
+
+    // Apply padding as necessary.
+    llvm::SmallVector<int64_t> pad;
+    pad.resize(2, 0);
+    getValuesFromIntArrayAttribute(op.pad(), pad);
+    pad.resize(pad.size() + 2, 0);
+    Attribute initialAttr = rewriter.getZeroAttr(accETy);
+    Value paddedInput = applyPad(loc, input, pad, initialAttr, rewriter);
+
+    Value initialValue = rewriter.create<ConstantOp>(loc, initialAttr);
+
+    SmallVector<int64_t> kernel, stride;
+    getValuesFromIntArrayAttribute(op.kernel(), kernel);
+    getValuesFromIntArrayAttribute(op.stride(), stride);
+
+    Attribute strideAttr = rewriter.getI64VectorAttr(stride);
+    Attribute dilationAttr = rewriter.getI64VectorAttr({1, 1});
+
+    // Create the linalg op that performs pooling.
+    Value poolInitTensor =
+        rewriter.create<linalg::InitTensorOp>(loc, accTy.getShape(), accETy);
+
+    Value filledInitTensor =
+        rewriter.create<linalg::FillOp>(loc, initialValue, poolInitTensor)
+            .result();
+
+    Value fakeWindowDims =
+        rewriter.create<linalg::InitTensorOp>(loc, kernel, accETy);
+
+    // Sum across the pooled region.
+    Value poolingOp = rewriter
+                          .create<linalg::PoolingNhwcSumOp>(
+                              loc, ArrayRef<Type>{accTy},
+                              ValueRange{paddedInput, fakeWindowDims},
+                              filledInitTensor, strideAttr, dilationAttr)
+                          .getResult(0);
+
+    // Normalize the summed value by the number of elements grouped in each
+    // pool.
+    auto poolingOpTy = poolingOp.getType().cast<ShapedType>();
+    auto affineMap = rewriter.getMultiDimIdentityMap(resultTy.getRank());
+
+    Value genericInitTensor = rewriter.create<linalg::InitTensorOp>(
+        loc, resultTy.getShape(), resultETy);
+
+    auto genericOp = rewriter.create<linalg::GenericOp>(
+        loc, ArrayRef<Type>({resultTy}), ValueRange{poolingOp},
+        ValueRange{genericInitTensor},
+        ArrayRef<AffineMap>({affineMap, affineMap}),
+        getNParallelLoopsAttrs(resultTy.getRank()),
+        [&](OpBuilder &b, Location loc, ValueRange args) {
+          auto zero = rewriter.create<ConstantIndexOp>(loc, 0);
+          auto one = rewriter.create<ConstantIndexOp>(loc, 1);
+          auto iH = rewriter.create<ConstantIndexOp>(
+              loc, poolingOpTy.getDimSize(1) - 1);
+          auto iW = rewriter.create<ConstantIndexOp>(
+              loc, poolingOpTy.getDimSize(2) - 1);
+
+          // Compute the indices from either end.
+          auto y0 = rewriter.create<linalg::IndexOp>(loc, 1);
+          auto x0 = rewriter.create<linalg::IndexOp>(loc, 2);
+          auto y1 = rewriter.create<SubIOp>(loc, iH, y0);
+          auto x1 = rewriter.create<SubIOp>(loc, iW, x0);
+
+          // Determines what the portion of valid input is covered by the
+          // kernel.
+          auto padFn = [&](Value v, Value x, int64_t pad) -> Value {
+            if (pad == 0)
+              return v;
+
+            auto padVal = rewriter.create<ConstantIndexOp>(loc, pad);
+            Value dx = rewriter.create<SubIOp>(loc, x, padVal);
+
+            Value cmp = rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::slt,
+                                                      dx, zero);
+            Value offset = rewriter.create<mlir::SelectOp>(loc, cmp, dx, zero);
+            return rewriter.create<mlir::AddIOp>(loc, v, offset)->getResult(0);
+          };
+
+          // Compute the vertical component of coverage.
+          auto kH0 = rewriter.create<ConstantIndexOp>(loc, kernel[0]);
+          auto kH1 = padFn(kH0, y0, pad[2]);
+          auto kH2 = padFn(kH1, y1, pad[3]);
+          auto kHCmp =
+              rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, kH2, one);
+          auto kH3 = rewriter.create<SelectOp>(loc, kHCmp, one, kH2);
+
+          // compute the horizontal component of coverage.
+          auto kW0 = rewriter.create<ConstantIndexOp>(loc, kernel[1]);
+          auto kW1 = padFn(kW0, x0, pad[4]);
+          auto kW2 = padFn(kW1, x1, pad[5]);
+          auto kWCmp =
+              rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, kW2, one);
+          auto kW3 = rewriter.create<SelectOp>(loc, kWCmp, one, kW2);
+
+          // Compute the total number of elements and normalize.
+          Value count = rewriter.create<MulIOp>(loc, kH3, kW3);
+          auto countI = rewriter.create<mlir::IndexCastOp>(
+              loc, rewriter.getI32Type(), count);
+
+          // Divide by the number of summed values. For floats this is just
+          // a div however for quantized values input normalization had
+          // to be applied.
+          Value poolVal = args[0];
+          if (accETy.isa<FloatType>()) {
             auto countF =
                 rewriter.create<mlir::SIToFPOp>(loc, inElementTy, countI);
+            poolVal =
+                rewriter.create<DivFOp>(loc, poolVal, countF)->getResult(0);
+          } else {
 
-            auto div =
-                rewriter.create<DivFOp>(loc, args[0], countF)->getResult(0);
+            // If we have quantization information we need to apply an offset
+            // for the input zp value.
+            if (op.quantization_info()) {
+              auto quantizationInfo = op.quantization_info().getValue();
+              auto inputZp = rewriter.create<mlir::ConstantOp>(
+                  loc, quantizationInfo.input_zp());
+              Value offset =
+                  rewriter.create<mlir::MulIOp>(loc, accETy, countI, inputZp);
+              poolVal = rewriter.create<SubIOp>(loc, accETy, poolVal, offset);
+            }
 
-            rewriter.create<linalg::YieldOp>(loc, div);
-          });
+            // Compute the multiplier and shift values for the quantization
+            // normalization. Preferably we would want to compute more bits
+            // however 32-bits should be enough for compute. Honestly we
+            // should probably straight divide.
+            int64_t numerator = ((1 << 30) + 1);
+            int64_t shift = 30;
+
+            Value numeratorVal = rewriter.create<ConstantOp>(
+                loc, rewriter.getI32IntegerAttr(numerator));
+            Value multiplierVal =
+                rewriter
+                    .create<UnsignedDivIOp>(loc, rewriter.getI32Type(),
+                                            numeratorVal, countI)
+                    .getResult();
+            Value shiftVal = rewriter.create<ConstantOp>(
+                loc, rewriter.getI8IntegerAttr(shift));
+
+            auto scaled =
+                rewriter
+                    .create<tosa::ApplyScaleOp>(
+                        loc, rewriter.getI32Type(), poolVal, multiplierVal,
+                        shiftVal, rewriter.getBoolAttr(false))
+                    .getResult();
+
+            // If we have quantization information we need to apply output
+            // zeropoint.
+            if (op.quantization_info()) {
+              auto quantizationInfo = op.quantization_info().getValue();
+              auto outputZp = rewriter.create<mlir::ConstantOp>(
+                  loc, quantizationInfo.output_zp());
+              scaled =
+                  rewriter.create<AddIOp>(loc, scaled, outputZp).getResult();
+            }
 
-      rewriter.replaceOp(op, genericOp.getResult(0));
-      return success();
-    }
+            // Apply Clip.
+            int64_t outBitwidth = resultETy.getIntOrFloatBitWidth();
+
+            auto min = rewriter.create<ConstantOp>(
+                loc, rewriter.getIntegerAttr(
+                         accETy,
+                         APInt::getSignedMinValue(outBitwidth).getSExtValue()));
+            auto max = rewriter.create<ConstantOp>(
+                loc, rewriter.getIntegerAttr(
+                         accETy,
+                         APInt::getSignedMaxValue(outBitwidth).getSExtValue()));
+            auto clamp = clampHelper<mlir::CmpIOp>(
+                loc, scaled, min, max, CmpIPredicate::slt, rewriter);
+
+            // Convert type.
+            poolVal = rewriter.create<TruncateIOp>(loc, resultETy, clamp);
+          }
 
-    return failure();
+          // Cast to output type.
+
+          rewriter.create<linalg::YieldOp>(loc, poolVal);
+        });
+
+    rewriter.replaceOp(op, genericOp.getResult(0));
+    return success();
   }
 };
 
@@ -2719,8 +2837,8 @@ void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
       TileConverter,
       TransposeConverter,
       MatMulConverter,
-      Pool2dConverter<tosa::AvgPool2dOp>,
-      Pool2dConverter<tosa::MaxPool2dOp>,
+      MaxPool2dConverter,
+      AvgPool2dConverter,
       FullyConnectedConverter>(patterns->getContext());
   // clang-format on
 }

diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 329c4b27e421..c9e2f65907a4 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1203,11 +1203,12 @@ func @avg_pool(%arg0: tensor<1x6x34x62xf32>) -> (tensor<1x5x33x62xf32>) {
   // CHECK: [[CONST:%.+]] = constant 0
   // CHECK: [[PAD:%.+]] = linalg.pad_tensor %arg0 low[0, 1, 1, 0] high[0, 1, 1, 0]
   // CHECK: [[CONST:%.+]] = constant 0
-  // CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 5, 33, 62]
-  // CHECK: [[FILL:%.+]] = linalg.fill([[CONST]], [[INIT]])
+  // CHECK: [[POOLINIT:%.+]] = linalg.init_tensor [1, 5, 33, 62]
+  // CHECK: [[FILL:%.+]] = linalg.fill([[CONST]], [[POOLINIT]])
   // CHECK: [[KERNEL:%.+]] = linalg.init_tensor [4, 4]
   // CHECK: [[POOL:%.+]] = linalg.pooling_nhwc_sum {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins([[PAD]], [[KERNEL]] : tensor<1x8x36x62xf32>, tensor<4x4xf32>) outs([[FILL]] : tensor<1x5x33x62xf32>)
-  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} outs([[POOL]] : tensor<1x5x33x62xf32>)
+  // CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 5, 33, 62]
+  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins([[POOL]] : tensor<1x5x33x62xf32>) outs([[INIT]] : tensor<1x5x33x62xf32>)
   // CHECK:   [[ZERO:%.0]] = constant 0
   // CHECK:   [[ONE:%.+]] = constant 1
   // CHECK:   [[HEIGHT:%.+]] = constant 4
@@ -1257,17 +1258,46 @@ func @avg_pool(%arg0: tensor<1x6x34x62xf32>) -> (tensor<1x5x33x62xf32>) {
 
 // -----
 
-// CHECK: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d0, d1, d2)>
-// CHECK: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-// CHECK: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)>
+// CHECK-LABEL: @avg_pool_i8
+func @avg_pool_i8(%arg0 : tensor<1x128x128x2xi8>) -> () {
+
+  // CHECK: linalg.pooling_nhwc_sum
+  // CHECK: linalg.generic
+
+  // CHECK: %[[INZP:.+]] = constant -128
+  // CHECK: %[[INZP_OFF:.+]] = muli %{{.+}}, %[[INZP]]
+  // CHECK: %[[OFFSETED:.+]] = subi %arg1, %[[INZP_OFF]]
+  // CHECK: %[[NUMERATOR:.+]] = constant 1073741825
+  // CHECK: %[[MULTIPLIER:.+]] = divi_unsigned %[[NUMERATOR]], %{{.+}}
+  // CHECK: %[[SHIFT:.+]] = constant 30
+  // CHECK: %[[SCALE:.+]] = "tosa.apply_scale"(%{{.+}}, %[[MULTIPLIER]], %[[SHIFT]]) {double_round = false}
+  // CHECK: %[[OUTZP:.+]] = constant -128
+  // CHECK: %[[OUT:.+]] = addi %[[SCALE]], %[[OUTZP]]
+  // CHECK: %[[MIN:.+]] = constant -128
+  // CHECK: %[[MAX:.+]] = constant 127
+  // CHECK: %[[CMP_MIN:.+]] = cmpi slt, %[[OUT]], %[[MIN]]
+  // CHECK: %[[CLMP_MIN:.+]] = select %[[CMP_MIN]], %[[MIN]], %[[OUT]]
+  // CHECK: %[[CMP_MAX:.+]] = cmpi slt, %[[MAX]], %[[OUT]]
+  // CHECK: %[[CLMP_MAX:.+]] = select %[[CMP_MAX]], %[[MAX]], %[[CLMP_MIN]]
+  // CHECK: %[[TRUNC:.+]] = trunci %[[CLMP_MAX]]
+  // CHECK: linalg.yield %[[TRUNC]]
+  %0 = "tosa.avg_pool2d"(%arg0) {kernel = [4, 4], pad = [0, 0, 0, 0], quantization_info = {input_zp = -128 : i32, output_zp = -128 : i32}, stride = [4, 4]} : (tensor<1x128x128x2xi8>) -> tensor<1x32x32x2xi8>
+  return
+}
+
+// -----
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d0, d1, d2)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)>
 
-// CHECK-LABEL @conv2d_f32
+// CHECK-LABEL: @conv2d_f32
 func @conv2d_f32(%input: tensor<1x49x42x27xf32>, %weights: tensor<28x3x3x27xf32>, %bias: tensor<28xf32>) -> () {
   // CHECK: %[[W_IN:.+]] = linalg.init_tensor [3, 3, 27, 28]
-  // CHECK: %[[W:.+]] = linalg.generic {indexing_maps = [#[[MAP0]], #[[MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg1 : tensor<28x3x3x27xf32>) outs(%[[W_IN]] : tensor<3x3x27x28xf32>)
+  // CHECK: %[[W:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg1 : tensor<28x3x3x27xf32>) outs(%[[W_IN]] : tensor<3x3x27x28xf32>)
   // CHECK:   linalg.yield %arg3 : f32
   // CHECK: %[[B_IN:.+]] = linalg.init_tensor [1, 45, 40, 28]
-  // CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[MAP2]], #[[MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<28xf32>) outs(%[[B_IN]] : tensor<1x45x40x28xf32>)
+  // CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<28xf32>) outs(%[[B_IN]] : tensor<1x45x40x28xf32>)
   // CHECK:   linalg.yield %arg3 : f32
   // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %1 : tensor<1x49x42x27xf32>, tensor<3x3x27x28xf32>) outs(%[[B]] : tensor<1x45x40x28xf32>)
   %0 = "tosa.conv2d"(%input, %weights, %bias) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [2, 1]} : (tensor<1x49x42x27xf32>, tensor<28x3x3x27xf32>, tensor<28xf32>)  -> (tensor<1x45x40x28xf32>)


        


More information about the Mlir-commits mailing list