[Mlir-commits] [mlir] 7ce53e3 - [mlir][tosa] Add tosa.conv3d lowering to Linalg
Rob Suderman
llvmlistbot at llvm.org
Fri Jan 6 10:47:54 PST 2023
Author: Rob Suderman
Date: 2023-01-06T10:47:45-08:00
New Revision: 7ce53e31023dcf9d8fb95d172e20a35e60ebd821
URL: https://github.com/llvm/llvm-project/commit/7ce53e31023dcf9d8fb95d172e20a35e60ebd821
DIFF: https://github.com/llvm/llvm-project/commit/7ce53e31023dcf9d8fb95d172e20a35e60ebd821.diff
LOG: [mlir][tosa] Add tosa.conv3d lowering to Linalg
Conv3D has an existing linalg operation for floating point. Adding a quantized
variant and corresponding lowering from TOSA. Numerical correctness was validated
using the TOSA conformance tests.
Reviewed By: jpienaar
Differential Revision: https://reviews.llvm.org/D140919
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp
mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 9d771fc1b738a..249f60ba02979 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -2001,6 +2001,145 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_arg: K
--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+ name: conv_3d_ndhwc_dhwcf_q
+ cpp_class_name: Conv3DNdhwcDhwcfQOp
+ doc: |-
+ Performs 3-D convolution with zero point offsets.
+
+ Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output. This includes the zero
+ point offsets common to quantized operations.
+ implements:
+ - LinalgConvolutionOpInterface
+structured_op: !LinalgStructuredOpConfig
+ args:
+ - !LinalgOperandDefConfig
+ name: I
+ kind: input_tensor
+ type_var: T1
+ shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12,
+ s13, s14] -> (s0, s1 * s2 + s3 * s4, s5 * s6 + s7 * s8, s9 * s10 + s11 * s12,
+ s13)>
+ - !LinalgOperandDefConfig
+ name: K
+ kind: input_tensor
+ type_var: T2
+ shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12,
+ s13, s14] -> (s3, s7, s11, s13, s14)>
+ - !LinalgOperandDefConfig
+ name: IZp
+ kind: scalar
+ type_var: I32
+ - !LinalgOperandDefConfig
+ name: KZp
+ kind: scalar
+ type_var: I32
+ - !LinalgOperandDefConfig
+ name: O
+ kind: output_tensor
+ type_var: U
+ shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12,
+ s13, s14] -> (s0, s1, s5, s9, s14)>
+ - !LinalgOperandDefConfig
+ name: strides
+ kind: index_attr
+ index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11,
+ s12, s13, s14] -> (s2, s6, s10)>
+ default_indices:
+ - 1
+ - 1
+ - 1
+ - !LinalgOperandDefConfig
+ name: dilations
+ kind: index_attr
+ index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11,
+ s12, s13, s14] -> (s4, s8, s12)>
+ default_indices:
+ - 1
+ - 1
+ - 1
+ indexing_maps: !LinalgIndexingMapsConfig
+ static_indexing_maps:
+ - affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8)[s0, s1, s2, s3, s4, s5, s6,
+ s7, s8, s9, s10, s11, s12, s13, s14] -> (d0, d1 * s2 + d5 * s4, d2 * s6 + d6
+ * s8, d3 * s10 + d7 * s12, d8)>
+ - affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8)[s0, s1, s2, s3, s4, s5, s6,
+ s7, s8, s9, s10, s11, s12, s13, s14] -> (d5, d6, d7, d8, d4)>
+ - affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8)[s0, s1, s2, s3, s4, s5, s6,
+ s7, s8, s9, s10, s11, s12, s13, s14] -> ()>
+ - affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8)[s0, s1, s2, s3, s4, s5, s6,
+ s7, s8, s9, s10, s11, s12, s13, s14] -> ()>
+ - affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8)[s0, s1, s2, s3, s4, s5, s6,
+ s7, s8, s9, s10, s11, s12, s13, s14] -> (d0, d1, d2, d3, d4)>
+ iterator_types:
+ - parallel
+ - parallel
+ - parallel
+ - parallel
+ - parallel
+ - reduction
+ - reduction
+ - reduction
+ - reduction
+ assignments:
+ - !ScalarAssign
+ arg: O
+ value: !ScalarExpression
+ scalar_fn:
+ kind: binary
+ fn_name: add
+ operands:
+ - !ScalarExpression
+ scalar_arg: O
+ - !ScalarExpression
+ scalar_fn:
+ kind: binary
+ fn_name: mul
+ operands:
+ - !ScalarExpression
+ scalar_fn:
+ kind: binary
+ fn_name: sub
+ operands:
+ - !ScalarExpression
+ scalar_fn:
+ kind: type
+ fn_name: cast_signed
+ type_var: U
+ operands:
+ - !ScalarExpression
+ scalar_arg: I
+ - !ScalarExpression
+ scalar_fn:
+ kind: type
+ fn_name: cast_signed
+ type_var: U
+ operands:
+ - !ScalarExpression
+ scalar_arg: IZp
+ - !ScalarExpression
+ scalar_fn:
+ kind: binary
+ fn_name: sub
+ operands:
+ - !ScalarExpression
+ scalar_fn:
+ kind: type
+ fn_name: cast_signed
+ type_var: U
+ operands:
+ - !ScalarExpression
+ scalar_arg: K
+ - !ScalarExpression
+ scalar_fn:
+ kind: type
+ fn_name: cast_signed
+ type_var: U
+ operands:
+ - !ScalarExpression
+ scalar_arg: KZp
+--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: depthwise_conv_1d_nwc_wc
cpp_class_name: DepthwiseConv1DNwcWcOp
@@ -4441,3 +4580,4 @@ structured_op: !LinalgStructuredOpConfig
scalar_const: '2.3283063999999999E-10 : f64'
- !ScalarExpression
scalar_arg: min
+
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 52f9d3416e2a6..f9732efaa27eb 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -99,47 +99,40 @@ static mlir::Value getConvOutputDim(Location loc, Value inputDim,
}
// Creates a vector of the dynamic output dims for Conv2D and Depthwise_Conv2D
-static SmallVector<Value>
-inferDynamicDimsForConv(Location loc, Value input, Value weight,
- ShapedType resultTy, DenseI64ArrayAttr padAttr,
- DenseI64ArrayAttr strideAttr,
- DenseI64ArrayAttr dilationAttr, int64_t weightHDim,
- int64_t weightWDim, OpBuilder &rewriter) {
+static SmallVector<Value> inferDynamicDimsForConv(
+ Location loc, Value input, Value weight, ShapedType resultTy,
+ ArrayRef<int64_t> padAttr, ArrayRef<int64_t> strideAttr,
+ ArrayRef<int64_t> dilationAttr, ArrayRef<int64_t> inputSizeDims,
+ ArrayRef<int64_t> kernelSizeDims, OpBuilder &rewriter) {
ShapedType inputTy = input.getType().cast<ShapedType>();
Type inputETy = inputTy.getElementType();
int64_t inputRank = inputTy.getRank();
- int64_t heightDim = 1;
- int64_t weightDim = 2;
SmallVector<Value> dynDims;
dynDims.resize(resultTy.getRank());
- for (int i = 0; i < inputRank; i++) {
- if (inputTy.isDynamicDim(i) && i != heightDim && i != weightDim)
- dynDims[i] = rewriter.create<tensor::DimOp>(loc, input, i);
- }
- // Dynamic input height
- if (inputTy.isDynamicDim(heightDim)) {
- Value initHDim =
- rewriter.create<tensor::DimOp>(loc, input, heightDim).getResult();
- Value kernelHDim =
- rewriter.create<tensor::DimOp>(loc, weight, weightHDim).getResult();
- // H = F(IH, pad_top, pad_bottom, dilation_y, KH, stride_y)
- dynDims[heightDim] =
- getConvOutputDim(loc, initHDim, padAttr[0], padAttr[1], kernelHDim,
- strideAttr[0], dilationAttr[0], inputETy, rewriter);
+ for (uint32_t i = 0, s = inputSizeDims.size(); i < s; ++i) {
+ int64_t inputDim = inputSizeDims[i];
+ int64_t kernelDim = kernelSizeDims[i];
+ if (inputTy.isDynamicDim(inputDim)) {
+ auto padTop = padAttr[i * 2];
+ auto padBottom = padAttr[i * 2 + 1];
+ auto stride = strideAttr[i];
+ auto dilation = dilationAttr[i];
+ Value initDynDim = rewriter.create<tensor::DimOp>(loc, input, inputDim);
+ Value kernelDynDim =
+ rewriter.create<tensor::DimOp>(loc, weight, kernelDim);
+ // H = F(IH, pad_top, pad_bottom, dilation_y, KH, stride_y)
+ dynDims[inputDim] =
+ getConvOutputDim(loc, initDynDim, padTop, padBottom, kernelDynDim,
+ stride, dilation, inputETy, rewriter);
+ }
}
- // Dynamic input weight
- if (inputTy.isDynamicDim(weightDim)) {
- Value initWDim =
- rewriter.create<tensor::DimOp>(loc, input, weightDim).getResult();
- Value kernelWDim =
- rewriter.create<tensor::DimOp>(loc, weight, weightWDim).getResult();
- // W = F(IW, pad_left, pad_right, dilation_x, KW, stride_x)
- dynDims[weightDim] =
- getConvOutputDim(loc, initWDim, padAttr[2], padAttr[3], kernelWDim,
- strideAttr[1], dilationAttr[1], inputETy, rewriter);
+ // Get the batch/channels dimensions.
+ for (int i = 0; i < inputRank; i++) {
+ if (inputTy.isDynamicDim(i) && !dynDims[i])
+ dynDims[i] = rewriter.create<tensor::DimOp>(loc, input, i);
}
SmallVector<Value> filteredDims = condenseValues(dynDims);
@@ -161,21 +154,23 @@ static void createDepthwiseConvCollapseMap(
namespace {
-class ConvConverter : public OpConversionPattern<tosa::Conv2DOp> {
+template <typename TosaConvOp, typename LinalgConvOp, typename LinalgConvQOp>
+class ConvConverter : public OpConversionPattern<TosaConvOp> {
public:
- using OpConversionPattern<tosa::Conv2DOp>::OpConversionPattern;
+ using OpConversionPattern<TosaConvOp>::OpConversionPattern;
LogicalResult
- matchAndRewrite(tosa::Conv2DOp op, OpAdaptor adaptor,
+ matchAndRewrite(TosaConvOp op, typename TosaConvOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
Location loc = op->getLoc();
Value input = op->getOperand(0);
Value weight = op->getOperand(1);
Value bias = op->getOperand(2);
- ShapedType inputTy = input.getType().cast<ShapedType>();
- ShapedType weightTy = weight.getType().cast<ShapedType>();
- ShapedType biasTy = bias.getType().cast<ShapedType>();
- ShapedType resultTy = op->getResult(0).getType().cast<ShapedType>();
+ ShapedType inputTy = input.getType().template cast<ShapedType>();
+ ShapedType weightTy = weight.getType().template cast<ShapedType>();
+ ShapedType biasTy = bias.getType().template cast<ShapedType>();
+ ShapedType resultTy =
+ op->getResult(0).getType().template cast<ShapedType>();
Type inputETy = inputTy.getElementType();
Type resultETy = resultTy.getElementType();
@@ -183,7 +178,7 @@ class ConvConverter : public OpConversionPattern<tosa::Conv2DOp> {
DenseI64ArrayAttr padAttr = op.getPadAttr();
DenseI64ArrayAttr strideTosaAttr = op.getStrideAttr();
DenseI64ArrayAttr dilationTosaAttr = op.getDilationAttr();
- bool isQuantized = op->hasAttr("quantization_info");
+ bool isQuantized = op.getQuantizationInfo().has_value();
if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
return rewriter.notifyMatchFailure(
@@ -193,17 +188,24 @@ class ConvConverter : public OpConversionPattern<tosa::Conv2DOp> {
return rewriter.notifyMatchFailure(
op, "tosa.conv ops does not support unsigned integer input");
+ llvm::SmallVector<int64_t> inputSizeDims;
+ llvm::SmallVector<int64_t> kernelSizeDims;
+ for (int i = 1; i < resultTy.getRank() - 1; i++) {
+ inputSizeDims.push_back(i);
+ kernelSizeDims.push_back(i);
+ }
+
SmallVector<Value> filteredDims = inferDynamicDimsForConv(
- loc, input, weight, resultTy, padAttr, strideTosaAttr, dilationTosaAttr,
- /*weightHDim=*/1, /*weightWDim=*/2, rewriter);
+ loc, input, weight, resultTy, padAttr.asArrayRef(),
+ strideTosaAttr.asArrayRef(), dilationTosaAttr.asArrayRef(),
+ inputSizeDims, kernelSizeDims, rewriter);
auto weightShape = weightTy.getShape();
// Apply padding as necessary.
Attribute zeroAttr = rewriter.getZeroAttr(inputETy);
if (isQuantized) {
- auto quantizationInfo =
- op->getAttr("quantization_info").cast<tosa::ConvOpQuantizationAttr>();
+ auto quantizationInfo = *op.getQuantizationInfo();
int64_t iZp = quantizationInfo.getInputZp();
int64_t intMin =
@@ -230,11 +232,15 @@ class ConvConverter : public OpConversionPattern<tosa::Conv2DOp> {
// convolution operation.
// TODO(suderman): See if this can be efficiently folded - check whether
// the input is used anywhere else, if not fold the constant.
- SmallVector<int64_t> weightPerm{1, 2, 3, 0};
- SmallVector<int64_t> newWeightShape{weightShape[1], weightShape[2],
- weightShape[3], weightShape[0]};
- auto weightPermAttr = DenseIntElementsAttr::get(
- RankedTensorType::get({4}, rewriter.getI64Type()), weightPerm);
+ SmallVector<int64_t> weightPerm;
+ for (int i = 1; i < resultTy.getRank(); i++)
+ weightPerm.push_back(i);
+ weightPerm.push_back(0);
+
+ SmallVector<int64_t> newWeightShape;
+ for (auto dim : weightPerm)
+ newWeightShape.push_back(weightShape[dim]);
+ auto weightPermAttr = rewriter.getI64TensorAttr(weightPerm);
Value weightPermValue =
rewriter.create<arith::ConstantOp>(loc, weightPermAttr);
Type newWeightTy =
@@ -256,16 +262,15 @@ class ConvConverter : public OpConversionPattern<tosa::Conv2DOp> {
ArrayRef<int64_t> dilation = dilationTosaAttr;
// Create the convolution op.
- auto strideAttr = DenseIntElementsAttr::get(
- RankedTensorType::get({2}, rewriter.getI64Type()), stride);
- auto dilationAttr = DenseIntElementsAttr::get(
- RankedTensorType::get({2}, rewriter.getI64Type()), dilation);
+ auto strideAttr = rewriter.getI64TensorAttr(stride);
+ auto dilationAttr = rewriter.getI64TensorAttr(dilation);
// Create maps for the bias broadcasting
SmallVector<AffineMap, 4> indexingMaps;
indexingMaps.push_back(AffineMap::get(
/*dimCount=*/resultTy.getRank(), /*symbolCount=*/0,
- {rewriter.getAffineDimExpr(3)}, rewriter.getContext()));
+ {rewriter.getAffineDimExpr(resultTy.getRank() - 1)},
+ rewriter.getContext()));
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultTy.getRank()));
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultTy.getRank()));
@@ -273,8 +278,7 @@ class ConvConverter : public OpConversionPattern<tosa::Conv2DOp> {
loc, resultTy.getShape(), resultETy, filteredDims);
if (isQuantized) {
- auto quantizationInfo =
- op->getAttr("quantization_info").cast<tosa::ConvOpQuantizationAttr>();
+ auto quantizationInfo = *op.getQuantizationInfo();
auto iZp = rewriter.getI32IntegerAttr(quantizationInfo.getInputZp());
auto kZp = rewriter.getI32IntegerAttr(quantizationInfo.getWeightZp());
@@ -282,7 +286,7 @@ class ConvConverter : public OpConversionPattern<tosa::Conv2DOp> {
auto kZpVal = rewriter.create<arith::ConstantOp>(loc, kZp);
Value conv =
rewriter
- .create<linalg::Conv2DNhwcHwcfQOp>(
+ .create<LinalgConvQOp>(
loc, resultTy, ValueRange{input, weight, iZpVal, kZpVal},
ValueRange{zeroTensor}, strideAttr, dilationAttr)
->getResult(0);
@@ -304,7 +308,7 @@ class ConvConverter : public OpConversionPattern<tosa::Conv2DOp> {
}
Value conv = rewriter
- .create<linalg::Conv2DNhwcHwcfOp>(
+ .create<LinalgConvOp>(
loc, resultTy, ValueRange{input, weight},
ValueRange{zeroTensor}, strideAttr, dilationAttr)
->getResult(0);
@@ -358,8 +362,10 @@ class DepthwiseConvConverter
// Compute output dynamic dims
SmallVector<Value> filteredDims = inferDynamicDimsForConv(
- loc, input, weight, resultTy, padAttr, strideTosaAttr, dilationTosaAttr,
- 0, 1, rewriter);
+ loc, input, weight, resultTy, padAttr.asArrayRef(),
+ strideTosaAttr.asArrayRef(), dilationTosaAttr.asArrayRef(),
+ /*inputSizeDims=*/{1, 2},
+ /*kernelSizeDims=*/{0, 1}, rewriter);
bool isQuantized = op->hasAttr("quantization_info");
IntegerAttr iZp;
@@ -408,11 +414,8 @@ class DepthwiseConvConverter
ArrayRef<int64_t> dilation = dilationTosaAttr;
// Create the convolution op.
- auto strideAttr = DenseIntElementsAttr::get(
- RankedTensorType::get({2}, rewriter.getI64Type()), stride);
- auto dilationAttr = DenseIntElementsAttr::get(
- RankedTensorType::get({2}, rewriter.getI64Type()), dilation);
-
+ auto strideAttr = rewriter.getI64TensorAttr(stride);
+ auto dilationAttr = rewriter.getI64TensorAttr(dilation);
ShapedType linalgConvTy =
RankedTensorType::get({resultShape[0], resultShape[1], resultShape[2],
weightShape[2], weightShape[3]},
@@ -610,8 +613,7 @@ class FullyConnectedConverter
.result();
SmallVector<int64_t> permutation{1, 0};
- auto permutationAttr = DenseIntElementsAttr::get(
- RankedTensorType::get({2}, rewriter.getI64Type()), permutation);
+ auto permutationAttr = rewriter.getI64TensorAttr(permutation);
Value permutationValue =
rewriter.create<arith::ConstantOp>(loc, permutationAttr);
@@ -966,7 +968,8 @@ void mlir::tosa::populateTosaToLinalgNamedConversionPatterns(
RewritePatternSet *patterns) {
patterns->add<
// clang-format off
- ConvConverter,
+ ConvConverter<tosa::Conv2DOp, linalg::Conv2DNhwcHwcfOp, linalg::Conv2DNhwcHwcfQOp>,
+ ConvConverter<tosa::Conv3DOp, linalg::Conv3DNdhwcDhwcfOp, linalg::Conv3DNdhwcDhwcfQOp>,
DepthwiseConvConverter,
MatMulConverter,
MaxPool2dConverter,
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp
index ac67faa4b3ca4..4c941a109ed84 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp
@@ -51,6 +51,7 @@ struct TosaToLinalgNamed
// Not every TOSA op can be legalized to linalg.
target.addIllegalOp<tosa::Conv2DOp>();
+ target.addIllegalOp<tosa::Conv3DOp>();
target.addIllegalOp<tosa::DepthwiseConv2DOp>();
target.addIllegalOp<tosa::MaxPool2dOp>();
target.addIllegalOp<tosa::AvgPool2dOp>();
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index 8bab1607b4ed4..4402624c174b2 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -150,6 +150,7 @@ def quantized_batch_matmul(A=TensorDef(T1, Batch, S.M, S.K),
TypeFn.cast_signed(U, AZp)) * (TypeFn.cast_signed(
U, B[D.b, D.k, D.n]) - TypeFn.cast_signed(U, BZp))
+
@linalg_structured_op
def batch_reduce_matmul(A=TensorDef(T1, Batch, S.M, S.K),
B=TensorDef(T2, Batch, S.K, S.N),
@@ -162,8 +163,9 @@ def batch_reduce_matmul(A=TensorDef(T1, Batch, S.M, S.K),
"""
domain(D.b, D.m, D.n, D.k)
implements(ContractionOpInterface)
- C[D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k] * TypeFn.cast_signed(
- U, B[D.b, D.k, D.n]))
+ C[D.m, D.n] += TypeFn.cast_signed(
+ U, A[D.b, D.m, D.k] * TypeFn.cast_signed(U, B[D.b, D.k, D.n]))
+
@linalg_structured_op
def matvec(A=TensorDef(T1, S.M, S.N),
@@ -283,6 +285,7 @@ def conv_1d_nwc_wcf(I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.C),
U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c]) * TypeFn.cast_signed(
U, K[D.kw, D.c, D.f])
+
@linalg_structured_op
def conv_1d_ncw_fcw(I=TensorDef(T1, S.N, S.C, S.OW * S.SW + S.KW * S.DW),
K=TensorDef(T2, S.F, S.C, S.KW),
@@ -304,6 +307,7 @@ def conv_1d_ncw_fcw(I=TensorDef(T1, S.N, S.C, S.OW * S.SW + S.KW * S.DW),
U, I[D.n, D.c, D.ow * S.SW + D.kw * S.DW]) * TypeFn.cast_signed(
U, K[D.f, D.c, D.kw])
+
@linalg_structured_op
def conv_2d_nhwc_hwcf(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH,
S.OW * S.SW + S.KW * S.DW, S.C),
@@ -400,13 +404,15 @@ def conv_2d_nchw_fchw(I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH,
U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW +
D.kw * S.DW]) * TypeFn.cast_signed(U, K[D.f, D.c, D.kh, D.kw])
+
@linalg_structured_op
-def conv_2d_ngchw_fgchw(I=TensorDef(T1, S.N, S.G, S.C, S.OH * S.SH + S.KH * S.DH,
- S.OW * S.SW + S.KW * S.DW),
- K=TensorDef(T2, S.FG, S.G, S.C, S.KH, S.KW),
- O=TensorDef(U, S.N, S.FG, S.G, S.OH, S.OW, output=True),
- strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
- dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
+def conv_2d_ngchw_fgchw(I=TensorDef(T1, S.N, S.G, S.C,
+ S.OH * S.SH + S.KH * S.DH,
+ S.OW * S.SW + S.KW * S.DW),
+ K=TensorDef(T2, S.FG, S.G, S.C, S.KH, S.KW),
+ O=TensorDef(U, S.N, S.FG, S.G, S.OH, S.OW, output=True),
+ strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+ dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
"""Performs 2-D grouped convolution.
Layout:
@@ -420,7 +426,8 @@ def conv_2d_ngchw_fgchw(I=TensorDef(T1, S.N, S.G, S.C, S.OH * S.SH + S.KH * S.DH
domain(D.n, D.g, D.fg, D.oh, D.ow, D.c, D.kh, D.kw)
O[D.n, D.g, D.fg, D.oh, D.ow] += TypeFn.cast_signed(
U, I[D.n, D.g, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW +
- D.kw * S.DW]) * TypeFn.cast_signed(U, K[D.g, D.fg, D.c, D.kh, D.kw])
+ D.kw * S.DW]) * TypeFn.cast_signed(U, K[D.g, D.fg, D.c, D.kh, D.kw])
+
@linalg_structured_op
def conv_3d_ndhwc_dhwcf(I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD,
@@ -449,6 +456,43 @@ def conv_3d_ndhwc_dhwcf(I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD,
U, K[D.kd, D.kh, D.kw, D.c, D.f])
+ at linalg_structured_op
+def conv_3d_ndhwc_dhwcf_q(I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD,
+ S.OH * S.SH + S.KH * S.DH,
+ S.OW * S.SW + S.KW * S.DW, S.C),
+ K=TensorDef(T2, S.KD, S.KH, S.KW, S.C, S.F),
+ IZp=ScalarDef(I32),
+ KZp=ScalarDef(I32),
+ O=TensorDef(U,
+ S.N,
+ S.OD,
+ S.OH,
+ S.OW,
+ S.F,
+ output=True),
+ strides=IndexAttrDef(S.SD,
+ S.SH,
+ S.SW,
+ default=[1, 1, 1]),
+ dilations=IndexAttrDef(S.DD,
+ S.DH,
+ S.DW,
+ default=[1, 1, 1])):
+ """Performs 3-D convolution with zero point offsets.
+
+ Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output. This includes the zero
+ point offsets common to quantized operations.
+ """
+ implements(ConvolutionOpInterface)
+ domain(D.n, D.od, D.oh, D.ow, D.f, D.kd, D.kh, D.kw, D.c)
+ O[D.n, D.od, D.oh, D.ow, D.f] += (TypeFn.cast_signed(
+ U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH,
+ D.ow * S.SW + D.kw * S.DW, D.c]) - TypeFn.cast_signed(U, IZp)) * (
+ TypeFn.cast_signed(U, K[D.kd, D.kh, D.kw, D.c, D.f]) -
+ TypeFn.cast_signed(U, KZp))
+
+
@linalg_structured_op
def depthwise_conv_1d_nwc_wc(I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW,
S.IC),
@@ -517,7 +561,8 @@ def depthwise_conv_2d_nhwc_hwc(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH,
@linalg_structured_op
-def depthwise_conv_2d_nchw_chw(I=TensorDef(T1, S.N, S.IC, S.OH * S.SH + S.KH * S.DH,
+def depthwise_conv_2d_nchw_chw(I=TensorDef(T1, S.N, S.IC,
+ S.OH * S.SH + S.KH * S.DH,
S.OW * S.SW + S.KW * S.DW),
K=TensorDef(T2, S.IC, S.KH, S.KW),
O=TensorDef(U,
@@ -539,7 +584,8 @@ def depthwise_conv_2d_nchw_chw(I=TensorDef(T1, S.N, S.IC, S.OH * S.SH + S.KH * S
implements(ConvolutionOpInterface)
domain(D.n, D.oh, D.ow, D.ic, D.kh, D.kw)
O[D.n, D.ic, D.oh, D.ow] += TypeFn.cast_signed(
- U, I[D.n, D.ic, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW]) * TypeFn.cast_signed(U, K[D.ic, D.kh, D.kw])
+ U, I[D.n, D.ic, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW +
+ D.kw * S.DW]) * TypeFn.cast_signed(U, K[D.ic, D.kh, D.kw])
@linalg_structured_op
@@ -642,7 +688,11 @@ def depthwise_conv_3d_ndhwc_dhwc(I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD,
S.OH * S.SH + S.KH * S.DH,
S.OW * S.SW + S.KW * S.DW, S.IC),
K=TensorDef(T2, S.KD, S.KH, S.KW, S.IC),
- O=TensorDef(U, S.N, S.OD, S.OH, S.OW,
+ O=TensorDef(U,
+ S.N,
+ S.OD,
+ S.OH,
+ S.OW,
output=True),
strides=IndexAttrDef(S.SD,
S.SH,
@@ -667,12 +717,17 @@ def depthwise_conv_3d_ndhwc_dhwc(I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD,
@linalg_structured_op
-def depthwise_conv_3d_ndhwc_dhwcm(I=TensorDef(T1,
- S.N, S.OD * S.SD + S.KD * S.DD,
+def depthwise_conv_3d_ndhwc_dhwcm(I=TensorDef(T1, S.N,
+ S.OD * S.SD + S.KD * S.DD,
S.OH * S.SH + S.KH * S.DH,
S.OW * S.SW + S.KW * S.DW, S.IC),
K=TensorDef(T2, S.KD, S.KH, S.KW, S.IC, S.CM),
- O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.CM,
+ O=TensorDef(U,
+ S.N,
+ S.OD,
+ S.OH,
+ S.OW,
+ S.CM,
output=True),
strides=IndexAttrDef(S.SD,
S.SH,
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index e6427174b66a3..5a28597052c58 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -603,3 +603,55 @@ func.func @depthwise_conv2d_dyn_w_h(%arg0: tensor<2x?x?x3xf32>, %arg1: tensor<3x
%0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = array<i64: 1, 2, 3, 4>, dilation = array<i64: 2, 1>, stride = array<i64: 1, 2>} : (tensor<2x?x?x3xf32>, tensor<3x6x3x5xf32>, tensor<15xf32>) -> tensor<2x?x?x15xf32>
return
}
+
+// -----
+
+// CHECK-LABEL: @conv3d_f32
+func.func @conv3d_f32(%input: tensor<1x49x48x47x27xf32>, %weights: tensor<28x3x4x5x27xf32>, %bias: tensor<28xf32>) -> () {
+ // CHECK-DAG: %[[PERMS:.+]] = arith.constant dense<[1, 2, 3, 4, 0]>
+ // CHECK-DAG: %[[TRANSPOSE:.+]] = "tosa.transpose"(%arg1, %[[PERMS]])
+ // CHECK-DAG: %[[EMPTY:.+]] = tensor.empty()
+ // CHECK-DAG: %[[ZERO:.+]] = arith.constant 0
+ // CHECK-DAG: %[[FILL:.+]] = linalg.fill ins(%[[ZERO]] : f32) outs(%[[EMPTY]] : tensor<1x47x45x43x28xf32>)
+ // CHECK-DAG: %[[EMPTY:.+]] = tensor.empty()
+ // CHECK-DAG: %[[CONV3D:.+]] = linalg.conv_3d_ndhwc_dhwcf
+ // CHECK-SAME: {dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>}
+ // CHECK-SAME: ins(%arg0, %[[TRANSPOSE]] : tensor<1x49x48x47x27xf32>, tensor<3x4x5x27x28xf32>)
+ // CHECK-SAME: outs(%[[FILL]] : tensor<1x47x45x43x28xf32>) -> tensor<1x47x45x43x28xf32>
+ // CHECK: %[[GENERIC:.+]] = linalg.generic
+ // CHECK-SAME: {indexing_maps = [#map, #map1, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]}
+ // CHECK-SAME: ins(%arg2, %[[CONV3D]] : tensor<28xf32>, tensor<1x47x45x43x28xf32>)
+ // CHECK--SAME: outs(%[[EMPTY]] : tensor<1x47x45x43x28xf32>) {
+ // CHECK: ^bb0(%[[A1:.+]]: f32, %[[A2:.+]]: f32, %{{.+}}: f32):
+ // CHECK: %[[ADD:.+]] = arith.addf %[[A1]], %[[A2]] : f32
+ // CHECK: linalg.yield %[[ADD]]
+ %0 = "tosa.conv3d"(%input, %weights, %bias) {pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>, dilation = array<i64: 1, 1, 1>} : (tensor<1x49x48x47x27xf32>, tensor<28x3x4x5x27xf32>, tensor<28xf32>) -> tensor<1x47x45x43x28xf32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @conv3d_i8
+func.func @conv3d_i8(%input: tensor<1x49x48x47x27xi8>, %weights: tensor<28x3x4x5x27xi8>, %bias: tensor<28xi32>) -> () {
+ // CHECK-DAG: %[[PERMS:.+]] = arith.constant dense<[1, 2, 3, 4, 0]>
+ // CHECK-DAG: %[[TRANSPOSE:.+]] = "tosa.transpose"(%arg1, %[[PERMS]])
+ // CHECK-DAG: %[[EMPTY:.+]] = tensor.empty()
+ // CHECK-DAG: %[[ZERO:.+]] = arith.constant 0
+ // CHECK-DAG: %[[FILL:.+]] = linalg.fill ins(%[[ZERO]] : i32) outs(%[[EMPTY]] : tensor<1x47x45x43x28xi32>)
+ // CHECK-DAG: %[[EMPTY:.+]] = tensor.empty()
+ // CHECK-DAG: %[[IZP:.+]] = arith.constant -128 : i32
+ // CHECK-DAG: %[[FZP:.+]] = arith.constant 42 : i32
+ // CHECK-DAG: %[[CONV3D:.+]] = linalg.conv_3d_ndhwc_dhwcf_q
+ // CHECK-SAME: {dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>}
+ // CHECK-SAME: ins(%arg0, %[[TRANSPOSE]], %[[IZP]], %[[FZP]] : tensor<1x49x48x47x27xi8>, tensor<3x4x5x27x28xi8>, i32, i32)
+ // CHECK-SAME: outs(%[[FILL]] : tensor<1x47x45x43x28xi32>) -> tensor<1x47x45x43x28xi32>
+ // CHECK: %[[GENERIC:.+]] = linalg.generic
+ // CHECK-SAME: {indexing_maps = [#map, #map1, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]}
+ // CHECK-SAME: ins(%arg2, %[[CONV3D]] : tensor<28xi32>, tensor<1x47x45x43x28xi32>)
+ // CHECK--SAME: outs(%[[EMPTY]] : tensor<1x47x45x43x28xi32>) {
+ // CHECK: ^bb0(%[[A1:.+]]: i32, %[[A2:.+]]: i32, %{{.+}}: i32):
+ // CHECK: %[[ADD:.+]] = arith.addi %[[A1]], %[[A2]] : i32
+ // CHECK: linalg.yield %[[ADD]]
+ %0 = "tosa.conv3d"(%input, %weights, %bias) {pad = array<i64: 0, 0, 0, 0, 0, 0>, quantization_info = #tosa.conv_quant<input_zp = -128, weight_zp = 42>, stride = array<i64: 1, 1, 1>, dilation = array<i64: 1, 1, 1>} : (tensor<1x49x48x47x27xi8>, tensor<28x3x4x5x27xi8>, tensor<28xi32>) -> tensor<1x47x45x43x28xi32>
+ return
+}
More information about the Mlir-commits
mailing list