[Mlir-commits] [mlir] 9c923f4 - [mlir][linalg] Fix vectorization of linalg depthwise conv for int types

Rob Suderman llvmlistbot at llvm.org
Tue Nov 8 16:32:07 PST 2022

Author: Rob Suderman
Date: 2022-11-08T16:21:05-08:00
New Revision: 9c923f4e58357af870925781f0b33d1ee958b9d8

URL: https://github.com/llvm/llvm-project/commit/9c923f4e58357af870925781f0b33d1ee958b9d8
DIFF: https://github.com/llvm/llvm-project/commit/9c923f4e58357af870925781f0b33d1ee958b9d8.diff

LOG: [mlir][linalg] Fix vectorization of linalg depthwise conv for int types

Vectorization of Linalg's depthwise convolution only supports floating
point types. Previous version assumed floating point operations would
work. This version checks whether the computation is integer or floating
point and adjust the inner loop computation.

Reviewed By: hanchung

Differential Revision: https://reviews.llvm.org/D137595




diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index cedec72b9cb33..2cf74a67df20e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1746,11 +1746,17 @@ struct Conv1DGenerator : public StructuredGenerator<LinalgOp> {
     // Compute contraction: O{n, w, c} += I{n, sw * w + dw * kw, c} * F{c}
     for (int64_t kw = 0; kw < kwSize; ++kw) {
       for (int64_t w = 0; w < wSize; w += wSizeStep) {
-        resVals[w] = depthwiseConv1dSliceAsFma(
+        resVals[w] = depthwiseConv1dSliceAsMulAcc(
             builder, loc, lhsVals[linearIndex(kw, w)], rhsVals[kw], resVals[w]);
+    // Its possible we failed to create the Fma
+    for (auto v : resVals) {
+      if (!v)
+        return failure();
+    }
     // Write back res slice: {n, wSizeStep, c} @ [0, w, 0].
     // This does not depend on kw.
     for (int64_t w = 0; w < wSize; w += wSizeStep) {
@@ -1770,11 +1776,45 @@ struct Conv1DGenerator : public StructuredGenerator<LinalgOp> {
-  /// Lower lhs{n, w, c} * rhs{c} -> res{n, w, c} to fma.
-  Value depthwiseConv1dSliceAsFma(OpBuilder &b, Location loc, Value lhs,
-                                  Value rhs, Value res) {
-    Value bcast = builder.create<vector::BroadcastOp>(loc, res.getType(), rhs);
-    return b.create<vector::FMAOp>(loc, lhs, bcast, res);
+  // Take a value of element type T and widen to the destination type.
+  Value promote(OpBuilder &b, Location loc, Value val, Type ty) {
+    if (val.getType() == ty)
+      return val;
+    const int64_t srcWidth =
+        getElementTypeOrSelf(val.getType()).getIntOrFloatBitWidth();
+    const int64_t destWidth = getElementTypeOrSelf(ty).getIntOrFloatBitWidth();
+    if (getElementTypeOrSelf(ty).isa<FloatType>() && srcWidth < destWidth)
+      return builder.create<arith::ExtFOp>(loc, ty, val);
+    if (getElementTypeOrSelf(ty).isa<IntegerType>() && srcWidth < destWidth)
+      return builder.create<arith::ExtSIOp>(loc, ty, val);
+    return nullptr;
+  }
+  /// Lower lhs{n, w, c} * rhs{c} -> res{n, w, c} to MulAcc
+  Value depthwiseConv1dSliceAsMulAcc(OpBuilder &b, Location loc, Value lhs,
+                                     Value rhs, Value res) {
+    auto rhsTy = rhs.getType().cast<ShapedType>();
+    auto resTy = res.getType().cast<ShapedType>();
+    // TODO(suderman): Change this to use a vector.ima intrinsic.
+    lhs = promote(b, loc, lhs, resTy);
+    rhs = builder.create<vector::BroadcastOp>(
+        loc, resTy.clone(rhsTy.getElementType()), rhs);
+    rhs = promote(b, loc, rhs, resTy);
+    if (!lhs || !rhs)
+      return nullptr;
+    if (resTy.getElementType().isa<FloatType>())
+      return b.create<vector::FMAOp>(loc, lhs, rhs, res);
+    auto mul = b.create<arith::MulIOp>(loc, lhs, rhs);
+    return b.create<arith::AddIOp>(loc, mul, res);
   /// Entry point that transposes into the common form:

diff  --git a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
index 1374c996128a1..f1f00cf16d1b0 100644
--- a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
@@ -463,7 +463,7 @@ func.func @conv1d_ncw_4x8x2_memref(%input: memref<4x3x6xf32>, %filter: memref<8x
 // -----
-func.func @depthwise_conv1d_nwc_wc_3x5x4_memref(%input: memref<3x5x4xf32>, %filter: memref<2x4xf32>, %output: memref<3x2x4xf32>) {
+func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref(%input: memref<3x5x4xf32>, %filter: memref<2x4xf32>, %output: memref<3x2x4xf32>) {
     {dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
     ins(%input, %filter : memref<3x5x4xf32>, memref<2x4xf32>)
@@ -471,7 +471,7 @@ func.func @depthwise_conv1d_nwc_wc_3x5x4_memref(%input: memref<3x5x4xf32>, %filt
-//       CHECK: func @depthwise_conv1d_nwc_wc_3x5x4_memref
+//       CHECK: func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref
 //  CHECK-SAME:   (%[[INPUT:[0-9a-z]+]]: memref<3x5x4xf32>, %[[FILTER:[0-9a-z]+]]: memref<2x4xf32>, %[[OUTPUT:[0-9a-z]+]]: memref<3x2x4xf32>)
 //   CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
@@ -502,6 +502,51 @@ func.func @depthwise_conv1d_nwc_wc_3x5x4_memref(%input: memref<3x5x4xf32>, %filt
 //      CHECK:   vector.transfer_write %[[FMA_1]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
+// -----
+func.func @depthwise_conv1d_nwc_wc_3x5x4xi8_memref(%input: memref<3x5x4xi8>, %filter: memref<2x4xi8>, %output: memref<3x2x4xi32>) {
+  linalg.depthwise_conv_1d_nwc_wc
+    {dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
+    ins(%input, %filter : memref<3x5x4xi8>, memref<2x4xi8>)
+    outs(%output : memref<3x2x4xi32>)
+  return
+//       CHECK: func @depthwise_conv1d_nwc_wc_3x5x4xi8_memref
+//  CHECK-SAME:   (%[[INPUT:[0-9a-z]+]]: memref<3x5x4xi8>, %[[FILTER:[0-9a-z]+]]: memref<2x4xi8>, %[[OUTPUT:[0-9a-z]+]]: memref<3x2x4xi32>)
+//   CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+/// Read the whole data in one shot.
+//      CHECK:   %[[V_INPUT_R:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]]
+//      CHECK:  %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]]]
+//      CHECK:  %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
+//      CHECK:   %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
+// CHECK-SAME:     {offsets = [0, 0, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x4xi8> to vector<3x2x4xi8>
+//      CHECK:   %[[V_INPUT_1:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
+// CHECK-SAME:     {offsets = [0, 2, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x4xi8> to vector<3x2x4xi8>
+//      CHECK:  %[[V_FILTER_0:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<2x4xi8>
+//      CHECK:  %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][1] : vector<2x4xi8>
+/// w == 0, kw = 
+//      CHECK:  %[[EXT_INPUT_0:.*]] = arith.extsi %[[V_INPUT_0]] : vector<3x2x4xi8> to vector<3x2x4xi32>
+//      CHECK:  %[[B_FILTER_0:.*]] = vector.broadcast %[[V_FILTER_0]] : vector<4xi8> to vector<3x2x4xi8>
+//      CHECK:  %[[EXT_FILTER_0:.*]] = arith.extsi %[[B_FILTER_0]] : vector<3x2x4xi8> to vector<3x2x4xi32>
+//      CHECK:  %[[MUL_0:.*]] = arith.muli %[[EXT_INPUT_0]], %[[EXT_FILTER_0]] : vector<3x2x4xi32>
+//      CHECK:  %[[ADD_0:.*]] = arith.addi %[[MUL_0]], %[[V_OUTPUT_R]] : vector<3x2x4xi32>
+/// w == 0, kw = 1
+//      CHECK:  %[[EXT_INPUT_1:.*]] = arith.extsi %[[V_INPUT_1]] : vector<3x2x4xi8> to vector<3x2x4xi32>
+//      CHECK:  %[[B_FILTER_1:.*]] = vector.broadcast %[[V_FILTER_1]] : vector<4xi8> to vector<3x2x4xi8>
+//      CHECK:  %[[EXT_FILTER_1:.*]] = arith.extsi %[[B_FILTER_1]] : vector<3x2x4xi8> to vector<3x2x4xi32>
+//      CHECK:  %[[MUL_1:.*]] = arith.muli %[[EXT_INPUT_1]], %[[EXT_FILTER_1]] : vector<3x2x4xi32>
+//      CHECK:  %[[ADD_1:.*]] = arith.addi %[[MUL_1]], %[[ADD_0]] : vector<3x2x4xi32>
+// Write the result back in one shot.
+//      CHECK:   vector.transfer_write %[[ADD_1]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
 // -----
 func.func @conv_1d_nwc_wcf_mixed_type_memref(%input: memref<1x2x3xf16>, %filter: memref<1x3x2xf16>, %output: memref<1x2x2xf32>) {


More information about the Mlir-commits mailing list