[Mlir-commits] [mlir] 7ce53e3 - [mlir][tosa] Add tosa.conv3d lowering to Linalg

Rob Suderman llvmlistbot at llvm.org
Fri Jan 6 10:47:54 PST 2023

Author: Rob Suderman
Date: 2023-01-06T10:47:45-08:00
New Revision: 7ce53e31023dcf9d8fb95d172e20a35e60ebd821

URL: https://github.com/llvm/llvm-project/commit/7ce53e31023dcf9d8fb95d172e20a35e60ebd821
DIFF: https://github.com/llvm/llvm-project/commit/7ce53e31023dcf9d8fb95d172e20a35e60ebd821.diff

LOG: [mlir][tosa] Add tosa.conv3d lowering to Linalg

Conv3D has an existing linalg operation for floating point. Adding a quantized
variant and corresponding lowering from TOSA. Numerical correctness was validated
using the TOSA conformance tests.

Reviewed By: jpienaar

Differential Revision: https://reviews.llvm.org/D140919




diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 9d771fc1b738a..249f60ba02979 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -2001,6 +2001,145 @@ structured_op: !LinalgStructuredOpConfig
                 - !ScalarExpression
                   scalar_arg: K
 --- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+  name: conv_3d_ndhwc_dhwcf_q
+  cpp_class_name: Conv3DNdhwcDhwcfQOp
+  doc: |-
+    Performs 3-D convolution with zero point offsets.
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output. This includes the zero
+    point offsets common to quantized operations.
+  implements:
+  - LinalgConvolutionOpInterface
+structured_op: !LinalgStructuredOpConfig
+  args:
+  - !LinalgOperandDefConfig
+    name: I
+    kind: input_tensor
+    type_var: T1
+    shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12,
+      s13, s14] -> (s0, s1 * s2 + s3 * s4, s5 * s6 + s7 * s8, s9 * s10 + s11 * s12,
+      s13)>
+  - !LinalgOperandDefConfig
+    name: K
+    kind: input_tensor
+    type_var: T2
+    shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12,
+      s13, s14] -> (s3, s7, s11, s13, s14)>
+  - !LinalgOperandDefConfig
+    name: IZp
+    kind: scalar
+    type_var: I32
+  - !LinalgOperandDefConfig
+    name: KZp
+    kind: scalar
+    type_var: I32
+  - !LinalgOperandDefConfig
+    name: O
+    kind: output_tensor
+    type_var: U
+    shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12,
+      s13, s14] -> (s0, s1, s5, s9, s14)>
+  - !LinalgOperandDefConfig
+    name: strides
+    kind: index_attr
+    index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11,
+      s12, s13, s14] -> (s2, s6, s10)>
+    default_indices:
+    - 1
+    - 1
+    - 1
+  - !LinalgOperandDefConfig
+    name: dilations
+    kind: index_attr
+    index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11,
+      s12, s13, s14] -> (s4, s8, s12)>
+    default_indices:
+    - 1
+    - 1
+    - 1
+  indexing_maps: !LinalgIndexingMapsConfig
+    static_indexing_maps:
+    - affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8)[s0, s1, s2, s3, s4, s5, s6,
+      s7, s8, s9, s10, s11, s12, s13, s14] -> (d0, d1 * s2 + d5 * s4, d2 * s6 + d6
+      * s8, d3 * s10 + d7 * s12, d8)>
+    - affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8)[s0, s1, s2, s3, s4, s5, s6,
+      s7, s8, s9, s10, s11, s12, s13, s14] -> (d5, d6, d7, d8, d4)>
+    - affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8)[s0, s1, s2, s3, s4, s5, s6,
+      s7, s8, s9, s10, s11, s12, s13, s14] -> ()>
+    - affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8)[s0, s1, s2, s3, s4, s5, s6,
+      s7, s8, s9, s10, s11, s12, s13, s14] -> ()>
+    - affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8)[s0, s1, s2, s3, s4, s5, s6,
+      s7, s8, s9, s10, s11, s12, s13, s14] -> (d0, d1, d2, d3, d4)>
+  iterator_types:
+  - parallel
+  - parallel
+  - parallel
+  - parallel
+  - parallel
+  - reduction
+  - reduction
+  - reduction
+  - reduction
+  assignments:
+  - !ScalarAssign
+    arg: O
+    value: !ScalarExpression
+      scalar_fn:
+        kind: binary
+        fn_name: add
+        operands:
+        - !ScalarExpression
+          scalar_arg: O
+        - !ScalarExpression
+          scalar_fn:
+            kind: binary
+            fn_name: mul
+            operands:
+            - !ScalarExpression
+              scalar_fn:
+                kind: binary
+                fn_name: sub
+                operands:
+                - !ScalarExpression
+                  scalar_fn:
+                    kind: type
+                    fn_name: cast_signed
+                    type_var: U
+                    operands:
+                    - !ScalarExpression
+                      scalar_arg: I
+                - !ScalarExpression
+                  scalar_fn:
+                    kind: type
+                    fn_name: cast_signed
+                    type_var: U
+                    operands:
+                    - !ScalarExpression
+                      scalar_arg: IZp
+            - !ScalarExpression
+              scalar_fn:
+                kind: binary
+                fn_name: sub
+                operands:
+                - !ScalarExpression
+                  scalar_fn:
+                    kind: type
+                    fn_name: cast_signed
+                    type_var: U
+                    operands:
+                    - !ScalarExpression
+                      scalar_arg: K
+                - !ScalarExpression
+                  scalar_fn:
+                    kind: type
+                    fn_name: cast_signed
+                    type_var: U
+                    operands:
+                    - !ScalarExpression
+                      scalar_arg: KZp
+--- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: depthwise_conv_1d_nwc_wc
   cpp_class_name: DepthwiseConv1DNwcWcOp
@@ -4441,3 +4580,4 @@ structured_op: !LinalgStructuredOpConfig
                           scalar_const: '2.3283063999999999E-10 : f64'
             - !ScalarExpression
               scalar_arg: min

diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 52f9d3416e2a6..f9732efaa27eb 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -99,47 +99,40 @@ static mlir::Value getConvOutputDim(Location loc, Value inputDim,
 // Creates a vector of the dynamic output dims for Conv2D and Depthwise_Conv2D
-static SmallVector<Value>
-inferDynamicDimsForConv(Location loc, Value input, Value weight,
-                        ShapedType resultTy, DenseI64ArrayAttr padAttr,
-                        DenseI64ArrayAttr strideAttr,
-                        DenseI64ArrayAttr dilationAttr, int64_t weightHDim,
-                        int64_t weightWDim, OpBuilder &rewriter) {
+static SmallVector<Value> inferDynamicDimsForConv(
+    Location loc, Value input, Value weight, ShapedType resultTy,
+    ArrayRef<int64_t> padAttr, ArrayRef<int64_t> strideAttr,
+    ArrayRef<int64_t> dilationAttr, ArrayRef<int64_t> inputSizeDims,
+    ArrayRef<int64_t> kernelSizeDims, OpBuilder &rewriter) {
   ShapedType inputTy = input.getType().cast<ShapedType>();
   Type inputETy = inputTy.getElementType();
   int64_t inputRank = inputTy.getRank();
-  int64_t heightDim = 1;
-  int64_t weightDim = 2;
   SmallVector<Value> dynDims;
-  for (int i = 0; i < inputRank; i++) {
-    if (inputTy.isDynamicDim(i) && i != heightDim && i != weightDim)
-      dynDims[i] = rewriter.create<tensor::DimOp>(loc, input, i);
-  }
-  // Dynamic input height
-  if (inputTy.isDynamicDim(heightDim)) {
-    Value initHDim =
-        rewriter.create<tensor::DimOp>(loc, input, heightDim).getResult();
-    Value kernelHDim =
-        rewriter.create<tensor::DimOp>(loc, weight, weightHDim).getResult();
-    // H = F(IH, pad_top, pad_bottom, dilation_y, KH, stride_y)
-    dynDims[heightDim] =
-        getConvOutputDim(loc, initHDim, padAttr[0], padAttr[1], kernelHDim,
-                         strideAttr[0], dilationAttr[0], inputETy, rewriter);
+  for (uint32_t i = 0, s = inputSizeDims.size(); i < s; ++i) {
+    int64_t inputDim = inputSizeDims[i];
+    int64_t kernelDim = kernelSizeDims[i];
+    if (inputTy.isDynamicDim(inputDim)) {
+      auto padTop = padAttr[i * 2];
+      auto padBottom = padAttr[i * 2 + 1];
+      auto stride = strideAttr[i];
+      auto dilation = dilationAttr[i];
+      Value initDynDim = rewriter.create<tensor::DimOp>(loc, input, inputDim);
+      Value kernelDynDim =
+          rewriter.create<tensor::DimOp>(loc, weight, kernelDim);
+      // H = F(IH, pad_top, pad_bottom, dilation_y, KH, stride_y)
+      dynDims[inputDim] =
+          getConvOutputDim(loc, initDynDim, padTop, padBottom, kernelDynDim,
+                           stride, dilation, inputETy, rewriter);
+    }
-  // Dynamic input weight
-  if (inputTy.isDynamicDim(weightDim)) {
-    Value initWDim =
-        rewriter.create<tensor::DimOp>(loc, input, weightDim).getResult();
-    Value kernelWDim =
-        rewriter.create<tensor::DimOp>(loc, weight, weightWDim).getResult();
-    // W = F(IW, pad_left, pad_right, dilation_x, KW, stride_x)
-    dynDims[weightDim] =
-        getConvOutputDim(loc, initWDim, padAttr[2], padAttr[3], kernelWDim,
-                         strideAttr[1], dilationAttr[1], inputETy, rewriter);
+  // Get the batch/channels dimensions.
+  for (int i = 0; i < inputRank; i++) {
+    if (inputTy.isDynamicDim(i) && !dynDims[i])
+      dynDims[i] = rewriter.create<tensor::DimOp>(loc, input, i);
   SmallVector<Value> filteredDims = condenseValues(dynDims);
@@ -161,21 +154,23 @@ static void createDepthwiseConvCollapseMap(
 namespace {
-class ConvConverter : public OpConversionPattern<tosa::Conv2DOp> {
+template <typename TosaConvOp, typename LinalgConvOp, typename LinalgConvQOp>
+class ConvConverter : public OpConversionPattern<TosaConvOp> {
-  using OpConversionPattern<tosa::Conv2DOp>::OpConversionPattern;
+  using OpConversionPattern<TosaConvOp>::OpConversionPattern;
-  matchAndRewrite(tosa::Conv2DOp op, OpAdaptor adaptor,
+  matchAndRewrite(TosaConvOp op, typename TosaConvOp::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const final {
     Location loc = op->getLoc();
     Value input = op->getOperand(0);
     Value weight = op->getOperand(1);
     Value bias = op->getOperand(2);
-    ShapedType inputTy = input.getType().cast<ShapedType>();
-    ShapedType weightTy = weight.getType().cast<ShapedType>();
-    ShapedType biasTy = bias.getType().cast<ShapedType>();
-    ShapedType resultTy = op->getResult(0).getType().cast<ShapedType>();
+    ShapedType inputTy = input.getType().template cast<ShapedType>();
+    ShapedType weightTy = weight.getType().template cast<ShapedType>();
+    ShapedType biasTy = bias.getType().template cast<ShapedType>();
+    ShapedType resultTy =
+        op->getResult(0).getType().template cast<ShapedType>();
     Type inputETy = inputTy.getElementType();
     Type resultETy = resultTy.getElementType();
@@ -183,7 +178,7 @@ class ConvConverter : public OpConversionPattern<tosa::Conv2DOp> {
     DenseI64ArrayAttr padAttr = op.getPadAttr();
     DenseI64ArrayAttr strideTosaAttr = op.getStrideAttr();
     DenseI64ArrayAttr dilationTosaAttr = op.getDilationAttr();
-    bool isQuantized = op->hasAttr("quantization_info");
+    bool isQuantized = op.getQuantizationInfo().has_value();
     if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
       return rewriter.notifyMatchFailure(
@@ -193,17 +188,24 @@ class ConvConverter : public OpConversionPattern<tosa::Conv2DOp> {
       return rewriter.notifyMatchFailure(
           op, "tosa.conv ops does not support unsigned integer input");
+    llvm::SmallVector<int64_t> inputSizeDims;
+    llvm::SmallVector<int64_t> kernelSizeDims;
+    for (int i = 1; i < resultTy.getRank() - 1; i++) {
+      inputSizeDims.push_back(i);
+      kernelSizeDims.push_back(i);
+    }
     SmallVector<Value> filteredDims = inferDynamicDimsForConv(
-        loc, input, weight, resultTy, padAttr, strideTosaAttr, dilationTosaAttr,
-        /*weightHDim=*/1, /*weightWDim=*/2, rewriter);
+        loc, input, weight, resultTy, padAttr.asArrayRef(),
+        strideTosaAttr.asArrayRef(), dilationTosaAttr.asArrayRef(),
+        inputSizeDims, kernelSizeDims, rewriter);
     auto weightShape = weightTy.getShape();
     // Apply padding as necessary.
     Attribute zeroAttr = rewriter.getZeroAttr(inputETy);
     if (isQuantized) {
-      auto quantizationInfo =
-          op->getAttr("quantization_info").cast<tosa::ConvOpQuantizationAttr>();
+      auto quantizationInfo = *op.getQuantizationInfo();
       int64_t iZp = quantizationInfo.getInputZp();
       int64_t intMin =
@@ -230,11 +232,15 @@ class ConvConverter : public OpConversionPattern<tosa::Conv2DOp> {
     // convolution operation.
     // TODO(suderman): See if this can be efficiently folded - check whether
     // the input is used anywhere else, if not fold the constant.
-    SmallVector<int64_t> weightPerm{1, 2, 3, 0};
-    SmallVector<int64_t> newWeightShape{weightShape[1], weightShape[2],
-                                        weightShape[3], weightShape[0]};
-    auto weightPermAttr = DenseIntElementsAttr::get(
-        RankedTensorType::get({4}, rewriter.getI64Type()), weightPerm);
+    SmallVector<int64_t> weightPerm;
+    for (int i = 1; i < resultTy.getRank(); i++)
+      weightPerm.push_back(i);
+    weightPerm.push_back(0);
+    SmallVector<int64_t> newWeightShape;
+    for (auto dim : weightPerm)
+      newWeightShape.push_back(weightShape[dim]);
+    auto weightPermAttr = rewriter.getI64TensorAttr(weightPerm);
     Value weightPermValue =
         rewriter.create<arith::ConstantOp>(loc, weightPermAttr);
     Type newWeightTy =
@@ -256,16 +262,15 @@ class ConvConverter : public OpConversionPattern<tosa::Conv2DOp> {
     ArrayRef<int64_t> dilation = dilationTosaAttr;
     // Create the convolution op.
-    auto strideAttr = DenseIntElementsAttr::get(
-        RankedTensorType::get({2}, rewriter.getI64Type()), stride);
-    auto dilationAttr = DenseIntElementsAttr::get(
-        RankedTensorType::get({2}, rewriter.getI64Type()), dilation);
+    auto strideAttr = rewriter.getI64TensorAttr(stride);
+    auto dilationAttr = rewriter.getI64TensorAttr(dilation);
     // Create maps for the bias broadcasting
     SmallVector<AffineMap, 4> indexingMaps;
         /*dimCount=*/resultTy.getRank(), /*symbolCount=*/0,
-        {rewriter.getAffineDimExpr(3)}, rewriter.getContext()));
+        {rewriter.getAffineDimExpr(resultTy.getRank() - 1)},
+        rewriter.getContext()));
@@ -273,8 +278,7 @@ class ConvConverter : public OpConversionPattern<tosa::Conv2DOp> {
         loc, resultTy.getShape(), resultETy, filteredDims);
     if (isQuantized) {
-      auto quantizationInfo =
-          op->getAttr("quantization_info").cast<tosa::ConvOpQuantizationAttr>();
+      auto quantizationInfo = *op.getQuantizationInfo();
       auto iZp = rewriter.getI32IntegerAttr(quantizationInfo.getInputZp());
       auto kZp = rewriter.getI32IntegerAttr(quantizationInfo.getWeightZp());
@@ -282,7 +286,7 @@ class ConvConverter : public OpConversionPattern<tosa::Conv2DOp> {
       auto kZpVal = rewriter.create<arith::ConstantOp>(loc, kZp);
       Value conv =
-              .create<linalg::Conv2DNhwcHwcfQOp>(
+              .create<LinalgConvQOp>(
                   loc, resultTy, ValueRange{input, weight, iZpVal, kZpVal},
                   ValueRange{zeroTensor}, strideAttr, dilationAttr)
@@ -304,7 +308,7 @@ class ConvConverter : public OpConversionPattern<tosa::Conv2DOp> {
     Value conv = rewriter
-                     .create<linalg::Conv2DNhwcHwcfOp>(
+                     .create<LinalgConvOp>(
                          loc, resultTy, ValueRange{input, weight},
                          ValueRange{zeroTensor}, strideAttr, dilationAttr)
@@ -358,8 +362,10 @@ class DepthwiseConvConverter
     // Compute output dynamic dims
     SmallVector<Value> filteredDims = inferDynamicDimsForConv(
-        loc, input, weight, resultTy, padAttr, strideTosaAttr, dilationTosaAttr,
-        0, 1, rewriter);
+        loc, input, weight, resultTy, padAttr.asArrayRef(),
+        strideTosaAttr.asArrayRef(), dilationTosaAttr.asArrayRef(),
+        /*inputSizeDims=*/{1, 2},
+        /*kernelSizeDims=*/{0, 1}, rewriter);
     bool isQuantized = op->hasAttr("quantization_info");
     IntegerAttr iZp;
@@ -408,11 +414,8 @@ class DepthwiseConvConverter
     ArrayRef<int64_t> dilation = dilationTosaAttr;
     // Create the convolution op.
-    auto strideAttr = DenseIntElementsAttr::get(
-        RankedTensorType::get({2}, rewriter.getI64Type()), stride);
-    auto dilationAttr = DenseIntElementsAttr::get(
-        RankedTensorType::get({2}, rewriter.getI64Type()), dilation);
+    auto strideAttr = rewriter.getI64TensorAttr(stride);
+    auto dilationAttr = rewriter.getI64TensorAttr(dilation);
     ShapedType linalgConvTy =
         RankedTensorType::get({resultShape[0], resultShape[1], resultShape[2],
                                weightShape[2], weightShape[3]},
@@ -610,8 +613,7 @@ class FullyConnectedConverter
     SmallVector<int64_t> permutation{1, 0};
-    auto permutationAttr = DenseIntElementsAttr::get(
-        RankedTensorType::get({2}, rewriter.getI64Type()), permutation);
+    auto permutationAttr = rewriter.getI64TensorAttr(permutation);
     Value permutationValue =
         rewriter.create<arith::ConstantOp>(loc, permutationAttr);
@@ -966,7 +968,8 @@ void mlir::tosa::populateTosaToLinalgNamedConversionPatterns(
     RewritePatternSet *patterns) {
       // clang-format off
-      ConvConverter,
+      ConvConverter<tosa::Conv2DOp, linalg::Conv2DNhwcHwcfOp, linalg::Conv2DNhwcHwcfQOp>,
+      ConvConverter<tosa::Conv3DOp, linalg::Conv3DNdhwcDhwcfOp, linalg::Conv3DNdhwcDhwcfQOp>,

diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp
index ac67faa4b3ca4..4c941a109ed84 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp
@@ -51,6 +51,7 @@ struct TosaToLinalgNamed
     // Not every TOSA op can be legalized to linalg.
+    target.addIllegalOp<tosa::Conv3DOp>();

diff  --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index 8bab1607b4ed4..4402624c174b2 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -150,6 +150,7 @@ def quantized_batch_matmul(A=TensorDef(T1, Batch, S.M, S.K),
                        TypeFn.cast_signed(U, AZp)) * (TypeFn.cast_signed(
                            U, B[D.b, D.k, D.n]) - TypeFn.cast_signed(U, BZp))
 def batch_reduce_matmul(A=TensorDef(T1, Batch, S.M, S.K),
                         B=TensorDef(T2, Batch, S.K, S.N),
@@ -162,8 +163,9 @@ def batch_reduce_matmul(A=TensorDef(T1, Batch, S.M, S.K),
   domain(D.b, D.m, D.n, D.k)
-  C[D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k] * TypeFn.cast_signed(
-    U, B[D.b, D.k, D.n]))
+  C[D.m, D.n] += TypeFn.cast_signed(
+      U, A[D.b, D.m, D.k] * TypeFn.cast_signed(U, B[D.b, D.k, D.n]))
 def matvec(A=TensorDef(T1, S.M, S.N),
@@ -283,6 +285,7 @@ def conv_1d_nwc_wcf(I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.C),
       U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c]) * TypeFn.cast_signed(
           U, K[D.kw, D.c, D.f])
 def conv_1d_ncw_fcw(I=TensorDef(T1, S.N, S.C, S.OW * S.SW + S.KW * S.DW),
                     K=TensorDef(T2, S.F, S.C, S.KW),
@@ -304,6 +307,7 @@ def conv_1d_ncw_fcw(I=TensorDef(T1, S.N, S.C, S.OW * S.SW + S.KW * S.DW),
       U, I[D.n, D.c, D.ow * S.SW + D.kw * S.DW]) * TypeFn.cast_signed(
           U, K[D.f, D.c, D.kw])
 def conv_2d_nhwc_hwcf(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH,
                                   S.OW * S.SW + S.KW * S.DW, S.C),
@@ -400,13 +404,15 @@ def conv_2d_nchw_fchw(I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH,
       U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW +
            D.kw * S.DW]) * TypeFn.cast_signed(U, K[D.f, D.c, D.kh, D.kw])
-def conv_2d_ngchw_fgchw(I=TensorDef(T1, S.N, S.G, S.C, S.OH * S.SH + S.KH * S.DH,
-                                  S.OW * S.SW + S.KW * S.DW),
-                      K=TensorDef(T2, S.FG, S.G, S.C, S.KH, S.KW),
-                      O=TensorDef(U, S.N, S.FG, S.G, S.OH, S.OW, output=True),
-                      strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
-                      dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
+def conv_2d_ngchw_fgchw(I=TensorDef(T1, S.N, S.G, S.C,
+                                    S.OH * S.SH + S.KH * S.DH,
+                                    S.OW * S.SW + S.KW * S.DW),
+                        K=TensorDef(T2, S.FG, S.G, S.C, S.KH, S.KW),
+                        O=TensorDef(U, S.N, S.FG, S.G, S.OH, S.OW, output=True),
+                        strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+                        dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
   """Performs 2-D grouped convolution.
@@ -420,7 +426,8 @@ def conv_2d_ngchw_fgchw(I=TensorDef(T1, S.N, S.G, S.C, S.OH * S.SH + S.KH * S.DH
   domain(D.n, D.g, D.fg, D.oh, D.ow, D.c, D.kh, D.kw)
   O[D.n, D.g, D.fg, D.oh, D.ow] += TypeFn.cast_signed(
       U, I[D.n, D.g, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW +
-          D.kw * S.DW]) * TypeFn.cast_signed(U, K[D.g, D.fg, D.c, D.kh, D.kw])
+           D.kw * S.DW]) * TypeFn.cast_signed(U, K[D.g, D.fg, D.c, D.kh, D.kw])
 def conv_3d_ndhwc_dhwcf(I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD,
@@ -449,6 +456,43 @@ def conv_3d_ndhwc_dhwcf(I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD,
                U, K[D.kd, D.kh, D.kw, D.c, D.f])
+ at linalg_structured_op
+def conv_3d_ndhwc_dhwcf_q(I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD,
+                                      S.OH * S.SH + S.KH * S.DH,
+                                      S.OW * S.SW + S.KW * S.DW, S.C),
+                          K=TensorDef(T2, S.KD, S.KH, S.KW, S.C, S.F),
+                          IZp=ScalarDef(I32),
+                          KZp=ScalarDef(I32),
+                          O=TensorDef(U,
+                                      S.N,
+                                      S.OD,
+                                      S.OH,
+                                      S.OW,
+                                      S.F,
+                                      output=True),
+                          strides=IndexAttrDef(S.SD,
+                                               S.SH,
+                                               S.SW,
+                                               default=[1, 1, 1]),
+                          dilations=IndexAttrDef(S.DD,
+                                                 S.DH,
+                                                 S.DW,
+                                                 default=[1, 1, 1])):
+  """Performs 3-D convolution with zero point offsets.
+  Numeric casting is performed on the operands to the inner multiply, promoting
+  them to the same data type as the accumulator/output. This includes the zero
+  point offsets common to quantized operations.
+  """
+  implements(ConvolutionOpInterface)
+  domain(D.n, D.od, D.oh, D.ow, D.f, D.kd, D.kh, D.kw, D.c)
+  O[D.n, D.od, D.oh, D.ow, D.f] += (TypeFn.cast_signed(
+      U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH,
+           D.ow * S.SW + D.kw * S.DW, D.c]) - TypeFn.cast_signed(U, IZp)) * (
+               TypeFn.cast_signed(U, K[D.kd, D.kh, D.kw, D.c, D.f]) -
+               TypeFn.cast_signed(U, KZp))
 def depthwise_conv_1d_nwc_wc(I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW,
@@ -517,7 +561,8 @@ def depthwise_conv_2d_nhwc_hwc(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH,
-def depthwise_conv_2d_nchw_chw(I=TensorDef(T1, S.N, S.IC, S.OH * S.SH + S.KH * S.DH,
+def depthwise_conv_2d_nchw_chw(I=TensorDef(T1, S.N, S.IC,
+                                           S.OH * S.SH + S.KH * S.DH,
                                            S.OW * S.SW + S.KW * S.DW),
                                K=TensorDef(T2, S.IC, S.KH, S.KW),
@@ -539,7 +584,8 @@ def depthwise_conv_2d_nchw_chw(I=TensorDef(T1, S.N, S.IC, S.OH * S.SH + S.KH * S
   domain(D.n, D.oh, D.ow, D.ic, D.kh, D.kw)
   O[D.n, D.ic, D.oh, D.ow] += TypeFn.cast_signed(
-      U, I[D.n, D.ic, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW]) * TypeFn.cast_signed(U, K[D.ic, D.kh, D.kw])
+      U, I[D.n, D.ic, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW +
+           D.kw * S.DW]) * TypeFn.cast_signed(U, K[D.ic, D.kh, D.kw])
@@ -642,7 +688,11 @@ def depthwise_conv_3d_ndhwc_dhwc(I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD,
                                              S.OH * S.SH + S.KH * S.DH,
                                              S.OW * S.SW + S.KW * S.DW, S.IC),
                                  K=TensorDef(T2, S.KD, S.KH, S.KW, S.IC),
-                                 O=TensorDef(U, S.N, S.OD, S.OH, S.OW,
+                                 O=TensorDef(U,
+                                             S.N,
+                                             S.OD,
+                                             S.OH,
+                                             S.OW,
@@ -667,12 +717,17 @@ def depthwise_conv_3d_ndhwc_dhwc(I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD,
-def depthwise_conv_3d_ndhwc_dhwcm(I=TensorDef(T1,
-                                              S.N, S.OD * S.SD + S.KD * S.DD,
+def depthwise_conv_3d_ndhwc_dhwcm(I=TensorDef(T1, S.N,
+                                              S.OD * S.SD + S.KD * S.DD,
                                               S.OH * S.SH + S.KH * S.DH,
                                               S.OW * S.SW + S.KW * S.DW, S.IC),
                                   K=TensorDef(T2, S.KD, S.KH, S.KW, S.IC, S.CM),
-                                  O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.CM,
+                                  O=TensorDef(U,
+                                              S.N,
+                                              S.OD,
+                                              S.OH,
+                                              S.OW,
+                                              S.CM,

diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index e6427174b66a3..5a28597052c58 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -603,3 +603,55 @@ func.func @depthwise_conv2d_dyn_w_h(%arg0: tensor<2x?x?x3xf32>, %arg1: tensor<3x
   %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = array<i64: 1, 2, 3, 4>, dilation = array<i64: 2, 1>, stride = array<i64: 1, 2>} : (tensor<2x?x?x3xf32>, tensor<3x6x3x5xf32>, tensor<15xf32>) -> tensor<2x?x?x15xf32>
+// -----
+// CHECK-LABEL: @conv3d_f32
+func.func @conv3d_f32(%input: tensor<1x49x48x47x27xf32>, %weights: tensor<28x3x4x5x27xf32>, %bias: tensor<28xf32>) -> () {
+  // CHECK-DAG: %[[PERMS:.+]] = arith.constant dense<[1, 2, 3, 4, 0]>
+  // CHECK-DAG: %[[TRANSPOSE:.+]] = "tosa.transpose"(%arg1, %[[PERMS]])
+  // CHECK-DAG: %[[EMPTY:.+]] = tensor.empty()
+  // CHECK-DAG: %[[ZERO:.+]] = arith.constant 0
+  // CHECK-DAG: %[[FILL:.+]] = linalg.fill ins(%[[ZERO]] : f32) outs(%[[EMPTY]] : tensor<1x47x45x43x28xf32>)
+  // CHECK-DAG: %[[EMPTY:.+]] = tensor.empty()
+  // CHECK-DAG: %[[CONV3D:.+]] = linalg.conv_3d_ndhwc_dhwcf
+  // CHECK-SAME: {dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>}
+  // CHECK-SAME: ins(%arg0, %[[TRANSPOSE]] : tensor<1x49x48x47x27xf32>, tensor<3x4x5x27x28xf32>)
+  // CHECK-SAME: outs(%[[FILL]] : tensor<1x47x45x43x28xf32>) -> tensor<1x47x45x43x28xf32>
+  // CHECK: %[[GENERIC:.+]] = linalg.generic
+  // CHECK-SAME: {indexing_maps = [#map, #map1, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]}
+  // CHECK-SAME: ins(%arg2, %[[CONV3D]] : tensor<28xf32>, tensor<1x47x45x43x28xf32>)
+  // CHECK--SAME: outs(%[[EMPTY]] : tensor<1x47x45x43x28xf32>) {
+  // CHECK: ^bb0(%[[A1:.+]]: f32, %[[A2:.+]]: f32, %{{.+}}: f32):
+  // CHECK: %[[ADD:.+]] = arith.addf %[[A1]], %[[A2]] : f32
+  // CHECK: linalg.yield %[[ADD]]
+  %0 = "tosa.conv3d"(%input, %weights, %bias) {pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>, dilation = array<i64: 1, 1, 1>} : (tensor<1x49x48x47x27xf32>, tensor<28x3x4x5x27xf32>, tensor<28xf32>)  -> tensor<1x47x45x43x28xf32>
+  return
+// -----
+// CHECK-LABEL: @conv3d_i8
+func.func @conv3d_i8(%input: tensor<1x49x48x47x27xi8>, %weights: tensor<28x3x4x5x27xi8>, %bias: tensor<28xi32>) -> () {
+    // CHECK-DAG: %[[PERMS:.+]] = arith.constant dense<[1, 2, 3, 4, 0]>
+  // CHECK-DAG: %[[TRANSPOSE:.+]] = "tosa.transpose"(%arg1, %[[PERMS]])
+  // CHECK-DAG: %[[EMPTY:.+]] = tensor.empty()
+  // CHECK-DAG: %[[ZERO:.+]] = arith.constant 0
+  // CHECK-DAG: %[[FILL:.+]] = linalg.fill ins(%[[ZERO]] : i32) outs(%[[EMPTY]] : tensor<1x47x45x43x28xi32>)
+  // CHECK-DAG: %[[EMPTY:.+]] = tensor.empty()
+  // CHECK-DAG: %[[IZP:.+]] = arith.constant -128 : i32
+  // CHECK-DAG: %[[FZP:.+]] = arith.constant 42 : i32
+  // CHECK-DAG: %[[CONV3D:.+]] = linalg.conv_3d_ndhwc_dhwcf_q
+  // CHECK-SAME: {dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>}
+  // CHECK-SAME: ins(%arg0, %[[TRANSPOSE]], %[[IZP]], %[[FZP]] : tensor<1x49x48x47x27xi8>, tensor<3x4x5x27x28xi8>, i32, i32)
+  // CHECK-SAME: outs(%[[FILL]] : tensor<1x47x45x43x28xi32>) -> tensor<1x47x45x43x28xi32>
+  // CHECK: %[[GENERIC:.+]] = linalg.generic
+  // CHECK-SAME: {indexing_maps = [#map, #map1, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]}
+  // CHECK-SAME: ins(%arg2, %[[CONV3D]] : tensor<28xi32>, tensor<1x47x45x43x28xi32>)
+  // CHECK--SAME: outs(%[[EMPTY]] : tensor<1x47x45x43x28xi32>) {
+  // CHECK: ^bb0(%[[A1:.+]]: i32, %[[A2:.+]]: i32, %{{.+}}: i32):
+  // CHECK: %[[ADD:.+]] = arith.addi %[[A1]], %[[A2]] : i32
+  // CHECK: linalg.yield %[[ADD]]
+  %0 = "tosa.conv3d"(%input, %weights, %bias) {pad = array<i64: 0, 0, 0, 0, 0, 0>, quantization_info = #tosa.conv_quant<input_zp = -128, weight_zp = 42>, stride = array<i64: 1, 1, 1>, dilation = array<i64: 1, 1, 1>} : (tensor<1x49x48x47x27xi8>, tensor<28x3x4x5x27xi8>, tensor<28xi32>)  -> tensor<1x47x45x43x28xi32>
+  return


More information about the Mlir-commits mailing list