[Mlir-commits] [mlir] [mlir][tosa] Convert group tosa::Conv2DOp to linalg conv (PR #108192)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Sep 11 04:37:04 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: None (stefankoncarevic)
<details>
<summary>Changes</summary>
This patch adds two new ops: linalg::Conv2DNhwgcGfhwcOp and linalg::Conv2DNhwgcGfhwcQOp, and uses them to convert tosa group conv2d Ops.
- Added linalg::Conv2DNhwgcGfhwcOp and linalg::Conv2DNhwgcGfhwcQOp.
- Updated the conversion process to use these new ops for tosa group conv2d operations.
---
Patch is 26.58 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/108192.diff
9 Files Affected:
- (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml (+237)
- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td (+12)
- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+2-1)
- (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp (+103-42)
- (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+13)
- (modified) mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp (+3-2)
- (modified) mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py (+54)
- (modified) mlir/test/Dialect/Linalg/named-ops.mlir (+16)
- (modified) mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp (+1-1)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 8cb698096ef5b7..011c4858d6521b 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -3410,6 +3410,243 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_arg: K
--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+ name: conv_2d_nhwgc_gfhwc
+ cpp_class_name: Conv2DNhwgcGfhwcOp
+ doc: |-
+ Performs 2-D grouped convolution.
+
+ Layout:
+ * Input: NHWGC.
+ * Kernel: GFHWC.
+
+ Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output.
+ implements:
+ - LinalgConvolutionOpInterface
+structured_op: !LinalgStructuredOpConfig
+ args:
+ - !LinalgOperandDefConfig
+ name: I
+ kind: input_tensor
+ type_var: T1
+ shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
+ (s0, s1 * s2 + s3 * s4, s5 * s6 + s7 * s8, s9, s10)>
+ - !LinalgOperandDefConfig
+ name: K
+ kind: input_tensor
+ type_var: T2
+ shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
+ (s9, s11, s3, s7, s10)>
+ - !LinalgOperandDefConfig
+ name: O
+ kind: output_tensor
+ type_var: U
+ shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
+ (s0, s1, s5, s9, s11)>
+ - !LinalgOperandDefConfig
+ name: strides
+ kind: index_attr
+ index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11]
+ -> (s2, s6)>
+ default_indices:
+ - 1
+ - 1
+ - !LinalgOperandDefConfig
+ name: dilations
+ kind: index_attr
+ index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11]
+ -> (s4, s8)>
+ default_indices:
+ - 1
+ - 1
+ indexing_maps: !LinalgIndexingMapsConfig
+ static_indexing_maps:
+ - affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, s6, s7,
+ s8, s9, s10, s11] -> (d0, d1 * s2 + d5 * s4, d2 * s6 + d6 * s8, d3, d7)>
+ - affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, s6, s7,
+ s8, s9, s10, s11] -> (d3, d4, d5, d6, d7)>
+ - affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, s6, s7,
+ s8, s9, s10, s11] -> (d0, d1, d2, d3, d4)>
+ iterator_types:
+ - parallel
+ - parallel
+ - parallel
+ - parallel
+ - parallel
+ - reduction
+ - reduction
+ - reduction
+ assignments:
+ - !ScalarAssign
+ arg: O
+ value: !ScalarExpression
+ scalar_fn:
+ kind: binary
+ fn_name: add
+ operands:
+ - !ScalarExpression
+ scalar_arg: O
+ - !ScalarExpression
+ scalar_fn:
+ kind: binary
+ fn_name: mul
+ operands:
+ - !ScalarExpression
+ scalar_fn:
+ kind: type
+ fn_name: cast_signed
+ type_var: U
+ operands:
+ - !ScalarExpression
+ scalar_arg: I
+ - !ScalarExpression
+ scalar_fn:
+ kind: type
+ fn_name: cast_signed
+ type_var: U
+ operands:
+ - !ScalarExpression
+ scalar_arg: K
+--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+ name: conv_2d_nhwgc_gfhwc_q
+ cpp_class_name: Conv2DNhwgcGfhwcQOp
+ doc: |-
+ Performs 2-D grouped convolution with zero point offsets.
+
+ Layout:
+ * Input: NHWGC.
+ * Kernel: GFHWC.
+
+ Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output. This includes the zero
+ point offsets common to quantized operations.
+ implements:
+ - LinalgConvolutionOpInterface
+structured_op: !LinalgStructuredOpConfig
+ args:
+ - !LinalgOperandDefConfig
+ name: I
+ kind: input_tensor
+ type_var: T1
+ shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
+ (s0, s1 * s2 + s3 * s4, s5 * s6 + s7 * s8, s9, s10)>
+ - !LinalgOperandDefConfig
+ name: K
+ kind: input_tensor
+ type_var: T2
+ shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
+ (s9, s11, s3, s7, s10)>
+ - !LinalgOperandDefConfig
+ name: IZp
+ kind: scalar
+ type_var: I32
+ - !LinalgOperandDefConfig
+ name: KZp
+ kind: scalar
+ type_var: I32
+ - !LinalgOperandDefConfig
+ name: O
+ kind: output_tensor
+ type_var: U
+ shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
+ (s0, s1, s5, s9, s11)>
+ - !LinalgOperandDefConfig
+ name: strides
+ kind: index_attr
+ index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11]
+ -> (s2, s6)>
+ default_indices:
+ - 1
+ - 1
+ - !LinalgOperandDefConfig
+ name: dilations
+ kind: index_attr
+ index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11]
+ -> (s4, s8)>
+ default_indices:
+ - 1
+ - 1
+ indexing_maps: !LinalgIndexingMapsConfig
+ static_indexing_maps:
+ - affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, s6, s7,
+ s8, s9, s10, s11] -> (d0, d1 * s2 + d5 * s4, d2 * s6 + d6 * s8, d3, d7)>
+ - affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, s6, s7,
+ s8, s9, s10, s11] -> (d3, d4, d5, d6, d7)>
+ - affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, s6, s7,
+ s8, s9, s10, s11] -> ()>
+ - affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, s6, s7,
+ s8, s9, s10, s11] -> ()>
+ - affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, s6, s7,
+ s8, s9, s10, s11] -> (d0, d1, d2, d3, d4)>
+ iterator_types:
+ - parallel
+ - parallel
+ - parallel
+ - parallel
+ - parallel
+ - reduction
+ - reduction
+ - reduction
+ assignments:
+ - !ScalarAssign
+ arg: O
+ value: !ScalarExpression
+ scalar_fn:
+ kind: binary
+ fn_name: add
+ operands:
+ - !ScalarExpression
+ scalar_arg: O
+ - !ScalarExpression
+ scalar_fn:
+ kind: binary
+ fn_name: mul
+ operands:
+ - !ScalarExpression
+ scalar_fn:
+ kind: binary
+ fn_name: sub
+ operands:
+ - !ScalarExpression
+ scalar_fn:
+ kind: type
+ fn_name: cast_signed
+ type_var: U
+ operands:
+ - !ScalarExpression
+ scalar_arg: I
+ - !ScalarExpression
+ scalar_fn:
+ kind: type
+ fn_name: cast_signed
+ type_var: U
+ operands:
+ - !ScalarExpression
+ scalar_arg: IZp
+ - !ScalarExpression
+ scalar_fn:
+ kind: binary
+ fn_name: sub
+ operands:
+ - !ScalarExpression
+ scalar_fn:
+ kind: type
+ fn_name: cast_signed
+ type_var: U
+ operands:
+ - !ScalarExpression
+ scalar_arg: K
+ - !ScalarExpression
+ scalar_fn:
+ kind: type
+ fn_name: cast_signed
+ type_var: U
+ operands:
+ - !ScalarExpression
+ scalar_arg: KZp
+--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: conv_2d_ngchw_gfchw_q
cpp_class_name: Conv2DNgchwGfchwQOp
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index 1412c7a2615d20..d4697f0afbf466 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -133,6 +133,18 @@ def Tosa_ConvOpQuantInfoBuilder : OpBuilder<
pad, stride, dilation);
}]>;
+// Handles grouped convolution
+def Tosa_ConvOpGroupQuantBuilder : OpBuilder<
+ (ins "::mlir::Type":$outputType, "::mlir::Value":$input,
+ "::mlir::Value":$weight, "::mlir::Value":$bias,
+ "::mlir::DenseI64ArrayAttr":$pad, "::mlir::DenseI64ArrayAttr":$stride,
+ "::mlir::DenseI64ArrayAttr":$dilation, "::mlir::IntegerAttr":$group),
+ [{
+ buildConvOpWithQuantInfo($_builder, $_state, outputType,
+ input, weight, bias,
+ pad, stride, dilation, group);
+ }]>;
+
// Handles tosa.transpose_conv2d which has an outpad and output shape attribute.
def Tosa_TransConvOpQuantInfoBuilder : OpBuilder<
(ins "::mlir::Type":$outputType, "::mlir::Value":$input,
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index ab6daa39708d13..0b67019fd0c7bb 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -108,6 +108,7 @@ def Tosa_Conv2DOp : Tosa_InferShapedTypeOp<"conv2d"> {
Tosa_IntArrayAttr4:$pad,
Tosa_IntArrayAttr2:$stride,
Tosa_IntArrayAttr2:$dilation,
+ OptionalAttr<I64Attr>:$group,
OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info,
DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound
);
@@ -116,7 +117,7 @@ def Tosa_Conv2DOp : Tosa_InferShapedTypeOp<"conv2d"> {
Tosa_Tensor4D:$output
);
- let builders = [Tosa_ConvOpQuantInfoBuilder];
+ let builders = [Tosa_ConvOpQuantInfoBuilder, Tosa_ConvOpGroupQuantBuilder];
let hasVerifier = 1;
}
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 77c3d2e8757910..898bed4a895864 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -236,6 +236,7 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
LogicalResult
matchAndRewrite(TosaConvOp op, typename TosaConvOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
+ bool isConv2DOp = isa<tosa::Conv2DOp>(op);
Location loc = op->getLoc();
Value input = op->getOperand(0);
Value weight = op->getOperand(1);
@@ -253,6 +254,24 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
DenseI64ArrayAttr strideTosaAttr = op.getStrideAttr();
DenseI64ArrayAttr dilationTosaAttr = op.getDilationAttr();
bool isQuantized = op.getQuantizationInfo().has_value();
+ int64_t group = 1;
+
+ if (auto convop = dyn_cast<tosa::Conv2DOp>(&op)) {
+ if (convop->getGroup().has_value())
+ group = convop->getGroup().value();
+ }
+
+ if (group > 1 && isConv2DOp &&
+ !std::is_same<LinalgConvOp, linalg::Conv2DNhwgcGfhwcOp>::value &&
+ !std::is_same<LinalgConvOp, linalg::Conv2DNhwcHwcfOp>::value)
+ return rewriter.notifyMatchFailure(
+ op, "tosa.conv ops should map to grouped convolution ops");
+
+ if (group == 1 && isConv2DOp &&
+ !std::is_same<LinalgConvOp, linalg::Conv2DNhwcFhwcOp>::value &&
+ !std::is_same<LinalgConvOp, linalg::Conv2DNhwcHwcfOp>::value)
+ return rewriter.notifyMatchFailure(
+ op, "tosa.conv ops should map to non-grouped convolution ops");
if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
return rewriter.notifyMatchFailure(
@@ -274,8 +293,6 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
strideTosaAttr.asArrayRef(), dilationTosaAttr.asArrayRef(),
inputSizeDims, kernelSizeDims, rewriter);
- auto weightShape = weightTy.getShape();
-
// Apply padding as necessary.
TypedAttr zeroAttr = rewriter.getZeroAttr(inputETy);
if (isQuantized) {
@@ -302,15 +319,64 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
pad.resize(pad.size() + 2, 0);
input = applyPad(loc, input, pad, zeroAttr, rewriter);
- if (4 == inputTy.getRank()) {
- // For 2D convolutions, we need to check if the target convolution op
- // wants a HWCF kernel layout.
- bool wantHwcf =
- isQuantized ? std::is_same_v<LinalgConvQOp, linalg::Conv2DNhwcHwcfQOp>
- : std::is_same_v<LinalgConvOp, linalg::Conv2DNhwcHwcfOp>;
- if (wantHwcf) {
- // Transpose the kernel to match dimension ordering of the linalg
- // convolution operation.
+ auto weightShape = weightTy.getShape();
+ SmallVector<int64_t> weightPerm;
+
+ auto resultShape = resultTy.getShape();
+ auto newResultTy = resultTy;
+
+ if (isConv2DOp && group > 1) {
+ // Map 4D-tensors to 5D tensors
+ auto inputShape = cast<ShapedType>(input.getType()).getShape();
+ SmallVector<int64_t, 5> newInputShape = {inputShape[0], inputShape[1],
+ inputShape[2], group,
+ inputShape[3] / group};
+
+ SmallVector<int64_t, 5> newWeightShape = {group, weightShape[0] / group,
+ weightShape[1], weightShape[2],
+ weightShape[3]};
+ input = rewriter.create<tosa::ReshapeOp>(
+ loc, RankedTensorType::get(newInputShape, inputETy), input,
+ rewriter.getDenseI64ArrayAttr(newInputShape));
+ weight = rewriter.create<tosa::ReshapeOp>(
+ loc, RankedTensorType::get(newWeightShape, weightTy.getElementType()),
+ weight, rewriter.getDenseI64ArrayAttr(newWeightShape));
+ } else {
+
+ if (4 == inputTy.getRank()) {
+ // For 2D convolutions, we need to check if the target convolution op
+ // wants a HWCF kernel layout.
+ bool wantHwcf =
+ isQuantized
+ ? std::is_same_v<LinalgConvQOp, linalg::Conv2DNhwcHwcfQOp>
+ : std::is_same_v<LinalgConvOp, linalg::Conv2DNhwcHwcfOp>;
+ if (wantHwcf) {
+ // Transpose the kernel to match dimension ordering of the linalg
+ // convolution operation.
+ // TODO(suderman): See if this can be efficiently folded - check
+ // whether the input is used anywhere else, if not fold the constant.
+ SmallVector<int64_t> weightPerm;
+ for (int i = 1; i < resultTy.getRank(); i++)
+ weightPerm.push_back(i);
+ weightPerm.push_back(0);
+
+ SmallVector<int64_t> newWeightShape;
+ for (auto dim : weightPerm)
+ newWeightShape.push_back(weightShape[dim]);
+ auto weightPermAttr = rewriter.getI64TensorAttr(weightPerm);
+ Value weightPermValue =
+ rewriter.create<arith::ConstantOp>(loc, weightPermAttr);
+ Type newWeightTy =
+ RankedTensorType::get(newWeightShape, weightTy.getElementType());
+ weight = rewriter.create<tosa::TransposeOp>(loc, newWeightTy, weight,
+ weightPermValue);
+ }
+ }
+
+ // For Conv3D transpose the kernel to match dimension ordering of the
+ // linalg convolution operation. Conv2D has a 1-1 mapping in linalg so
+ // better to map directly and then transpose later if desired.
+ if (5 == inputTy.getRank()) {
// TODO(suderman): See if this can be efficiently folded - check whether
// the input is used anywhere else, if not fold the constant.
SmallVector<int64_t> weightPerm;
@@ -331,27 +397,11 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
}
}
- // For Conv3D transpose the kernel to match dimension ordering of the linalg
- // convolution operation. Conv2D has a 1-1 mapping in linalg so better to
- // map directly and then transpose later if desired.
- if (5 == inputTy.getRank()) {
- // TODO(suderman): See if this can be efficiently folded - check whether
- // the input is used anywhere else, if not fold the constant.
- SmallVector<int64_t> weightPerm;
- for (int i = 1; i < resultTy.getRank(); i++)
- weightPerm.push_back(i);
- weightPerm.push_back(0);
-
- SmallVector<int64_t> newWeightShape;
- for (auto dim : weightPerm)
- newWeightShape.push_back(weightShape[dim]);
- auto weightPermAttr = rewriter.getI64TensorAttr(weightPerm);
- Value weightPermValue =
- rewriter.create<arith::ConstantOp>(loc, weightPermAttr);
- Type newWeightTy =
- RankedTensorType::get(newWeightShape, weightTy.getElementType());
- weight = rewriter.create<tosa::TransposeOp>(loc, newWeightTy, weight,
- weightPermValue);
+ if (isConv2DOp && group > 1) {
+ SmallVector<int64_t, 5> newResultShape{resultShape[0], resultShape[1],
+ resultShape[2], group,
+ resultShape[3] / group};
+ newResultTy = RankedTensorType::get(newResultShape, resultETy);
}
// Extract the attributes for convolution.
@@ -368,6 +418,13 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
Value broadcastBias =
linalgBroadcastAndMaybeExtSI(rewriter, loc, bias, biasEmptyTensor);
+ if (isConv2DOp && group > 1) {
+ broadcastBias = rewriter.create<tosa::ReshapeOp>(
+ loc, RankedTensorType::get(newResultTy.getShape(), resultETy),
+ broadcastBias, rewriter.getDenseI64ArrayAttr(newResultTy.getShape()));
+ }
+
+ Value conv;
if (isQuantized) {
auto quantizationInfo = *op.getQuantizationInfo();
auto iZp = rewriter.getI32IntegerAttr(quantizationInfo.getInputZp());
@@ -376,22 +433,25 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
auto iZpVal = rewriter.create<arith::ConstantOp>(loc, iZp);
auto kZpVal = rewriter.create<arith::ConstantOp>(loc, kZp);
- Value conv =
+ conv =
rewriter
.create<LinalgConvQOp>(
- loc, resultTy, ValueRange{input, weight, iZpVal, kZpVal},
+ loc, newResultTy, ValueRange{input, weight, iZpVal, kZpVal},
ValueRange{broadcastBias}, strideAttr, dilationAttr)
->getResult(0);
-
- rewriter.replaceOp(op, conv);
- return success();
+ } else {
+ conv = rewriter
+ .create<LinalgConvOp>(
+ loc, newResultTy, ValueRange{input, weight},
+ ValueRange{broadcastBias}, strideAttr, dilationAttr)
+ ->getResult(0);
}
- Value conv = rewriter
- .create<LinalgConvOp>(
- loc, resultTy, ValueRange{input, weight},
- ValueRange{broadcastBias}, strideAttr, dilationAttr)
- ->getResult(0);
+ if (isConv2DOp && group > 1) {
+ conv = rewriter.create<tosa::ReshapeOp>(
+ loc, RankedTensorType::get(resultShape, resultETy), conv,
+ rewriter.getDenseI64ArrayAttr(resultShape));
+ }
rewriter.replaceOp(op, conv);
return success();
@@ -1074,6 +1134,7 @@ void mlir::tosa::populateTosaToLinalgNamedConversionPatterns(
}
patterns->add<
// clang-format off
+ ConvConverter<tosa::Conv2DOp, linalg::Conv2DNhwgcGfhwcOp, linalg::Conv2DNhwgcGfhwcQOp>,
ConvConverter<tosa::Conv3DOp, linalg::Conv3DNdhwcDhwcfOp, linalg::Conv3DNdhwcDhwcfQOp>,
DepthwiseConvConverter,
MatMulConverter,
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index d93db1b237f316..cedfa8a5afd110 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -383,6 +383,19 @@ static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
}
}
+// Handles grouped convolution
+static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
+ Type outputType, Value input, Value weight,...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/108192
More information about the Mlir-commits
mailing list