[Mlir-commits] [mlir] 5299953 - [mlir][linalg] Add vectorization support for conv_1d

Hanhan Wang llvmlistbot at llvm.org
Wed Mar 8 14:23:47 PST 2023


Author: Devajith Valaparambil Sreeramaswamy
Date: 2023-03-08T14:23:36-08:00
New Revision: 5299953ababd25c00ad97d0db4a4a795359c1058

URL: https://github.com/llvm/llvm-project/commit/5299953ababd25c00ad97d0db4a4a795359c1058
DIFF: https://github.com/llvm/llvm-project/commit/5299953ababd25c00ad97d0db4a4a795359c1058.diff

LOG: [mlir][linalg] Add vectorization support for conv_1d

This MR add vectorization support for linalg.conv_1D operation.

Reviewed By: nicolasvasilache, hanchung, dcaballe, vmurali

Differential Revision: https://reviews.llvm.org/D145160

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/test/Dialect/Linalg/vectorize-convolution.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 3db61da84128a..275fab1f7a519 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -62,6 +62,112 @@ static OpType getSingleOpOfType(Block &block) {
   return res;
 }
 
+/// Helper function to extract the input slices after filter is unrolled along
+/// kw.
+static SmallVector<Value>
+extractConvInputSlices(RewriterBase &rewriter, Location loc, Value input,
+                       int64_t nSize, int64_t wSize, int64_t cSize,
+                       int64_t kwSize, int strideW, int dilationW,
+                       int64_t wSizeStep, bool isSingleChanneled) {
+  SmallVector<Value> result;
+  if (isSingleChanneled) {
+    // Extract input slice of size {wSizeStep} @ [w + kw] for non-channeled
+    // convolution.
+    SmallVector<int64_t> sizes{wSizeStep};
+    SmallVector<int64_t> strides{1};
+    for (int64_t kw = 0; kw < kwSize; ++kw) {
+      for (int64_t w = 0; w < wSize; w += wSizeStep) {
+        result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
+            loc, input, /*offsets=*/ArrayRef<int64_t>{w + kw}, sizes, strides));
+      }
+    }
+  } else {
+    // Extract lhs slice of size {n, wSizeStep, c} @ [0, sw * w + dw * kw, 0]
+    // for channeled convolution.
+    SmallVector<int64_t> sizes{nSize, wSizeStep, cSize};
+    SmallVector<int64_t> strides{1, 1, 1};
+    for (int64_t kw = 0; kw < kwSize; ++kw) {
+      for (int64_t w = 0; w < wSize; w += wSizeStep) {
+        result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
+            loc, input,
+            /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
+            sizes, strides));
+      }
+    }
+  }
+  return result;
+}
+
+/// Helper function to extract the filter slices after filter is unrolled along
+/// kw.
+static SmallVector<Value> extractConvFilterSlices(RewriterBase &rewriter,
+                                                  Location loc, Value filter,
+                                                  int64_t kwSize) {
+  SmallVector<Value> result;
+  // Extract rhs slice of size [{c, f} for channeled convolutions and {1} for
+  // non-chanelled convolution] @ [kw].
+  for (int64_t kw = 0; kw < kwSize; ++kw) {
+    result.push_back(rewriter.create<vector::ExtractOp>(
+        loc, filter, /*offsets=*/ArrayRef<int64_t>{kw}));
+  }
+  return result;
+}
+
+/// Helper function to extract the result slices after filter is unrolled along
+/// kw.
+static SmallVector<Value>
+extractConvResultSlices(RewriterBase &rewriter, Location loc, Value res,
+                        int64_t nSize, int64_t wSize, int64_t fSize,
+                        int64_t wSizeStep, bool isSingleChanneled) {
+  SmallVector<Value> result;
+  if (isSingleChanneled) {
+    // Extract res slice: {wSizeStep} @ [w] for non-channeled convolution.
+    SmallVector<int64_t> sizes{wSizeStep};
+    SmallVector<int64_t> strides{1};
+    for (int64_t w = 0; w < wSize; w += wSizeStep) {
+      result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
+          loc, res, /*offsets=*/ArrayRef<int64_t>{w}, sizes, strides));
+    }
+  } else {
+    // Extract res slice: {n, wSizeStep, f} @ [0, w, 0] for channeled
+    // convolution.
+    SmallVector<int64_t> sizes{nSize, wSizeStep, fSize};
+    SmallVector<int64_t> strides{1, 1, 1};
+    for (int64_t w = 0; w < wSize; w += wSizeStep) {
+      result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
+          loc, res, /*offsets=*/ArrayRef<int64_t>{0, w, 0}, sizes, strides));
+    }
+  }
+  return result;
+}
+
+/// Helper function to insert the computed result slices.
+static Value insertConvResultSlices(RewriterBase &rewriter, Location loc,
+                                    Value res, int64_t wSize, int64_t wSizeStep,
+                                    SmallVectorImpl<Value> &resVals,
+                                    bool isSingleChanneled) {
+
+  if (isSingleChanneled) {
+    // Write back res slice: {wSizeStep} @ [w] for non-channeled convolution.
+    // This does not depend on kw.
+    SmallVector<int64_t> strides{1};
+    for (int64_t w = 0; w < wSize; w += wSizeStep) {
+      res = rewriter.create<vector::InsertStridedSliceOp>(
+          loc, resVals[w], res, /*offsets=*/ArrayRef<int64_t>{w}, strides);
+    }
+  } else {
+    // Write back res slice: {n, wSizeStep, f} @ [0, w, 0] for channeled
+    // convolution. This does not depend on kw.
+    SmallVector<int64_t> strides{1, 1, 1};
+    for (int64_t w = 0; w < wSize; w += wSizeStep) {
+      res = rewriter.create<vector::InsertStridedSliceOp>(
+          loc, resVals[w], res, /*offsets=*/ArrayRef<int64_t>{0, w, 0},
+          strides);
+    }
+  }
+  return res;
+}
+
 /// Contains the vectorization state and related methods used across the
 /// vectorization process of a given operation.
 struct VectorizationState {
@@ -334,6 +440,7 @@ static AffineMap reindexIndexingMap(AffineMap map) {
 
 /// Helper enum to represent conv1d input traversal order.
 enum class Conv1DOpOrder {
+  W,   // Corresponds to non-channeled 1D convolution operation.
   Ncw, // Corresponds to operation that traverses the input in (n, c, w) order.
   Nwc  // Corresponds to operation that traverses the input in (n, w, c) order.
 };
@@ -2055,6 +2162,15 @@ bool isSupportedPoolKind(vector::CombiningKind kind) {
 
 /// Generate a vector implementation for either:
 /// ```
+///   Op def: (     w,     kw  )
+///    Iters: ({Par(), Red()})
+///   Layout: {{w + kw}, {kw}, {w}}
+/// ```
+/// kw is unrolled.
+///
+/// or
+///
+/// ```
 ///   Op def: (     n,     w,     c,    kw,    f  )
 ///    Iters: ({Par(), Par(), Par(), Red(), Red()})
 ///   Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
@@ -2095,8 +2211,10 @@ struct Conv1DGenerator
     resShapedType = resShaped.getType().dyn_cast<ShapedType>();
     if (!lhsShapedType || !rhsShapedType || !resShapedType)
       return;
-    // LHS has dimension NCW/NWC and RES has dimension NFW/NCW/NWF/NWC.
-    if (lhsShapedType.getRank() != 3 || resShapedType.getRank() != 3)
+    // (LHS has dimension NCW/NWC and RES has dimension NFW/NCW/NWF/NWC) OR
+    // (non-channeled convolution -> LHS and RHS both have single dimensions).
+    if (!((lhsShapedType.getRank() == 3 && resShapedType.getRank() == 3) ||
+          (lhsShapedType.getRank() == 1 && resShapedType.getRank() == 1)))
       return;
 
     Operation *reduceOp = matchLinalgReduction(linalgOp.getDpsInitOperand(0));
@@ -2115,7 +2233,7 @@ struct Conv1DGenerator
     auto rhsRank = rhsShapedType.getRank();
     switch (oper) {
     case Conv:
-      if (rhsRank != 2 && rhsRank!= 3)
+      if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3)
         return;
       break;
     case Pool:
@@ -2129,6 +2247,15 @@ struct Conv1DGenerator
 
   /// Generate a vector implementation for:
   /// ```
+  ///   Op def: (     w,     kw  )
+  ///    Iters: ({Par(), Red()})
+  ///   Layout: {{w + kw}, {kw}, {w}}
+  /// ```
+  /// kw is always unrolled.
+  ///
+  /// or
+  ///
+  /// ```
   ///   Op def: (     n,     w,     c,    kw,    f  )
   ///    Iters: ({Par(), Par(), Par(), Red(), Red()})
   ///   Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
@@ -2142,7 +2269,21 @@ struct Conv1DGenerator
 
     int64_t nSize, wSize, cSize, kwSize, fSize;
     SmallVector<int64_t, 3> lhsShape, rhsShape, resShape;
+    bool isSingleChanneled = (conv1DOpOrder == Conv1DOpOrder::W);
     switch (conv1DOpOrder) {
+    case Conv1DOpOrder::W:
+      // Initialize unused dimensions
+      nSize = fSize = cSize = 0;
+      // out{W}
+      bindShapeDims(resShapedType, wSize);
+      // kernel{kw}
+      bindShapeDims(rhsShapedType, kwSize);
+      lhsShape = {// iw = ow + kw - 1
+                  //   (i.e. 16 convolved with 3 -> 14)
+                  (wSize + kwSize - 1)};
+      rhsShape = {kwSize};
+      resShape = {wSize};
+      break;
     case Conv1DOpOrder::Nwc:
       // out{n, w, f}
       bindShapeDims(resShapedType, nSize, wSize, fSize);
@@ -2220,24 +2361,27 @@ struct Conv1DGenerator
     auto lhsType = VectorType::get(lhsShape, lhsEltType);
     auto rhsType = VectorType::get(rhsShape, rhsEltType);
     auto resType = VectorType::get(resShape, resEltType);
-    // Read lhs slice of size {w * strideW + kw * dilationW, c, f} @ [0, 0,
-    // 0].
-    Value lhs = rewriter.create<vector::TransferReadOp>(
-        loc, lhsType, lhsShaped, ValueRange{zero, zero, zero});
-    // Read rhs slice of size {kw, c, f} @ [0, 0, 0].
+    // Zero padding with the corresponding dimensions for lhs, rhs and res.
+    SmallVector<Value> lhsPadding(lhsShape.size(), zero);
+    SmallVector<Value> rhsPadding(rhsShape.size(), zero);
+    SmallVector<Value> resPadding(resShape.size(), zero);
+
+    // Read the whole lhs, rhs and res in one shot (with zero padding).
+    Value lhs = rewriter.create<vector::TransferReadOp>(loc, lhsType, lhsShaped,
+                                                        lhsPadding);
     // This is needed only for Conv.
     Value rhs = nullptr;
     if (oper == Conv)
-      rhs = rewriter.create<vector::TransferReadOp>(
-          loc, rhsType, rhsShaped, ValueRange{zero, zero, zero});
-    // Read res slice of size {n, w, f} @ [0, 0, 0].
-    Value res = rewriter.create<vector::TransferReadOp>(
-        loc, resType, resShaped, ValueRange{zero, zero, zero});
-
-    // The base vectorization case is input: {n,w,c}, weight: {kw,c,f}, output:
-    // {n,w,f}. To reuse the base pattern vectorization case, we do pre
-    // transpose on input, weight, and output.
+      rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
+                                                    rhsPadding);
+    Value res = rewriter.create<vector::TransferReadOp>(loc, resType, resShaped,
+                                                        resPadding);
+
+    // The base vectorization case for channeled convolution is input: {n,w,c},
+    // weight: {kw,c,f}, output: {n,w,f}. To reuse the base pattern
+    // vectorization case, we do pre transpose on input, weight, and output.
     switch (conv1DOpOrder) {
+    case Conv1DOpOrder::W:
     case Conv1DOpOrder::Nwc:
       // Base case, so no transposes necessary.
       break;
@@ -2264,45 +2408,35 @@ struct Conv1DGenerator
     //===------------------------------------------------------------------===//
     // Unroll along kw and read slices of lhs and rhs.
     SmallVector<Value> lhsVals, rhsVals, resVals;
-    // Extract lhs slice of size {n, wSizeStep, c} @ [0, sw * w + dw * kw, 0].
-    for (int64_t kw = 0; kw < kwSize; ++kw) {
-      for (int64_t w = 0; w < wSize; w += wSizeStep) {
-        lhsVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
-            loc, lhs,
-            /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
-            /*sizes=*/ArrayRef<int64_t>{nSize, wSizeStep, cSize},
-            /*strides=*/ArrayRef<int64_t>{1, 1, 1}));
-      }
-    }
-    // Extract rhs slice of size {c, f} @ [kw].
+    lhsVals = extractConvInputSlices(rewriter, loc, lhs, nSize, wSize, cSize,
+                                     kwSize, strideW, dilationW, wSizeStep,
+                                     isSingleChanneled);
     // Do not do for pooling.
     if (oper == Conv)
-      for (int64_t kw = 0; kw < kwSize; ++kw) {
-        rhsVals.push_back(rewriter.create<vector::ExtractOp>(
-            loc, rhs, /*offsets=*/ArrayRef<int64_t>{kw}));
-      }
-    // Extract res slice: {n, wSizeStep, f} @ [0, w, 0].
-    for (int64_t w = 0; w < wSize; w += wSizeStep) {
-      resVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
-          loc, res,
-          /*offsets=*/ArrayRef<int64_t>{0, w, 0},
-          /*sizes=*/ArrayRef<int64_t>{nSize, wSizeStep, fSize},
-          /*strides=*/ArrayRef<int64_t>{1, 1, 1}));
-    }
+      rhsVals = extractConvFilterSlices(rewriter, loc, rhs, kwSize);
+    resVals = extractConvResultSlices(rewriter, loc, res, nSize, wSize, fSize,
+                                      wSizeStep, isSingleChanneled);
 
     auto linearIndex = [&](int64_t kw, int64_t w) {
       return kw * (wSize / wSizeStep) + w;
     };
 
     // Compute contraction: O{n, w, f} += I{n, sw * w + dw * kw, c} * F{c, f} or
+    // perform outerproduct for non-channeled convolution or
     // perform simple arith operation for pooling
     for (int64_t kw = 0; kw < kwSize; ++kw) {
       for (int64_t w = 0; w < wSize; w += wSizeStep) {
         switch (oper) {
         case Conv:
-          resVals[w] = conv1dSliceAsContraction(rewriter, loc,
-                                                lhsVals[linearIndex(kw, w)],
-                                                rhsVals[kw], resVals[w]);
+          if (isSingleChanneled) {
+            resVals[w] = conv1dSliceAsOuterProduct(rewriter, loc,
+                                                   lhsVals[linearIndex(kw, w)],
+                                                   rhsVals[kw], resVals[w]);
+          } else {
+            resVals[w] = conv1dSliceAsContraction(rewriter, loc,
+                                                  lhsVals[linearIndex(kw, w)],
+                                                  rhsVals[kw], resVals[w]);
+          }
           break;
         case Pool:
           resVals[w] = pool1dSlice(rewriter, loc, lhsVals[linearIndex(kw, w)],
@@ -2312,22 +2446,17 @@ struct Conv1DGenerator
       }
     }
 
-    // Write back res slice: {n, wSizeStep, f} @ [0, w, 0].
-    // This does not depend on kw.
-    for (int64_t w = 0; w < wSize; w += wSizeStep) {
-      res = rewriter.create<vector::InsertStridedSliceOp>(
-          loc, resVals[w], res,
-          /*offsets=*/ArrayRef<int64_t>{0, w, 0},
-          /*strides=*/ArrayRef<int64_t>{1, 1, 1});
-    }
+    res = insertConvResultSlices(rewriter, loc, res, wSize, wSizeStep, resVals,
+                                 isSingleChanneled);
     //===------------------------------------------------------------------===//
     // End vector-only rewrite part
     //===------------------------------------------------------------------===//
 
-    // The base vectorization case is output: {n,w,f}
+    // The base vectorization case for channeled convolution is output: {n,w,f}
     // To reuse the result from base pattern vectorization case, we post
     // transpose the base case result.
     switch (conv1DOpOrder) {
+    case Conv1DOpOrder::W:
     case Conv1DOpOrder::Nwc:
       // Base case, so no transposes necessary.
       break;
@@ -2339,10 +2468,8 @@ struct Conv1DGenerator
     }
     }
 
-    // Write back res slice of size {n, w, f} @ [0, 0, 0].
     return rewriter
-        .create<vector::TransferWriteOp>(loc, res, resShaped,
-                                         ValueRange{zero, zero, zero})
+        .create<vector::TransferWriteOp>(loc, res, resShaped, resPadding)
         .getOperation();
   }
 
@@ -2359,6 +2486,14 @@ struct Conv1DGenerator
         /*iteratorTypes=*/ArrayRef<vector::IteratorType>{par, par, par, red});
   }
 
+  // Create an outerproduct: lhs{w} * rhs{1} -> res{w} for single channel
+  // convolution.
+  Value conv1dSliceAsOuterProduct(RewriterBase &rewriter, Location loc,
+                                  Value lhs, Value rhs, Value res) {
+    return rewriter.create<vector::OuterProductOp>(
+        loc, res.getType(), lhs, rhs, res, vector::CombiningKind::ADD);
+  }
+
   // Create a reduction: lhs{n, w, c} -> res{n, w, c}
   Value pool1dSlice(RewriterBase &rewriter, Location loc, Value lhs,
                     Value res) {
@@ -2531,6 +2666,24 @@ struct Conv1DGenerator
     return rewriter.create<arith::AddIOp>(loc, mul, res);
   }
 
+  /// Entry point for non-channeled convolution:
+  ///   {{w + kw}, {kw}, {w}}
+  FailureOr<Operation *> generateNonChanneledConv() {
+    AffineExpr w, kw;
+    bindDims(ctx, w, kw);
+    if (!iters({Par(), Red()}))
+      return rewriter.notifyMatchFailure(op,
+                                         "failed to match conv::W 1-par 1-red");
+
+    // No transposition needed.
+    if (layout({/*lhsIndex*/ {w + kw},
+                /*rhsIndex*/ {kw},
+                /*resIndex*/ {w}}))
+      return conv(Conv1DOpOrder::W);
+
+    return rewriter.notifyMatchFailure(op, "not a conv::W layout");
+  }
+
   /// Entry point that transposes into the common form:
   ///   {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
   FailureOr<Operation *> generateNwcConv() {
@@ -2691,7 +2844,10 @@ static FailureOr<Operation *> vectorizeConvolution(RewriterBase &rewriter,
   auto stride = strides ? *strides.getValues<uint64_t>().begin() : 1;
   auto dilation = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
   Conv1DGenerator e(rewriter, op, stride, dilation);
-  auto res = e.generateNwcConv();
+  auto res = e.generateNonChanneledConv();
+  if (succeeded(res))
+    return res;
+  res = e.generateNwcConv();
   if (succeeded(res))
     return res;
   res = e.generateNcwConv();

diff  --git a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
index 91d822d804c05..db235fad7d2ee 100644
--- a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
@@ -461,6 +461,59 @@ func.func @conv1d_ncw_4x8x2_memref(%input: memref<4x3x6xf32>, %filter: memref<8x
 //      CHECK:   vector.transfer_write %[[RES]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
 
 
+// -----
+
+func.func @conv1d_8_tensor(%input: tensor<11xf32>, %filter: tensor<4xf32>, %output: tensor<8xf32>) -> tensor<8xf32> {
+  %0 = linalg.conv_1d ins(%input, %filter : tensor<11xf32>, tensor<4xf32>)
+                     outs(%output : tensor<8xf32>) -> tensor<8xf32>
+  return %0 : tensor<8xf32>
+}
+
+//      CHECK: func @conv1d_8_tensor
+// CHECK-SAME: (%[[INPUT:.+]]: tensor<11xf32>, %[[FILTER:.+]]: tensor<4xf32>, %[[OUTPUT:.+]]: tensor<8xf32>)
+
+//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//  CHECK-DAG:   %[[F0:.+]] = arith.constant 0.000000e+00 : f32
+
+/// Read the whole data in one shot.
+//  CHECK-DAG:   %[[V_INPUT_R:.+]] = vector.transfer_read %[[INPUT]][%[[C0]]], %[[F0]]
+//  CHECK-DAG:  %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]]], %[[F0]]
+//  CHECK-DAG:  %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]]], %[[F0]]
+
+//      CHECK:   %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
+// CHECK-SAME:     {offsets = [0], sizes = [8], strides = [1]} : vector<11xf32> to vector<8xf32>
+//      CHECK:   %[[V_INPUT_1:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
+// CHECK-SAME:     {offsets = [1], sizes = [8], strides = [1]} : vector<11xf32> to vector<8xf32>
+//      CHECK:   %[[V_INPUT_2:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
+// CHECK-SAME:     {offsets = [2], sizes = [8], strides = [1]} : vector<11xf32> to vector<8xf32>
+//      CHECK:   %[[V_INPUT_3:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
+// CHECK-SAME:     {offsets = [3], sizes = [8], strides = [1]} : vector<11xf32> to vector<8xf32>
+
+//      CHECK:  %[[V_FILTER_0:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<4xf32>
+//      CHECK:  %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][1] : vector<4xf32>
+//      CHECK:  %[[V_FILTER_2:.+]] = vector.extract %[[V_FILTER_R]][2] : vector<4xf32>
+//      CHECK:  %[[V_FILTER_3:.+]] = vector.extract %[[V_FILTER_R]][3] : vector<4xf32>
+
+/// w == 0, kw == 0
+//      CHECK:   %[[RES_0:.+]] = vector.outerproduct
+// CHECK-SAME:     %[[V_INPUT_0]], %[[V_FILTER_0]], %[[V_OUTPUT_R]] {kind = #vector.kind<add>}
+// CHECK-SAME:     : vector<8xf32>, f32
+/// w == 1, kw == 1
+//      CHECK:   %[[RES_1:.+]] = vector.outerproduct
+// CHECK-SAME:     %[[V_INPUT_1]], %[[V_FILTER_1]], %[[RES_0]] {kind = #vector.kind<add>}
+// CHECK-SAME:     : vector<8xf32>, f32
+/// w == 2, kw == 2
+//      CHECK:   %[[RES_2:.+]] = vector.outerproduct
+// CHECK-SAME:     %[[V_INPUT_2]], %[[V_FILTER_2]], %[[RES_1]] {kind = #vector.kind<add>}
+// CHECK-SAME:     : vector<8xf32>, f32
+/// w == 3, kw == 3
+//      CHECK:   %[[RES_3:.+]] = vector.outerproduct
+// CHECK-SAME:     %[[V_INPUT_3]], %[[V_FILTER_3]], %[[RES_2]] {kind = #vector.kind<add>}
+// CHECK-SAME:     : vector<8xf32>, f32
+
+// Write the result back in one shot.
+//      CHECK:   vector.transfer_write %[[RES_3]], %[[OUTPUT]][%[[C0]]]
+
 // -----
 
 func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref(%input: memref<3x5x4xf32>, %filter: memref<2x4xf32>, %output: memref<3x2x4xf32>) {


        


More information about the Mlir-commits mailing list