[Mlir-commits] [mlir] 755e776 - [mlir][linalg] Vectorize 1D convolution

Murali Vijayaraghavan llvmlistbot at llvm.org
Thu Jan 5 15:16:19 PST 2023


Author: Murali Vijayaraghavan
Date: 2023-01-05T23:08:32Z
New Revision: 755e776849be84e1a73dcf88c62a95346d2574f2

URL: https://github.com/llvm/llvm-project/commit/755e776849be84e1a73dcf88c62a95346d2574f2
DIFF: https://github.com/llvm/llvm-project/commit/755e776849be84e1a73dcf88c62a95346d2574f2.diff

LOG: [mlir][linalg] Vectorize 1D convolution

Differential Revision: https://reviews.llvm.org/D140188

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/test/Dialect/Linalg/vectorize-convolution.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index a3c5ecd2ba4fa..c0cede453470d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -385,8 +385,10 @@ mlir::linalg::getCombinerOpKind(Operation *combinerOp) {
           [&](auto op) { return CombiningKind::ADD; })
       .Case<arith::AndIOp>([&](auto op) { return CombiningKind::AND; })
       .Case<arith::MaxSIOp>([&](auto op) { return CombiningKind::MAXSI; })
+      .Case<arith::MaxUIOp>([&](auto op) { return CombiningKind::MAXUI; })
       .Case<arith::MaxFOp>([&](auto op) { return CombiningKind::MAXF; })
       .Case<arith::MinSIOp>([&](auto op) { return CombiningKind::MINSI; })
+      .Case<arith::MinUIOp>([&](auto op) { return CombiningKind::MINUI; })
       .Case<arith::MinFOp>([&](auto op) { return CombiningKind::MINF; })
       .Case<arith::MulIOp, arith::MulFOp>(
           [&](auto op) { return CombiningKind::MUL; })
@@ -1796,6 +1798,26 @@ static void bindShapeDims(ShapedType shapedType, IntTy &...vals) {
 }
 
 namespace {
+bool isCastOfBlockArgument(Operation *op) {
+  return isa<CastOpInterface>(op) && op->getNumOperands() == 1 &&
+         op->getOperand(0).isa<BlockArgument>();
+}
+
+bool isSupportedPoolKind(vector::CombiningKind kind) {
+  switch (kind) {
+  case vector::CombiningKind::ADD:
+  case vector::CombiningKind::MAXF:
+  case vector::CombiningKind::MAXSI:
+  case vector::CombiningKind::MAXUI:
+  case vector::CombiningKind::MINF:
+  case vector::CombiningKind::MINSI:
+  case vector::CombiningKind::MINUI:
+    return true;
+  default:
+    return false;
+  }
+}
+
 /// Generate a vector implementation for either:
 /// ```
 ///   Op def: (     n,     w,     c,    kw,    f  )
@@ -1838,41 +1860,33 @@ struct Conv1DGenerator
     resShapedType = resShaped.getType().dyn_cast<ShapedType>();
     if (!lhsShapedType || !rhsShapedType || !resShapedType)
       return;
-    if (lhsShapedType.getRank() != 3 ||
-        (rhsShapedType.getRank() != 2 && rhsShapedType.getRank() != 3) ||
-        resShapedType.getRank() != 3)
+    // LHS has dimension NCW/NWC and RES has dimension NFW/NCW/NWF/NWC.
+    if (lhsShapedType.getRank() != 3 || resShapedType.getRank() != 3)
       return;
 
-    // Check for reduction `add` preceded by `mul`.
     Operation *reduceOp = matchLinalgReduction(linalgOp.getDpsInitOperand(0));
     if (!reduceOp)
       return;
-    std::optional<vector::CombiningKind> maybeKind;
-    maybeKind = getCombinerOpKind(reduceOp);
-    if (!maybeKind || *maybeKind != vector::CombiningKind::ADD)
+    redOp = reduceOp->getName().getIdentifier();
+
+    if (!setOperKind(reduceOp))
       return;
-    // Check for single `mul` predecessor. The `mul` operands must be block
-    // arguments or extension of block arguments.
-    Operation *mulOp = nullptr;
-    for (Value operand : reduceOp->getOperands()) {
-      if (operand.isa<BlockArgument>())
-        continue;
-      if (mulOp)
-        return;
-      mulOp = operand.getDefiningOp();
-      if (!mulOp || !isa<arith::MulIOp, arith::MulFOp>(mulOp))
-        return;
-    }
-    if (!mulOp)
+    auto maybeKind = getCombinerOpKind(reduceOp);
+    if (!(maybeKind && (*maybeKind == vector::CombiningKind::ADD ||
+                        (oper == Pool && isSupportedPoolKind(*maybeKind))))) {
       return;
-    for (Value operand : mulOp->getOperands()) {
-      if (Operation *def = operand.getDefiningOp()) {
-        if (!isa<CastOpInterface>(def))
-          return;
-        operand = def->getOperand(0);
-      }
-      if (!operand.isa<BlockArgument>())
+    }
+
+    auto rhsRank = rhsShapedType.getRank();
+    switch (oper) {
+    case Conv:
+      if (rhsRank != 2 && rhsRank!= 3)
+        return;
+      break;
+    case Pool:
+      if (rhsRank != 1)
         return;
+      break;
     }
     // The op is now known to be valid.
     valid = true;
@@ -1889,16 +1903,25 @@ struct Conv1DGenerator
   /// > 1.
   FailureOr<Operation *> conv(Conv1DOpOrder conv1DOpOrder) {
     if (!valid)
-      return rewriter.notifyMatchFailure(op, "unvectorizable 1-D conv");
+      return rewriter.notifyMatchFailure(op, "unvectorizable 1-D conv/pool");
 
     int64_t nSize, wSize, cSize, kwSize, fSize;
     SmallVector<int64_t, 3> lhsShape, rhsShape, resShape;
     switch (conv1DOpOrder) {
     case Conv1DOpOrder::Nwc:
-      // kernel{kw, c, f}
-      bindShapeDims(rhsShapedType, kwSize, cSize, fSize);
       // out{n, w, f}
-      bindShapeDims(resShapedType, nSize, wSize);
+      bindShapeDims(resShapedType, nSize, wSize, fSize);
+      switch (oper) {
+      case Conv:
+        // kernel{kw, c, f}
+        bindShapeDims(rhsShapedType, kwSize, cSize);
+        break;
+      case Pool:
+        // kernel{kw}
+        bindShapeDims(rhsShapedType, kwSize);
+        cSize = fSize;
+        break;
+      }
       lhsShape = {nSize,
                   // iw = ow * sw + kw *  dw - 1
                   //   (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
@@ -1906,21 +1929,44 @@ struct Conv1DGenerator
                   ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
                       1,
                   cSize};
-      rhsShape = {kwSize, cSize, fSize};
+      switch (oper) {
+      case Conv:
+        rhsShape = {kwSize, cSize, fSize};
+        break;
+      case Pool:
+        rhsShape = {kwSize};
+        break;
+      }
       resShape = {nSize, wSize, fSize};
       break;
     case Conv1DOpOrder::Ncw:
-      // kernel{f, c, kw}
-      bindShapeDims(rhsShapedType, fSize, cSize, kwSize);
       // out{n, f, w}
       bindShapeDims(resShapedType, nSize, fSize, wSize);
+      switch (oper) {
+      case Conv:
+        // kernel{f, c, kw}
+        bindShapeDims(rhsShapedType, fSize, cSize, kwSize);
+        break;
+      case Pool:
+        // kernel{kw}
+        bindShapeDims(rhsShapedType, kwSize);
+        cSize = fSize;
+        break;
+      }
       lhsShape = {nSize, cSize,
                   // iw = ow * sw + kw *  dw - 1
                   //   (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
                   // Perform the proper inclusive -> exclusive -> inclusive.
                   ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
                       1};
-      rhsShape = {fSize, cSize, kwSize};
+      switch (oper) {
+      case Conv:
+        rhsShape = {fSize, cSize, kwSize};
+        break;
+      case Pool:
+        rhsShape = {kwSize};
+        break;
+      }
       resShape = {nSize, fSize, wSize};
       break;
     }
@@ -1944,8 +1990,11 @@ struct Conv1DGenerator
     Value lhs = rewriter.create<vector::TransferReadOp>(
         loc, lhsType, lhsShaped, ValueRange{zero, zero, zero});
     // Read rhs slice of size {kw, c, f} @ [0, 0, 0].
-    Value rhs = rewriter.create<vector::TransferReadOp>(
-        loc, rhsType, rhsShaped, ValueRange{zero, zero, zero});
+    // This is needed only for Conv.
+    Value rhs = nullptr;
+    if (oper == Conv)
+      rhs = rewriter.create<vector::TransferReadOp>(
+          loc, rhsType, rhsShaped, ValueRange{zero, zero, zero});
     // Read res slice of size {n, w, f} @ [0, 0, 0].
     Value res = rewriter.create<vector::TransferReadOp>(
         loc, resType, resShaped, ValueRange{zero, zero, zero});
@@ -1964,7 +2013,10 @@ struct Conv1DGenerator
       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, permLhs);
       // fcw -> wcf
       static constexpr std::array<int64_t, 3> permRhs = {2, 1, 0};
-      rhs = rewriter.create<vector::TransposeOp>(loc, rhs, permRhs);
+
+      // This is needed only for Conv.
+      if (oper == Conv)
+        rhs = rewriter.create<vector::TransposeOp>(loc, rhs, permRhs);
       // nfw -> nwf
       static constexpr std::array<int64_t, 3> permRes = {0, 2, 1};
       res = rewriter.create<vector::TransposeOp>(loc, res, permRes);
@@ -1988,10 +2040,12 @@ struct Conv1DGenerator
       }
     }
     // Extract rhs slice of size {c, f} @ [kw].
-    for (int64_t kw = 0; kw < kwSize; ++kw) {
-      rhsVals.push_back(rewriter.create<vector::ExtractOp>(
-          loc, rhs, /*offsets=*/ArrayRef<int64_t>{kw}));
-    }
+    // Do not do for pooling.
+    if (oper == Conv)
+      for (int64_t kw = 0; kw < kwSize; ++kw) {
+        rhsVals.push_back(rewriter.create<vector::ExtractOp>(
+            loc, rhs, /*offsets=*/ArrayRef<int64_t>{kw}));
+      }
     // Extract res slice: {n, wSizeStep, f} @ [0, w, 0].
     for (int64_t w = 0; w < wSize; w += wSizeStep) {
       resVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
@@ -2005,11 +2059,21 @@ struct Conv1DGenerator
       return kw * (wSize / wSizeStep) + w;
     };
 
-    // Compute contraction: O{n, w, f} += I{n, sw * w + dw * kw, c} * F{c, f}
+    // Compute contraction: O{n, w, f} += I{n, sw * w + dw * kw, c} * F{c, f} or
+    // perform simple arith operation for pooling
     for (int64_t kw = 0; kw < kwSize; ++kw) {
       for (int64_t w = 0; w < wSize; w += wSizeStep) {
-        resVals[w] = conv1dSliceAsContraction(
-            rewriter, loc, lhsVals[linearIndex(kw, w)], rhsVals[kw], resVals[w]);
+        switch (oper) {
+        case Conv:
+          resVals[w] = conv1dSliceAsContraction(rewriter, loc,
+                                                lhsVals[linearIndex(kw, w)],
+                                                rhsVals[kw], resVals[w]);
+          break;
+        case Pool:
+          resVals[w] = pool1dSlice(rewriter, loc, lhsVals[linearIndex(kw, w)],
+                                   resVals[w]);
+          break;
+        }
       }
     }
 
@@ -2060,6 +2124,16 @@ struct Conv1DGenerator
         /*iteratorTypes=*/ArrayRef<vector::IteratorType>{par, par, par, red});
   }
 
+  // Create a reduction: lhs{n, w, c} -> res{n, w, c}
+  Value pool1dSlice(RewriterBase &rewriter, Location loc, Value lhs,
+                    Value res) {
+    if (isPoolExt)
+      lhs = rewriter.create(loc, poolExtOp, lhs, res.getType())->getResult(0);
+    return rewriter
+        .create(loc, redOp, ArrayRef<Value>{lhs, res}, res.getType())
+        ->getResult(0);
+  }
+
   /// Generate a vector implementation for:
   /// ```
   ///   Op def: (     n,     w,     c,    kw)
@@ -2236,6 +2310,7 @@ struct Conv1DGenerator
                 /*rhsIndex*/ {kw, c, f},
                 /*resIndex*/ {n, w, f}}))
       return conv(Conv1DOpOrder::Nwc);
+
     return rewriter.notifyMatchFailure(op, "not a conv::Nwc layout");
   }
 
@@ -2256,6 +2331,41 @@ struct Conv1DGenerator
     return rewriter.notifyMatchFailure(op, "not a conv::Ncw layout");
   }
 
+  /// Entry point that transposes into the common form:
+  ///   {{n, strideW * w + dilationW * kw, c}, {kw}, {n, w, c}} for pooling
+  FailureOr<Operation *> generateNwcPooling() {
+    AffineExpr n, w, c, kw;
+    bindDims(ctx, n, w, c, kw);
+    if (!iters({Par(), Par(), Par(), Red()}))
+      return rewriter.notifyMatchFailure(op,
+                                         "failed to match pooling 3-par 1-red");
+
+    // No transposition needed.
+    if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
+                /*rhsIndex*/ {kw},
+                /*resIndex*/ {n, w, c}}))
+      return conv(Conv1DOpOrder::Nwc);
+
+    return rewriter.notifyMatchFailure(op, "not a pooling::Nwc layout");
+  }
+
+  /// Entry point that transposes into the common form:
+  ///   {{n, c, strideW * w + dilationW * kw}, {kw}, {n, c, w}} for pooling
+  FailureOr<Operation *> generateNcwPooling() {
+    AffineExpr n, w, c, kw;
+    bindDims(ctx, n, c, w, kw);
+    if (!iters({Par(), Par(), Par(), Red()}))
+      return rewriter.notifyMatchFailure(op,
+                                         "failed to match pooling 3-par 1-red");
+
+    if (layout({/*lhsIndex*/ {n, c, strideW * w + dilationW * kw},
+                /*rhsIndex*/ {kw},
+                /*resIndex*/ {n, c, w}}))
+      return conv(Conv1DOpOrder::Ncw);
+
+    return rewriter.notifyMatchFailure(op, "not a pooling::Ncw layout");
+  }
+
   /// Entry point that transposes into the common form:
   ///   {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
   FailureOr<Operation *> generateDilatedConv() {
@@ -2275,10 +2385,61 @@ struct Conv1DGenerator
   }
 
 private:
+  enum OperKind { Conv, Pool };
   bool valid = false;
+  OperKind oper = Conv;
+  StringAttr redOp;
+  StringAttr poolExtOp;
+  bool isPoolExt = false;
   int strideW, dilationW;
   Value lhsShaped, rhsShaped, resShaped;
   ShapedType lhsShapedType, rhsShapedType, resShapedType;
+
+  // Sets oper, poolExtOp and isPoolExt for valid conv/pooling ops.
+  // Returns true iff it is a valid conv/pooling op.
+  // If (region has 2 ops (reduction + yield) or 3 ops (extension + reduction
+  // + yield) and rhs is not used) then it is the body of a pooling
+  // If conv, check for single `mul` predecessor. The `mul` operands must be
+  // block arguments or extension of block arguments.
+  // Otherwise, check for one or zero `ext` predecessor. The `ext` operands
+  // must be block arguments or extension of block arguments.
+  bool setOperKind(Operation *reduceOp) {
+    int numBlockArguments =
+        llvm::count_if(reduceOp->getOperands(),
+                       [](Value v) { return v.isa<BlockArgument>(); });
+    switch (numBlockArguments) {
+    case 1: {
+      // Will be convolution if feeder is a MulOp.
+      // Otherwise, if it can be pooling.
+      auto feedValIt = llvm::find_if(reduceOp->getOperands(), [](Value v) {
+        return !v.isa<BlockArgument>();
+      });
+      Operation *feedOp = (*feedValIt).getDefiningOp();
+      if (isCastOfBlockArgument(feedOp)) {
+        oper = Pool;
+        isPoolExt = true;
+        poolExtOp = feedOp->getName().getIdentifier();
+      } else if (!(isa<arith::MulIOp, arith::MulFOp>(feedOp) &&
+                   llvm::all_of(feedOp->getOperands(), [](Value v) {
+                     if (v.isa<BlockArgument>())
+                       return true;
+                     if (Operation *op = v.getDefiningOp())
+                       return isCastOfBlockArgument(op);
+                     return false;
+                   }))) {
+        return false;
+      }
+      return true;
+    }
+    case 2:
+      // Must be pooling
+      oper = Pool;
+      isPoolExt = false;
+      return true;
+    default:
+      return false;
+    }
+  }
 };
 } // namespace
 
@@ -2299,6 +2460,12 @@ static FailureOr<Operation *> vectorizeConvolution(RewriterBase &rewriter,
   if (succeeded(res))
     return res;
   res = e.generateNcwConv();
+  if (succeeded(res))
+    return res;
+  res = e.generateNwcPooling();
+  if (succeeded(res))
+    return res;
+  res = e.generateNcwPooling();
   if (succeeded(res))
     return res;
   return e.generateDilatedConv();

diff  --git a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
index 207ef8e7522cf..6c2697aee6881 100644
--- a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
@@ -571,3 +571,282 @@ func.func @conv_1d_nwc_wcf_mixed_type_memref(%input: memref<1x2x3xf16>, %filter:
 //      CHECK:   %[[CONT:.*]] = vector.contract
 //        {{.*}} %[[V_INPUT_R]], %[[V_FILTER_1]], %[[V_OUTPUT_R]] : vector<1x2x3xf16>, vector<3x2xf16> into vector<1x2x2xf32>
 //      CHECK:   vector.transfer_write %[[CONT]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
+
+// -----
+
+func.func @pooling_nwc_sum_memref_1_2_1_3(%input: memref<4x4x3xf32>, %filter: memref<1xf32>, %output: memref<4x2x3xf32>) {
+  linalg.pooling_nwc_sum
+    {dilations = dense<1> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>}
+    ins(%input, %filter : memref<4x4x3xf32>, memref<1xf32>)
+    outs(%output : memref<4x2x3xf32>)
+  return
+}
+
+// CHECK-LABEL: func.func @pooling_nwc_sum_memref_1_2_1_3
+// CHECK-SAME: (%[[Varg0:.+]]: memref<4x4x3xf32>, %[[Varg1:.+]]: memref<1xf32>, %[[Varg2:.+]]: memref<4x2x3xf32>)
+// CHECK-DAG: %[[Vc0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[Vcst:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[V0:.+]] = vector.transfer_read %[[Varg0]][%[[Vc0]], %[[Vc0]], %[[Vc0]]], %[[Vcst]] {in_bounds = [true, true, true]} : memref<4x4x3xf32>, vector<4x4x3xf32>
+// CHECK: %[[V1:.+]] = vector.transfer_read %[[Varg2]][%[[Vc0]], %[[Vc0]], %[[Vc0]]], %[[Vcst]] {in_bounds = [true, true, true]} : memref<4x2x3xf32>, vector<4x2x3xf32>
+// CHECK: %[[V2:.+]] = vector.extract_strided_slice %[[V0]] {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x4x3xf32> to vector<4x1x3xf32>
+// CHECK: %[[V3:.+]] = vector.extract_strided_slice %[[V0]] {offsets = [0, 3, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x4x3xf32> to vector<4x1x3xf32>
+// CHECK: %[[V4:.+]] = vector.extract_strided_slice %[[V1]] {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x2x3xf32> to vector<4x1x3xf32>
+// CHECK: %[[V5:.+]] = vector.extract_strided_slice %[[V1]] {offsets = [0, 1, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x2x3xf32> to vector<4x1x3xf32>
+// CHECK: %[[V6:.+]] = arith.addf %[[V2]], %[[V4]] : vector<4x1x3xf32>
+// CHECK: %[[V7:.+]] = arith.addf %[[V3]], %[[V5]] : vector<4x1x3xf32>
+// CHECK: %[[V8:.+]] = vector.insert_strided_slice %[[V6]], %[[V1]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<4x1x3xf32> into vector<4x2x3xf32>
+// CHECK: %[[V9:.+]] = vector.insert_strided_slice %[[V7]], %[[V8]] {offsets = [0, 1, 0], strides = [1, 1, 1]} : vector<4x1x3xf32> into vector<4x2x3xf32>
+// CHECK: vector.transfer_write %[[V9]], %[[Varg2]][%[[Vc0]], %[[Vc0]], %[[Vc0]]] {in_bounds = [true, true, true]} : vector<4x2x3xf32>, memref<4x2x3xf32>
+
+// -----
+
+func.func @pooling_nwc_max_memref_1_2_1_3(%input: memref<4x4x3xf32>, %filter: memref<1xf32>, %output: memref<4x2x3xf32>) {
+  linalg.pooling_nwc_max
+    {dilations = dense<1> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>}
+    ins(%input, %filter : memref<4x4x3xf32>, memref<1xf32>)
+    outs(%output : memref<4x2x3xf32>)
+  return
+}
+
+// CHECK-LABEL: func.func @pooling_nwc_max_memref_1_2_1_3
+// CHECK-SAME: (%[[Varg0:.+]]: memref<4x4x3xf32>, %[[Varg1:.+]]: memref<1xf32>, %[[Varg2:.+]]: memref<4x2x3xf32>)
+// CHECK-DAG: %[[Vc0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[Vcst:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[V0:.+]] = vector.transfer_read %[[Varg0]][%[[Vc0]], %[[Vc0]], %[[Vc0]]], %[[Vcst]] {in_bounds = [true, true, true]} : memref<4x4x3xf32>, vector<4x4x3xf32>
+// CHECK: %[[V1:.+]] = vector.transfer_read %[[Varg2]][%[[Vc0]], %[[Vc0]], %[[Vc0]]], %[[Vcst]] {in_bounds = [true, true, true]} : memref<4x2x3xf32>, vector<4x2x3xf32>
+// CHECK: %[[V2:.+]] = vector.extract_strided_slice %[[V0]] {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x4x3xf32> to vector<4x1x3xf32>
+// CHECK: %[[V3:.+]] = vector.extract_strided_slice %[[V0]] {offsets = [0, 3, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x4x3xf32> to vector<4x1x3xf32>
+// CHECK: %[[V4:.+]] = vector.extract_strided_slice %[[V1]] {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x2x3xf32> to vector<4x1x3xf32>
+// CHECK: %[[V5:.+]] = vector.extract_strided_slice %[[V1]] {offsets = [0, 1, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x2x3xf32> to vector<4x1x3xf32>
+// CHECK: %[[V6:.+]] = arith.maxf %[[V2]], %[[V4]] : vector<4x1x3xf32>
+// CHECK: %[[V7:.+]] = arith.maxf %[[V3]], %[[V5]] : vector<4x1x3xf32>
+// CHECK: %[[V8:.+]] = vector.insert_strided_slice %[[V6]], %[[V1]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<4x1x3xf32> into vector<4x2x3xf32>
+// CHECK: %[[V9:.+]] = vector.insert_strided_slice %[[V7]], %[[V8]] {offsets = [0, 1, 0], strides = [1, 1, 1]} : vector<4x1x3xf32> into vector<4x2x3xf32>
+// CHECK: vector.transfer_write %[[V9]], %[[Varg2]][%[[Vc0]], %[[Vc0]], %[[Vc0]]] {in_bounds = [true, true, true]} : vector<4x2x3xf32>, memref<4x2x3xf32>
+
+// -----
+
+// The i8i8i32 case is similar to f32 case, so checking one case is enough for
+// test coverage.
+func.func @pooling_nwc_sum_i8i8i32_memref_1_2_1_3(%input: memref<4x4x3xi8>, %filter: memref<1xi8>, %output: memref<4x2x3xi32>) {
+  linalg.pooling_nwc_sum
+    {dilations = dense<1> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>}
+    ins(%input, %filter : memref<4x4x3xi8>, memref<1xi8>)
+    outs(%output : memref<4x2x3xi32>)
+  return
+}
+
+// CHECK-LABEL: func.func @pooling_nwc_sum_i8i8i32_memref_1_2_1_3
+// CHECK-SAME: (%[[Varg0:.+]]: memref<4x4x3xi8>, %[[Varg1:.+]]: memref<1xi8>, %[[Varg2:.+]]: memref<4x2x3xi32>)
+// CHECK-DAG: %[[Vc0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[Vc0_i8:.+]] = arith.constant 0 : i8
+// CHECK-DAG: %[[Vc0_i32:.+]] = arith.constant 0 : i32
+// CHECK: %[[V0:.+]] = vector.transfer_read %[[Varg0]][%[[Vc0]], %[[Vc0]], %[[Vc0]]], %[[Vc0_i8]] {in_bounds = [true, true, true]} : memref<4x4x3xi8>, vector<4x4x3xi8>
+// CHECK: %[[V1:.+]] = vector.transfer_read %[[Varg2]][%[[Vc0]], %[[Vc0]], %[[Vc0]]], %[[Vc0_i32]] {in_bounds = [true, true, true]} : memref<4x2x3xi32>, vector<4x2x3xi32>
+// CHECK: %[[V2:.+]] = vector.extract_strided_slice %[[V0]] {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x4x3xi8> to vector<4x1x3xi8>
+// CHECK: %[[V3:.+]] = vector.extract_strided_slice %[[V0]] {offsets = [0, 3, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x4x3xi8> to vector<4x1x3xi8>
+// CHECK: %[[V4:.+]] = vector.extract_strided_slice %[[V1]] {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x2x3xi32> to vector<4x1x3xi32>
+// CHECK: %[[V5:.+]] = vector.extract_strided_slice %[[V1]] {offsets = [0, 1, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x2x3xi32> to vector<4x1x3xi32>
+// CHECK: %[[V6:.+]] = arith.extsi %[[V2]] : vector<4x1x3xi8> to vector<4x1x3xi32>
+// CHECK: %[[V7:.+]] = arith.addi %[[V6]], %[[V4]] : vector<4x1x3xi32>
+// CHECK: %[[V8:.+]] = arith.extsi %[[V3]] : vector<4x1x3xi8> to vector<4x1x3xi32>
+// CHECK: %[[V9:.+]] = arith.addi %[[V8]], %[[V5]] : vector<4x1x3xi32>
+// CHECK: %[[V10:.+]] = vector.insert_strided_slice %[[V7]], %[[V1]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<4x1x3xi32> into vector<4x2x3xi32>
+// CHECK: %[[V11:.+]] = vector.insert_strided_slice %[[V9]], %[[V10]] {offsets = [0, 1, 0], strides = [1, 1, 1]} : vector<4x1x3xi32> into vector<4x2x3xi32>
+// CHECK: vector.transfer_write %[[V11]], %[[Varg2]][%[[Vc0]], %[[Vc0]], %[[Vc0]]] {in_bounds = [true, true, true]} : vector<4x2x3xi32>, memref<4x2x3xi32>
+// CHECK: return
+
+// -----
+
+// The i8i8i32 case is similar to f32 case, so checking one case is enough for
+// test coverage.
+func.func @pooling_nwc_max_i8i8i32_memref_1_2_1_3(%input: memref<4x4x3xi8>, %filter: memref<1xi8>, %output: memref<4x2x3xi32>) {
+  linalg.pooling_nwc_max
+    {dilations = dense<1> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>}
+    ins(%input, %filter : memref<4x4x3xi8>, memref<1xi8>)
+    outs(%output : memref<4x2x3xi32>)
+  return
+}
+
+// CHECK-LABEL: func.func @pooling_nwc_max_i8i8i32_memref_1_2_1_3
+// CHECK-SAME: (%[[Varg0:.+]]: memref<4x4x3xi8>, %[[Varg1:.+]]: memref<1xi8>, %[[Varg2:.+]]: memref<4x2x3xi32>)
+// CHECK-DAG: %[[Vc0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[Vc0_i8:.+]] = arith.constant 0 : i8
+// CHECK-DAG: %[[Vc0_i32:.+]] = arith.constant 0 : i32
+// CHECK: %[[V0:.+]] = vector.transfer_read %[[Varg0]][%[[Vc0]], %[[Vc0]], %[[Vc0]]], %[[Vc0_i8]] {in_bounds = [true, true, true]} : memref<4x4x3xi8>, vector<4x4x3xi8>
+// CHECK: %[[V1:.+]] = vector.transfer_read %[[Varg2]][%[[Vc0]], %[[Vc0]], %[[Vc0]]], %[[Vc0_i32]] {in_bounds = [true, true, true]} : memref<4x2x3xi32>, vector<4x2x3xi32>
+// CHECK: %[[V2:.+]] = vector.extract_strided_slice %[[V0]] {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x4x3xi8> to vector<4x1x3xi8>
+// CHECK: %[[V3:.+]] = vector.extract_strided_slice %[[V0]] {offsets = [0, 3, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x4x3xi8> to vector<4x1x3xi8>
+// CHECK: %[[V4:.+]] = vector.extract_strided_slice %[[V1]] {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x2x3xi32> to vector<4x1x3xi32>
+// CHECK: %[[V5:.+]] = vector.extract_strided_slice %[[V1]] {offsets = [0, 1, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x2x3xi32> to vector<4x1x3xi32>
+// CHECK: %[[V6:.+]] = arith.extsi %[[V2]] : vector<4x1x3xi8> to vector<4x1x3xi32>
+// CHECK: %[[V7:.+]] = arith.maxsi %[[V6]], %[[V4]] : vector<4x1x3xi32>
+// CHECK: %[[V8:.+]] = arith.extsi %[[V3]] : vector<4x1x3xi8> to vector<4x1x3xi32>
+// CHECK: %[[V9:.+]] = arith.maxsi %[[V8]], %[[V5]] : vector<4x1x3xi32>
+// CHECK: %[[V10:.+]] = vector.insert_strided_slice %[[V7]], %[[V1]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<4x1x3xi32> into vector<4x2x3xi32>
+// CHECK: %[[V11:.+]] = vector.insert_strided_slice %[[V9]], %[[V10]] {offsets = [0, 1, 0], strides = [1, 1, 1]} : vector<4x1x3xi32> into vector<4x2x3xi32>
+// CHECK: vector.transfer_write %[[V11]], %[[Varg2]][%[[Vc0]], %[[Vc0]], %[[Vc0]]] {in_bounds = [true, true, true]} : vector<4x2x3xi32>, memref<4x2x3xi32>
+// CHECK: return
+
+// -----
+
+func.func @pooling_nwc_sum_memref_2_2_2_3(%input: memref<4x6x3xf32>, %filter: memref<2xf32>, %output: memref<4x2x3xf32>) {
+  linalg.pooling_nwc_sum
+    {dilations = dense<2> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>}
+    ins(%input, %filter : memref<4x6x3xf32>, memref<2xf32>)
+    outs(%output : memref<4x2x3xf32>)
+  return
+}
+
+// CHECK-LABEL: func.func @pooling_nwc_sum_memref_2_2_2_3
+// CHECK-SAME: (%[[Varg0:.+]]: memref<4x6x3xf32>, %[[Varg1:.+]]: memref<2xf32>, %[[Varg2:.+]]: memref<4x2x3xf32>)
+// CHECK-DAG: %[[Vc0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[Vcst:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[V0:.+]] = vector.transfer_read %[[Varg0]][%[[Vc0]], %[[Vc0]], %[[Vc0]]], %[[Vcst]] {in_bounds = [true, true, true]} : memref<4x6x3xf32>, vector<4x6x3xf32>
+// CHECK: %[[V1:.+]] = vector.transfer_read %[[Varg2]][%[[Vc0]], %[[Vc0]], %[[Vc0]]], %[[Vcst]] {in_bounds = [true, true, true]} : memref<4x2x3xf32>, vector<4x2x3xf32>
+// CHECK: %[[V2:.+]] = vector.extract_strided_slice %[[V0]] {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x6x3xf32> to vector<4x1x3xf32>
+// CHECK: %[[V3:.+]] = vector.extract_strided_slice %[[V0]] {offsets = [0, 3, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x6x3xf32> to vector<4x1x3xf32>
+// CHECK: %[[V4:.+]] = vector.extract_strided_slice %[[V0]] {offsets = [0, 2, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x6x3xf32> to vector<4x1x3xf32>
+// CHECK: %[[V5:.+]] = vector.extract_strided_slice %[[V0]] {offsets = [0, 5, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x6x3xf32> to vector<4x1x3xf32>
+// CHECK: %[[V6:.+]] = vector.extract_strided_slice %[[V1]] {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x2x3xf32> to vector<4x1x3xf32>
+// CHECK: %[[V7:.+]] = vector.extract_strided_slice %[[V1]] {offsets = [0, 1, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x2x3xf32> to vector<4x1x3xf32>
+// CHECK: %[[V8:.+]] = arith.addf %[[V2]], %[[V6]] : vector<4x1x3xf32>
+// CHECK: %[[V9:.+]] = arith.addf %[[V3]], %[[V7]] : vector<4x1x3xf32>
+// CHECK: %[[V10:.+]] = arith.addf %[[V4]], %[[V8]] : vector<4x1x3xf32>
+// CHECK: %[[V11:.+]] = arith.addf %[[V5]], %[[V9]] : vector<4x1x3xf32>
+// CHECK: %[[V12:.+]] = vector.insert_strided_slice %[[V10]], %[[V1]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<4x1x3xf32> into vector<4x2x3xf32>
+// CHECK: %[[V13:.+]] = vector.insert_strided_slice %[[V11]], %[[V12]] {offsets = [0, 1, 0], strides = [1, 1, 1]} : vector<4x1x3xf32> into vector<4x2x3xf32>
+// CHECK: vector.transfer_write %[[V13:.+]], %[[Varg2]][%[[Vc0]], %[[Vc0]], %[[Vc0]]] {in_bounds = [true, true, true]} : vector<4x2x3xf32>, memref<4x2x3xf32>
+
+
+// -----
+
+func.func @pooling_ncw_sum_memref_1_2_1_3(%input: memref<4x3x4xf32>, %filter: memref<1xf32>, %output: memref<4x3x2xf32>) {
+  linalg.pooling_ncw_sum
+    {dilations = dense<1> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>}
+    ins(%input, %filter : memref<4x3x4xf32>, memref<1xf32>)
+    outs(%output : memref<4x3x2xf32>)
+  return
+}
+
+// CHECK-LABEL: func.func @pooling_ncw_sum_memref_1_2_1_3
+// CHECK-SAME: (%[[Varg0:.+]]: memref<4x3x4xf32>, %[[Varg1:.+]]: memref<1xf32>, %[[Varg2:.+]]: memref<4x3x2xf32>)
+// CHECK-DAG: %[[Vc0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[Vcst:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[V0:.+]] = vector.transfer_read %[[Varg0]][%[[Vc0]], %[[Vc0]], %[[Vc0]]], %[[Vcst]] {in_bounds = [true, true, true]} : memref<4x3x4xf32>, vector<4x3x4xf32>
+// CHECK: %[[V1:.+]] = vector.transfer_read %[[Varg2]][%[[Vc0]], %[[Vc0]], %[[Vc0]]], %[[Vcst]] {in_bounds = [true, true, true]} : memref<4x3x2xf32>, vector<4x3x2xf32>
+// CHECK: %[[V2:.+]] = vector.transpose %[[V0]], [0, 2, 1] : vector<4x3x4xf32> to vector<4x4x3xf32>
+// CHECK: %[[V3:.+]] = vector.transpose %[[V1]], [0, 2, 1] : vector<4x3x2xf32> to vector<4x2x3xf32>
+// CHECK: %[[V4:.+]] = vector.extract_strided_slice %[[V2]] {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x4x3xf32> to vector<4x1x3xf32>
+// CHECK: %[[V5:.+]] = vector.extract_strided_slice %[[V2]] {offsets = [0, 3, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x4x3xf32> to vector<4x1x3xf32>
+// CHECK: %[[V6:.+]] = vector.extract_strided_slice %[[V3]] {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x2x3xf32> to vector<4x1x3xf32>
+// CHECK: %[[V7:.+]] = vector.extract_strided_slice %[[V3]] {offsets = [0, 1, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x2x3xf32> to vector<4x1x3xf32>
+// CHECK: %[[V8:.+]] = arith.addf %[[V4]], %[[V6]] : vector<4x1x3xf32>
+// CHECK: %[[V9:.+]] = arith.addf %[[V5]], %[[V7]] : vector<4x1x3xf32>
+// CHECK: %[[V10:.+]] = vector.insert_strided_slice %[[V8]], %[[V3]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<4x1x3xf32> into vector<4x2x3xf32>
+// CHECK: %[[V11:.+]] = vector.insert_strided_slice %[[V9]], %[[V10]] {offsets = [0, 1, 0], strides = [1, 1, 1]} : vector<4x1x3xf32> into vector<4x2x3xf32>
+// CHECK: %[[V12:.+]] = vector.transpose %[[V11]], [0, 2, 1] : vector<4x2x3xf32> to vector<4x3x2xf32>
+// CHECK: vector.transfer_write %[[V12:.+]], %[[Varg2]][%[[Vc0]], %[[Vc0]], %[[Vc0]]] {in_bounds = [true, true, true]} : vector<4x3x2xf32>, memref<4x3x2xf32>
+
+
+// -----
+
+func.func @pooling_nwc_sum_mixed_type_memref_1_2_1_1(%input: memref<1x2x3xf16>, %filter: memref<1xf16>, %output: memref<1x2x3xf32>) {
+  linalg.pooling_nwc_sum
+  {dilations = dense<1> : vector<1xi64>, strides = dense<1> : vector<1xi64>}
+   ins(%input, %filter : memref<1x2x3xf16>, memref<1xf16>)
+   outs(%output : memref<1x2x3xf32>)
+  return
+}
+
+// CHECK-LABEL: func.func @pooling_nwc_sum_mixed_type_memref_1_2_1_1
+// CHECK-SAME: (%[[Varg0:.+]]: memref<1x2x3xf16>, %[[Varg1:.+]]: memref<1xf16>, %[[Varg2:.+]]: memref<1x2x3xf32>)
+// CHECK-DAG: %[[Vc0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[Vcst:.+]] = arith.constant 0.000000e+00 : f16
+// CHECK-DAG: %[[Vcst_0:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[V0:.+]] = vector.transfer_read %[[Varg0]][%[[Vc0]], %[[Vc0]], %[[Vc0]]], %[[Vcst]] {in_bounds = [true, true, true]} : memref<1x2x3xf16>, vector<1x2x3xf16>
+// CHECK: %[[V1:.+]] = vector.transfer_read %[[Varg2]][%[[Vc0]], %[[Vc0]], %[[Vc0]]], %[[Vcst_0]] {in_bounds = [true, true, true]} : memref<1x2x3xf32>, vector<1x2x3xf32>
+// CHECK: %[[V2:.+]] = arith.extf %[[V0]] : vector<1x2x3xf16> to vector<1x2x3xf32>
+// CHECK: %[[V3:.+]] = arith.addf %[[V2]], %[[V1]] : vector<1x2x3xf32>
+// CHECK: vector.transfer_write %[[V3:.+]], %[[Varg2]][%[[Vc0]], %[[Vc0]], %[[Vc0]]] {in_bounds = [true, true, true]} : vector<1x2x3xf32>, memref<1x2x3xf32>
+
+// -----
+
+func.func @pooling_nwc_sum_memref_2_2_2_1(%input: memref<4x4x3xf32>, %filter: memref<2xf32>, %output: memref<4x2x3xf32>) {
+  linalg.pooling_nwc_sum
+    {dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
+    ins(%input, %filter : memref<4x4x3xf32>, memref<2xf32>)
+    outs(%output : memref<4x2x3xf32>)
+  return
+}
+
+// CHECK-LABEL: func.func @pooling_nwc_sum_memref_2_2_2_1
+// CHECK-SAME: (%[[Varg0:.+]]: memref<4x4x3xf32>, %[[Varg1:.+]]: memref<2xf32>, %[[Varg2:.+]]: memref<4x2x3xf32>)
+// CHECK-DAG: %[[Vc0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[Vcst:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[V0:.+]] = vector.transfer_read %[[Varg0]][%[[Vc0]], %[[Vc0]], %[[Vc0]]], %[[Vcst]] {in_bounds = [true, true, true]} : memref<4x4x3xf32>, vector<4x4x3xf32>
+// CHECK: %[[V1:.+]] = vector.transfer_read %[[Varg2]][%[[Vc0]], %[[Vc0]], %[[Vc0]]], %[[Vcst]] {in_bounds = [true, true, true]} : memref<4x2x3xf32>, vector<4x2x3xf32>
+// CHECK: %[[V2:.+]] = vector.extract_strided_slice %[[V0]] {offsets = [0, 0, 0], sizes = [4, 2, 3], strides = [1, 1, 1]} : vector<4x4x3xf32> to vector<4x2x3xf32>
+// CHECK: %[[V3:.+]] = vector.extract_strided_slice %[[V0]] {offsets = [0, 2, 0], sizes = [4, 2, 3], strides = [1, 1, 1]} : vector<4x4x3xf32> to vector<4x2x3xf32>
+// CHECK: %[[V4:.+]] = arith.addf %[[V2]], %[[V1]] : vector<4x2x3xf32>
+// CHECK: %[[V5:.+]] = arith.addf %[[V3]], %[[V4]] : vector<4x2x3xf32>
+// CHECK: vector.transfer_write %[[V5:.+]], %[[Varg2]][%[[Vc0]], %[[Vc0]], %[[Vc0]]] {in_bounds = [true, true, true]} : vector<4x2x3xf32>, memref<4x2x3xf32>
+
+
+// -----
+
+func.func @pooling_ncw_sum_memref_2_2_2_3(%input: memref<4x3x6xf32>, %filter: memref<2xf32>, %output: memref<4x3x2xf32>) {
+  linalg.pooling_ncw_sum
+    {dilations = dense<2> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>}
+    ins(%input, %filter : memref<4x3x6xf32>, memref<2xf32>)
+    outs(%output : memref<4x3x2xf32>)
+  return
+}
+
+// CHECK-LABEL: func.func @pooling_ncw_sum_memref_2_2_2_3
+// CHECK-SAME: (%[[Varg0:.+]]: memref<4x3x6xf32>, %[[Varg1:.+]]: memref<2xf32>, %[[Varg2:.+]]: memref<4x3x2xf32>)
+// CHECK-DAG: %[[Vc0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[Vcst:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[V0:.+]] = vector.transfer_read %[[Varg0]][%[[Vc0]], %[[Vc0]], %[[Vc0]]], %[[Vcst]] {in_bounds = [true, true, true]} : memref<4x3x6xf32>, vector<4x3x6xf32>
+// CHECK: %[[V1:.+]] = vector.transfer_read %[[Varg2]][%[[Vc0]], %[[Vc0]], %[[Vc0]]], %[[Vcst]] {in_bounds = [true, true, true]} : memref<4x3x2xf32>, vector<4x3x2xf32>
+// CHECK: %[[V2:.+]] = vector.transpose %[[V0]], [0, 2, 1] : vector<4x3x6xf32> to vector<4x6x3xf32>
+// CHECK: %[[V3:.+]] = vector.transpose %[[V1]], [0, 2, 1] : vector<4x3x2xf32> to vector<4x2x3xf32>
+// CHECK: %[[V4:.+]] = vector.extract_strided_slice %[[V2]] {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x6x3xf32> to vector<4x1x3xf32>
+// CHECK: %[[V5:.+]] = vector.extract_strided_slice %[[V2]] {offsets = [0, 3, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x6x3xf32> to vector<4x1x3xf32>
+// CHECK: %[[V6:.+]] = vector.extract_strided_slice %[[V2]] {offsets = [0, 2, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x6x3xf32> to vector<4x1x3xf32>
+// CHECK: %[[V7:.+]] = vector.extract_strided_slice %[[V2]] {offsets = [0, 5, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x6x3xf32> to vector<4x1x3xf32>
+// CHECK: %[[V8:.+]] = vector.extract_strided_slice %[[V3]] {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x2x3xf32> to vector<4x1x3xf32>
+// CHECK: %[[V9:.+]] = vector.extract_strided_slice %[[V3]] {offsets = [0, 1, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x2x3xf32> to vector<4x1x3xf32>
+// CHECK: %[[V10:.+]] = arith.addf %[[V4]], %[[V8]] : vector<4x1x3xf32>
+// CHECK: %[[V11:.+]] = arith.addf %[[V5]], %[[V9]] : vector<4x1x3xf32>
+// CHECK: %[[V12:.+]] = arith.addf %[[V6]], %[[V10]] : vector<4x1x3xf32>
+// CHECK: %[[V13:.+]] = arith.addf %[[V7]], %[[V11]] : vector<4x1x3xf32>
+// CHECK: %[[V14:.+]] = vector.insert_strided_slice %[[V12]], %[[V3]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<4x1x3xf32> into vector<4x2x3xf32>
+// CHECK: %[[V15:.+]] = vector.insert_strided_slice %[[V13]], %[[V14]] {offsets = [0, 1, 0], strides = [1, 1, 1]} : vector<4x1x3xf32> into vector<4x2x3xf32>
+// CHECK: %[[V16:.+]] = vector.transpose %[[V15]], [0, 2, 1] : vector<4x2x3xf32> to vector<4x3x2xf32>
+// CHECK: vector.transfer_write %[[V16:.+]], %[[Varg2]][%[[Vc0]], %[[Vc0]], %[[Vc0]]] {in_bounds = [true, true, true]} : vector<4x3x2xf32>, memref<4x3x2xf32>
+
+// -----
+
+func.func @pooling_ncw_sum_memref_2_3_2_1(%input: memref<4x2x5xf32>, %filter: memref<2xf32>, %output: memref<4x2x3xf32>) {
+  linalg.pooling_ncw_sum
+    {dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
+    ins(%input, %filter : memref<4x2x5xf32>, memref<2xf32>)
+    outs(%output : memref<4x2x3xf32>)
+  return
+}
+
+// CHECK-LABEL: func.func @pooling_ncw_sum_memref_2_3_2_1
+// CHECK-SAME: (%[[Varg0:.+]]: memref<4x2x5xf32>, %[[Varg1:.+]]: memref<2xf32>, %[[Varg2:.+]]: memref<4x2x3xf32>)
+// CHECK-DAG: %[[Vc0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[Vcst:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[V0:.+]] = vector.transfer_read %[[Varg0]][%[[Vc0]], %[[Vc0]], %[[Vc0]]], %[[Vcst]] {in_bounds = [true, true, true]} : memref<4x2x5xf32>, vector<4x2x5xf32>
+// CHECK: %[[V1:.+]] = vector.transfer_read %[[Varg2]][%[[Vc0]], %[[Vc0]], %[[Vc0]]], %[[Vcst]] {in_bounds = [true, true, true]} : memref<4x2x3xf32>, vector<4x2x3xf32>
+// CHECK: %[[V2:.+]] = vector.transpose %[[V0]], [0, 2, 1] : vector<4x2x5xf32> to vector<4x5x2xf32>
+// CHECK: %[[V3:.+]] = vector.transpose %[[V1]], [0, 2, 1] : vector<4x2x3xf32> to vector<4x3x2xf32>
+// CHECK: %[[V4:.+]] = vector.extract_strided_slice %[[V2]] {offsets = [0, 0, 0], sizes = [4, 3, 2], strides = [1, 1, 1]} : vector<4x5x2xf32> to vector<4x3x2xf32>
+// CHECK: %[[V5:.+]] = vector.extract_strided_slice %[[V2]] {offsets = [0, 2, 0], sizes = [4, 3, 2], strides = [1, 1, 1]} : vector<4x5x2xf32> to vector<4x3x2xf32>
+// CHECK: %[[V6:.+]] = arith.addf %[[V4]], %[[V3]] : vector<4x3x2xf32>
+// CHECK: %[[V7:.+]] = arith.addf %[[V5]], %[[V6]] : vector<4x3x2xf32>
+// CHECK: %[[V8:.+]] = vector.transpose %[[V7]], [0, 2, 1] : vector<4x3x2xf32> to vector<4x2x3xf32>
+// CHECK: vector.transfer_write %[[V8:.+]], %[[Varg2]][%[[Vc0]], %[[Vc0]], %[[Vc0]]] {in_bounds = [true, true, true]} : vector<4x2x3xf32>, memref<4x2x3xf32>


        


More information about the Mlir-commits mailing list