[Mlir-commits] [mlir] 5541a05 - [mlir][tosa] Quantized tosa.avg_pool2d lowering to linalg
Rob Suderman
llvmlistbot at llvm.org
Tue Aug 24 19:02:55 PDT 2021
Author: Rob Suderman
Date: 2021-08-24T18:54:23-07:00
New Revision: 5541a05d6a5a74424dd6d98cfb6d9014a5fb17ca
URL: https://github.com/llvm/llvm-project/commit/5541a05d6a5a74424dd6d98cfb6d9014a5fb17ca
DIFF: https://github.com/llvm/llvm-project/commit/5541a05d6a5a74424dd6d98cfb6d9014a5fb17ca.diff
LOG: [mlir][tosa] Quantized tosa.avg_pool2d lowering to linalg
Includes the quantized version of average pool lowering to linalg dialect.
This includes a lit test for the transform. It is not 100% correct as the
multiplier / shift should be done in i64 however this is negligable rounding
difference.
Reviewed By: NatashaKnk
Differential Revision: https://reviews.llvm.org/D108676
Added:
Modified:
mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 880da55cbd24..f7e8b0e078eb 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -2504,39 +2504,34 @@ class TableConverter : public OpRewritePattern<tosa::TableOp> {
}
};
-template <typename SrcOp>
-class Pool2dConverter : public OpRewritePattern<SrcOp> {
+class MaxPool2dConverter : public OpRewritePattern<tosa::MaxPool2dOp> {
public:
- using OpRewritePattern<SrcOp>::OpRewritePattern;
+ using OpRewritePattern<tosa::MaxPool2dOp>::OpRewritePattern;
- LogicalResult matchAndRewrite(SrcOp op,
+ LogicalResult matchAndRewrite(tosa::MaxPool2dOp op,
PatternRewriter &rewriter) const final {
Location loc = op.getLoc();
Value input = op.input();
ShapedType inputTy = input.getType().cast<ShapedType>();
- Type inElementTy = inputTy.getElementType();
ShapedType resultTy = op.getType().template cast<ShapedType>();
- Type outElementTy = inputTy.getElementType();
+ Type resultETy = inputTy.getElementType();
if (!inputTy.hasStaticShape())
return failure();
// Determine what the initial value needs to be for the max pool op.
Attribute initialAttr;
- if (isa<tosa::MaxPool2dOp>(op) && outElementTy.isF32())
+ if (resultETy.isF32())
initialAttr = rewriter.getFloatAttr(
- outElementTy,
- APFloat::getLargest(
- outElementTy.cast<FloatType>().getFloatSemantics(), true));
+ resultETy,
+ APFloat::getLargest(resultETy.cast<FloatType>().getFloatSemantics(),
+ true));
- if (isa<tosa::MaxPool2dOp>(op) && outElementTy.isa<IntegerType>())
+ if (resultETy.isa<IntegerType>())
initialAttr = rewriter.getIntegerAttr(
- outElementTy,
- APInt::getSignedMinValue(outElementTy.getIntOrFloatBitWidth()));
-
- if (isa<tosa::AvgPool2dOp>(op) && outElementTy.isa<FloatType>())
- initialAttr = rewriter.getZeroAttr(outElementTy);
+ resultETy,
+ APInt::getSignedMinValue(resultETy.getIntOrFloatBitWidth()));
if (!initialAttr)
return rewriter.notifyMatchFailure(
@@ -2566,93 +2561,216 @@ class Pool2dConverter : public OpRewritePattern<SrcOp> {
rewriter.create<linalg::FillOp>(loc, initialValue, initTensor).result();
Value fakeWindowDims =
- rewriter.create<linalg::InitTensorOp>(loc, kernel, outElementTy);
+ rewriter.create<linalg::InitTensorOp>(loc, kernel, resultETy);
- if (isa<tosa::MaxPool2dOp>(op)) {
- rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMaxOp>(
- op, ArrayRef<Type>{resultTy}, ValueRange{paddedInput, fakeWindowDims},
- filledInitTensor, strideAttr, dilationAttr);
- return success();
- }
+ rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMaxOp>(
+ op, ArrayRef<Type>{resultTy}, ValueRange{paddedInput, fakeWindowDims},
+ filledInitTensor, strideAttr, dilationAttr);
+ return success();
+ }
+};
- if (isa<tosa::AvgPool2dOp>(op) && inElementTy.isF32()) {
- Value poolingOp = rewriter
- .create<linalg::PoolingNhwcSumOp>(
- loc, ArrayRef<Type>{resultTy},
- ValueRange{paddedInput, fakeWindowDims},
- filledInitTensor, strideAttr, dilationAttr)
- .getResult(0);
- auto poolingOpTy = poolingOp.getType().cast<ShapedType>();
- auto affineMap = rewriter.getMultiDimIdentityMap(resultTy.getRank());
- auto genericOp = rewriter.create<linalg::GenericOp>(
- loc, ArrayRef<Type>({resultTy}), ValueRange{}, ValueRange{poolingOp},
- ArrayRef<AffineMap>({affineMap}),
- getNParallelLoopsAttrs(resultTy.getRank()),
- [&](OpBuilder &b, Location loc, ValueRange args) {
- auto zero = rewriter.create<ConstantIndexOp>(loc, 0);
- auto one = rewriter.create<ConstantIndexOp>(loc, 1);
- auto iH = rewriter.create<ConstantIndexOp>(
- loc, poolingOpTy.getDimSize(1) - 1);
- auto iW = rewriter.create<ConstantIndexOp>(
- loc, poolingOpTy.getDimSize(2) - 1);
-
- // Compute the indices from either end.
- auto y0 = rewriter.create<linalg::IndexOp>(loc, 1);
- auto x0 = rewriter.create<linalg::IndexOp>(loc, 2);
- auto y1 = rewriter.create<SubIOp>(loc, iH, y0);
- auto x1 = rewriter.create<SubIOp>(loc, iW, x0);
-
- // Determines what the portion of valid input is covered by the
- // kernel.
- auto padFn = [&](Value v, Value x, int64_t pad) -> Value {
- if (pad == 0)
- return v;
-
- auto padVal = rewriter.create<ConstantIndexOp>(loc, pad);
- Value dx = rewriter.create<SubIOp>(loc, x, padVal);
-
- Value cmp = rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::slt,
- dx, zero);
- Value offset =
- rewriter.create<mlir::SelectOp>(loc, cmp, dx, zero);
- return rewriter.create<mlir::AddIOp>(loc, v, offset)
- ->getResult(0);
- };
-
- // Compute the vertical component of coverage.
- auto kH0 = rewriter.create<ConstantIndexOp>(loc, kernel[0]);
- auto kH1 = padFn(kH0, y0, pad[2]);
- auto kH2 = padFn(kH1, y1, pad[3]);
- auto kHCmp =
- rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, kH2, one);
- auto kH3 = rewriter.create<SelectOp>(loc, kHCmp, one, kH2);
-
- // compute teh horizontal component of coverage.
- auto kW0 = rewriter.create<ConstantIndexOp>(loc, kernel[1]);
- auto kW1 = padFn(kW0, x0, pad[4]);
- auto kW2 = padFn(kW1, x1, pad[5]);
- auto kWCmp =
- rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, kW2, one);
- auto kW3 = rewriter.create<SelectOp>(loc, kWCmp, one, kW2);
-
- // Compute the total number of elements and normalize.
- Value count = rewriter.create<MulIOp>(loc, kH3, kW3);
- auto countI = rewriter.create<mlir::IndexCastOp>(
- loc, rewriter.getI32Type(), count);
+class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
+public:
+ using OpRewritePattern<tosa::AvgPool2dOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tosa::AvgPool2dOp op,
+ PatternRewriter &rewriter) const final {
+ Location loc = op.getLoc();
+ Value input = op.input();
+ ShapedType inputTy = input.getType().cast<ShapedType>();
+ Type inElementTy = inputTy.getElementType();
+
+ ShapedType resultTy = op.getType().template cast<ShapedType>();
+ Type resultETy = inputTy.getElementType();
+
+ Type accETy =
+ inElementTy.isa<IntegerType>() ? rewriter.getI32Type() : inElementTy;
+ ShapedType accTy = resultTy.clone(accETy);
+
+ if (!inputTy.hasStaticShape())
+ return failure();
+
+ // Apply padding as necessary.
+ llvm::SmallVector<int64_t> pad;
+ pad.resize(2, 0);
+ getValuesFromIntArrayAttribute(op.pad(), pad);
+ pad.resize(pad.size() + 2, 0);
+ Attribute initialAttr = rewriter.getZeroAttr(accETy);
+ Value paddedInput = applyPad(loc, input, pad, initialAttr, rewriter);
+
+ Value initialValue = rewriter.create<ConstantOp>(loc, initialAttr);
+
+ SmallVector<int64_t> kernel, stride;
+ getValuesFromIntArrayAttribute(op.kernel(), kernel);
+ getValuesFromIntArrayAttribute(op.stride(), stride);
+
+ Attribute strideAttr = rewriter.getI64VectorAttr(stride);
+ Attribute dilationAttr = rewriter.getI64VectorAttr({1, 1});
+
+ // Create the linalg op that performs pooling.
+ Value poolInitTensor =
+ rewriter.create<linalg::InitTensorOp>(loc, accTy.getShape(), accETy);
+
+ Value filledInitTensor =
+ rewriter.create<linalg::FillOp>(loc, initialValue, poolInitTensor)
+ .result();
+
+ Value fakeWindowDims =
+ rewriter.create<linalg::InitTensorOp>(loc, kernel, accETy);
+
+ // Sum across the pooled region.
+ Value poolingOp = rewriter
+ .create<linalg::PoolingNhwcSumOp>(
+ loc, ArrayRef<Type>{accTy},
+ ValueRange{paddedInput, fakeWindowDims},
+ filledInitTensor, strideAttr, dilationAttr)
+ .getResult(0);
+
+ // Normalize the summed value by the number of elements grouped in each
+ // pool.
+ auto poolingOpTy = poolingOp.getType().cast<ShapedType>();
+ auto affineMap = rewriter.getMultiDimIdentityMap(resultTy.getRank());
+
+ Value genericInitTensor = rewriter.create<linalg::InitTensorOp>(
+ loc, resultTy.getShape(), resultETy);
+
+ auto genericOp = rewriter.create<linalg::GenericOp>(
+ loc, ArrayRef<Type>({resultTy}), ValueRange{poolingOp},
+ ValueRange{genericInitTensor},
+ ArrayRef<AffineMap>({affineMap, affineMap}),
+ getNParallelLoopsAttrs(resultTy.getRank()),
+ [&](OpBuilder &b, Location loc, ValueRange args) {
+ auto zero = rewriter.create<ConstantIndexOp>(loc, 0);
+ auto one = rewriter.create<ConstantIndexOp>(loc, 1);
+ auto iH = rewriter.create<ConstantIndexOp>(
+ loc, poolingOpTy.getDimSize(1) - 1);
+ auto iW = rewriter.create<ConstantIndexOp>(
+ loc, poolingOpTy.getDimSize(2) - 1);
+
+ // Compute the indices from either end.
+ auto y0 = rewriter.create<linalg::IndexOp>(loc, 1);
+ auto x0 = rewriter.create<linalg::IndexOp>(loc, 2);
+ auto y1 = rewriter.create<SubIOp>(loc, iH, y0);
+ auto x1 = rewriter.create<SubIOp>(loc, iW, x0);
+
+ // Determines what the portion of valid input is covered by the
+ // kernel.
+ auto padFn = [&](Value v, Value x, int64_t pad) -> Value {
+ if (pad == 0)
+ return v;
+
+ auto padVal = rewriter.create<ConstantIndexOp>(loc, pad);
+ Value dx = rewriter.create<SubIOp>(loc, x, padVal);
+
+ Value cmp = rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::slt,
+ dx, zero);
+ Value offset = rewriter.create<mlir::SelectOp>(loc, cmp, dx, zero);
+ return rewriter.create<mlir::AddIOp>(loc, v, offset)->getResult(0);
+ };
+
+ // Compute the vertical component of coverage.
+ auto kH0 = rewriter.create<ConstantIndexOp>(loc, kernel[0]);
+ auto kH1 = padFn(kH0, y0, pad[2]);
+ auto kH2 = padFn(kH1, y1, pad[3]);
+ auto kHCmp =
+ rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, kH2, one);
+ auto kH3 = rewriter.create<SelectOp>(loc, kHCmp, one, kH2);
+
+ // compute the horizontal component of coverage.
+ auto kW0 = rewriter.create<ConstantIndexOp>(loc, kernel[1]);
+ auto kW1 = padFn(kW0, x0, pad[4]);
+ auto kW2 = padFn(kW1, x1, pad[5]);
+ auto kWCmp =
+ rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, kW2, one);
+ auto kW3 = rewriter.create<SelectOp>(loc, kWCmp, one, kW2);
+
+ // Compute the total number of elements and normalize.
+ Value count = rewriter.create<MulIOp>(loc, kH3, kW3);
+ auto countI = rewriter.create<mlir::IndexCastOp>(
+ loc, rewriter.getI32Type(), count);
+
+ // Divide by the number of summed values. For floats this is just
+ // a div however for quantized values input normalization had
+ // to be applied.
+ Value poolVal = args[0];
+ if (accETy.isa<FloatType>()) {
auto countF =
rewriter.create<mlir::SIToFPOp>(loc, inElementTy, countI);
+ poolVal =
+ rewriter.create<DivFOp>(loc, poolVal, countF)->getResult(0);
+ } else {
- auto div =
- rewriter.create<DivFOp>(loc, args[0], countF)->getResult(0);
+ // If we have quantization information we need to apply an offset
+ // for the input zp value.
+ if (op.quantization_info()) {
+ auto quantizationInfo = op.quantization_info().getValue();
+ auto inputZp = rewriter.create<mlir::ConstantOp>(
+ loc, quantizationInfo.input_zp());
+ Value offset =
+ rewriter.create<mlir::MulIOp>(loc, accETy, countI, inputZp);
+ poolVal = rewriter.create<SubIOp>(loc, accETy, poolVal, offset);
+ }
- rewriter.create<linalg::YieldOp>(loc, div);
- });
+ // Compute the multiplier and shift values for the quantization
+ // normalization. Preferably we would want to compute more bits
+ // however 32-bits should be enough for compute. Honestly we
+ // should probably straight divide.
+ int64_t numerator = ((1 << 30) + 1);
+ int64_t shift = 30;
+
+ Value numeratorVal = rewriter.create<ConstantOp>(
+ loc, rewriter.getI32IntegerAttr(numerator));
+ Value multiplierVal =
+ rewriter
+ .create<UnsignedDivIOp>(loc, rewriter.getI32Type(),
+ numeratorVal, countI)
+ .getResult();
+ Value shiftVal = rewriter.create<ConstantOp>(
+ loc, rewriter.getI8IntegerAttr(shift));
+
+ auto scaled =
+ rewriter
+ .create<tosa::ApplyScaleOp>(
+ loc, rewriter.getI32Type(), poolVal, multiplierVal,
+ shiftVal, rewriter.getBoolAttr(false))
+ .getResult();
+
+ // If we have quantization information we need to apply output
+ // zeropoint.
+ if (op.quantization_info()) {
+ auto quantizationInfo = op.quantization_info().getValue();
+ auto outputZp = rewriter.create<mlir::ConstantOp>(
+ loc, quantizationInfo.output_zp());
+ scaled =
+ rewriter.create<AddIOp>(loc, scaled, outputZp).getResult();
+ }
- rewriter.replaceOp(op, genericOp.getResult(0));
- return success();
- }
+ // Apply Clip.
+ int64_t outBitwidth = resultETy.getIntOrFloatBitWidth();
+
+ auto min = rewriter.create<ConstantOp>(
+ loc, rewriter.getIntegerAttr(
+ accETy,
+ APInt::getSignedMinValue(outBitwidth).getSExtValue()));
+ auto max = rewriter.create<ConstantOp>(
+ loc, rewriter.getIntegerAttr(
+ accETy,
+ APInt::getSignedMaxValue(outBitwidth).getSExtValue()));
+ auto clamp = clampHelper<mlir::CmpIOp>(
+ loc, scaled, min, max, CmpIPredicate::slt, rewriter);
+
+ // Convert type.
+ poolVal = rewriter.create<TruncateIOp>(loc, resultETy, clamp);
+ }
- return failure();
+ // Cast to output type.
+
+ rewriter.create<linalg::YieldOp>(loc, poolVal);
+ });
+
+ rewriter.replaceOp(op, genericOp.getResult(0));
+ return success();
}
};
@@ -2719,8 +2837,8 @@ void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
TileConverter,
TransposeConverter,
MatMulConverter,
- Pool2dConverter<tosa::AvgPool2dOp>,
- Pool2dConverter<tosa::MaxPool2dOp>,
+ MaxPool2dConverter,
+ AvgPool2dConverter,
FullyConnectedConverter>(patterns->getContext());
// clang-format on
}
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 329c4b27e421..c9e2f65907a4 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1203,11 +1203,12 @@ func @avg_pool(%arg0: tensor<1x6x34x62xf32>) -> (tensor<1x5x33x62xf32>) {
// CHECK: [[CONST:%.+]] = constant 0
// CHECK: [[PAD:%.+]] = linalg.pad_tensor %arg0 low[0, 1, 1, 0] high[0, 1, 1, 0]
// CHECK: [[CONST:%.+]] = constant 0
- // CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 5, 33, 62]
- // CHECK: [[FILL:%.+]] = linalg.fill([[CONST]], [[INIT]])
+ // CHECK: [[POOLINIT:%.+]] = linalg.init_tensor [1, 5, 33, 62]
+ // CHECK: [[FILL:%.+]] = linalg.fill([[CONST]], [[POOLINIT]])
// CHECK: [[KERNEL:%.+]] = linalg.init_tensor [4, 4]
// CHECK: [[POOL:%.+]] = linalg.pooling_nhwc_sum {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins([[PAD]], [[KERNEL]] : tensor<1x8x36x62xf32>, tensor<4x4xf32>) outs([[FILL]] : tensor<1x5x33x62xf32>)
- // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} outs([[POOL]] : tensor<1x5x33x62xf32>)
+ // CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 5, 33, 62]
+ // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins([[POOL]] : tensor<1x5x33x62xf32>) outs([[INIT]] : tensor<1x5x33x62xf32>)
// CHECK: [[ZERO:%.0]] = constant 0
// CHECK: [[ONE:%.+]] = constant 1
// CHECK: [[HEIGHT:%.+]] = constant 4
@@ -1257,17 +1258,46 @@ func @avg_pool(%arg0: tensor<1x6x34x62xf32>) -> (tensor<1x5x33x62xf32>) {
// -----
-// CHECK: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d0, d1, d2)>
-// CHECK: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-// CHECK: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)>
+// CHECK-LABEL: @avg_pool_i8
+func @avg_pool_i8(%arg0 : tensor<1x128x128x2xi8>) -> () {
+
+ // CHECK: linalg.pooling_nhwc_sum
+ // CHECK: linalg.generic
+
+ // CHECK: %[[INZP:.+]] = constant -128
+ // CHECK: %[[INZP_OFF:.+]] = muli %{{.+}}, %[[INZP]]
+ // CHECK: %[[OFFSETED:.+]] = subi %arg1, %[[INZP_OFF]]
+ // CHECK: %[[NUMERATOR:.+]] = constant 1073741825
+ // CHECK: %[[MULTIPLIER:.+]] = divi_unsigned %[[NUMERATOR]], %{{.+}}
+ // CHECK: %[[SHIFT:.+]] = constant 30
+ // CHECK: %[[SCALE:.+]] = "tosa.apply_scale"(%{{.+}}, %[[MULTIPLIER]], %[[SHIFT]]) {double_round = false}
+ // CHECK: %[[OUTZP:.+]] = constant -128
+ // CHECK: %[[OUT:.+]] = addi %[[SCALE]], %[[OUTZP]]
+ // CHECK: %[[MIN:.+]] = constant -128
+ // CHECK: %[[MAX:.+]] = constant 127
+ // CHECK: %[[CMP_MIN:.+]] = cmpi slt, %[[OUT]], %[[MIN]]
+ // CHECK: %[[CLMP_MIN:.+]] = select %[[CMP_MIN]], %[[MIN]], %[[OUT]]
+ // CHECK: %[[CMP_MAX:.+]] = cmpi slt, %[[MAX]], %[[OUT]]
+ // CHECK: %[[CLMP_MAX:.+]] = select %[[CMP_MAX]], %[[MAX]], %[[CLMP_MIN]]
+ // CHECK: %[[TRUNC:.+]] = trunci %[[CLMP_MAX]]
+ // CHECK: linalg.yield %[[TRUNC]]
+ %0 = "tosa.avg_pool2d"(%arg0) {kernel = [4, 4], pad = [0, 0, 0, 0], quantization_info = {input_zp = -128 : i32, output_zp = -128 : i32}, stride = [4, 4]} : (tensor<1x128x128x2xi8>) -> tensor<1x32x32x2xi8>
+ return
+}
+
+// -----
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d0, d1, d2)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)>
-// CHECK-LABEL @conv2d_f32
+// CHECK-LABEL: @conv2d_f32
func @conv2d_f32(%input: tensor<1x49x42x27xf32>, %weights: tensor<28x3x3x27xf32>, %bias: tensor<28xf32>) -> () {
// CHECK: %[[W_IN:.+]] = linalg.init_tensor [3, 3, 27, 28]
- // CHECK: %[[W:.+]] = linalg.generic {indexing_maps = [#[[MAP0]], #[[MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg1 : tensor<28x3x3x27xf32>) outs(%[[W_IN]] : tensor<3x3x27x28xf32>)
+ // CHECK: %[[W:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg1 : tensor<28x3x3x27xf32>) outs(%[[W_IN]] : tensor<3x3x27x28xf32>)
// CHECK: linalg.yield %arg3 : f32
// CHECK: %[[B_IN:.+]] = linalg.init_tensor [1, 45, 40, 28]
- // CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[MAP2]], #[[MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<28xf32>) outs(%[[B_IN]] : tensor<1x45x40x28xf32>)
+ // CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<28xf32>) outs(%[[B_IN]] : tensor<1x45x40x28xf32>)
// CHECK: linalg.yield %arg3 : f32
// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %1 : tensor<1x49x42x27xf32>, tensor<3x3x27x28xf32>) outs(%[[B]] : tensor<1x45x40x28xf32>)
%0 = "tosa.conv2d"(%input, %weights, %bias) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [2, 1]} : (tensor<1x49x42x27xf32>, tensor<28x3x3x27xf32>, tensor<28xf32>) -> (tensor<1x45x40x28xf32>)
More information about the Mlir-commits
mailing list