[Mlir-commits] [mlir] [mlir][linalg][conv] Flatten the channel dimension when vectorizing (PR #71918)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Wed Dec 6 10:41:55 PST 2023
https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/71918
>From 2ddab31d43b462ebf6de6cf550fe66763b088227 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Sun, 8 Oct 2023 19:53:04 +0100
Subject: [PATCH 1/4] [mlir][linalg][conv] Flatten the channel dimension when
vectorizing
The current vectorization of 1D depthwise convolutions in Linalg is
_sub-optimal_ for tensor with a low number of channel dimensions, e.g.:
```mlir
linalg.depthwise_conv_1d_nwc_wc
{dilations = dense<1> : vector<1xi64>,
strides = dense<1> : vector<1xi64>}
ins(%input, %filter : tensor<1x8x3xi8>, tensor<1x3xi8>)
outs(%output : tensor<1x8x3xi8>) -> tensor<1x8x3xi8>
```
That's due to the fact that ultimately (i.e. at LLVM level),
vectorization happens along the trailing dimension (i.e. the channel
dimension). In this case it leads to vectors with 3 elements (or worse,
if there's e.g. only 1 channel dimension). For comparison, a 128 bit
wide vector registers can hold 16 x i8.
Instead, this patch adds an option to flatten/collapse the channel
dimension into the width dimension of the input/output:
```mlir
%collapsed = tensor.collapse_shape %input [[0], [1, 2]] : tensor<1x8x3xi8> into tensor<1x24xi8>
%collapsed_0 = tensor.collapse_shape %outpu [[0], [1, 2]] : tensor<1x8x3xi8> into tensor<1x24xi8>
```
(Note that for this to work, the filter is broadcast rather than
re-shaped. Please see the test cases for details).
The new vectorization strategy is implemented in `depthwiseConvFlatten`,
which was implemented based on `depthwiseConvGeneric` (i.e. the original
vectorization hook). The current vectorization is preserved and kept as
the default option. New vectorization can be selected through e.g. a
transform dialect attribute:
```mlir
transform.structured.vectorize_children_and_apply_patterns %conv {flatten_1d_depthwise_conv}
```
A forthcoming patch will implement a strategy to automatically switch
between the two implementations, depending on the shape of the input
tensors.
Co-authored by: Bradley Smith <bradley.smith at arm.com>
---
.../Linalg/TransformOps/LinalgTransformOps.td | 4 +-
.../Dialect/Linalg/Transforms/Transforms.h | 3 +-
.../TransformOps/LinalgTransformOps.cpp | 21 +-
.../Linalg/Transforms/Vectorization.cpp | 228 +++++++++++++++++-
.../Linalg/vectorize-convolution-flatten.mlir | 222 +++++++++++++++++
.../Dialect/Linalg/vectorize-convolution.mlir | 12 +-
6 files changed, 466 insertions(+), 24 deletions(-)
create mode 100644 mlir/test/Dialect/Linalg/vectorize-convolution-flatten.mlir
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 002926ff965fd..de65f3176c46a 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -2038,6 +2038,7 @@ def VectorizeChildrenAndApplyPatternsOp :
let arguments = (ins TransformHandleTypeInterface:$target,
UnitAttr:$vectorize_padding,
UnitAttr:$vectorize_nd_extract,
+ UnitAttr:$flatten_1d_depthwise_conv,
UnitAttr:$disable_multi_reduction_to_contract_patterns,
UnitAttr:$disable_transfer_permutation_map_lowering_patterns);
let results = (outs TransformHandleTypeInterface:$transformed);
@@ -2049,7 +2050,8 @@ def VectorizeChildrenAndApplyPatternsOp :
let builders = [
OpBuilder<(ins "Value":$target,
CArg<"bool", "false">:$vectorizePadding,
- CArg<"bool", "false">:$vectorizeNDExtract)>,
+ CArg<"bool", "false">:$vectorizeNDExtract,
+ CArg<"bool", "false">:$flatten1DDepthwise)>
];
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 6c4e16bd94f47..3f4dfe42b71fd 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -753,7 +753,8 @@ LogicalResult deallocateGPUPrivateMemory(OpBuilder &, Value /*buffer*/);
LogicalResult vectorize(RewriterBase &rewriter, Operation *op,
ArrayRef<int64_t> inputVectorSizes = {},
ArrayRef<bool> inputScalableVecDims = {},
- bool vectorizeNDExtract = false);
+ bool vectorizeNDExtract = false,
+ bool flatten1DDepthwiseConv = false);
/// Emit a suitable vector form for a Copy op with fully static shape.
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp);
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 14404d837ff74..279a7c487b7d5 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2943,7 +2943,7 @@ LogicalResult TileUsingForallOp::verify() {
void transform::VectorizeChildrenAndApplyPatternsOp::build(
OpBuilder &builder, OperationState &result, Value target,
- bool vectorizePadding, bool vectorizeExtract) {
+ bool vectorizePadding, bool vectorizeExtract, bool flatten1DDepthwiseConv) {
result.addOperands(target);
if (vectorizePadding) {
result.addAttribute(
@@ -2957,6 +2957,12 @@ void transform::VectorizeChildrenAndApplyPatternsOp::build(
result.name),
builder.getUnitAttr());
}
+ if (flatten1DDepthwiseConv) {
+ result.addAttribute(
+ VectorizeChildrenAndApplyPatternsOp::getFlatten_1dDepthwiseConvAttrName(
+ result.name),
+ builder.getUnitAttr());
+ }
result.addTypes(transform::AnyOpType::get(builder.getContext()));
}
@@ -2965,22 +2971,26 @@ namespace {
/// VectorizeChildrenAndApplyPatternsOp::applyToOne.
struct VectorizationPattern : public RewritePattern {
explicit VectorizationPattern(MLIRContext *context,
- bool vectorizeExtract = false)
+ bool vectorizeExtract = false,
+ bool flattenConv = false)
: RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context),
- vectorizeNDExtract(vectorizeExtract) {}
+ vectorizeNDExtract(vectorizeExtract),
+ flatten1DDepthwiseConv(flattenConv) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
if (!linalgOp)
return rewriter.notifyMatchFailure(op, "expected Linalg Op");
return vectorize(rewriter, linalgOp, /*inputVectorSizes=*/{},
- /*scalableVecDims=*/{}, vectorizeNDExtract);
+ /*scalableVecDims=*/{}, vectorizeNDExtract,
+ flatten1DDepthwiseConv);
}
private:
/// Controls whether to vectorize `tensor.extract` when the input tensor is
/// rank >= 2.
bool vectorizeNDExtract = false;
+ bool flatten1DDepthwiseConv = false;
};
} // namespace
@@ -2997,7 +3007,8 @@ transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
MLIRContext *ctx = getContext();
RewritePatternSet patterns(ctx);
- patterns.add<VectorizationPattern>(ctx, getVectorizeNdExtract());
+ patterns.add<VectorizationPattern>(ctx, getVectorizeNdExtract(),
+ getFlatten_1dDepthwiseConv());
if (!getDisableTransferPermutationMapLoweringPatterns())
vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index f9a53a8451a60..0a947af87144b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -44,8 +44,9 @@ using namespace mlir::linalg;
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
/// Try to vectorize `convOp` as a convolution.
-static FailureOr<Operation *> vectorizeConvolution(RewriterBase &rewriter,
- LinalgOp convOp);
+static FailureOr<Operation *>
+vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp,
+ bool flatten1DDepthwiseConv = false);
/// Return the unique instance of OpType in `block` if it is indeed unique.
/// Return null if none or more than 1 instances exist.
@@ -1664,7 +1665,8 @@ static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) {
LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
ArrayRef<int64_t> inputVectorSizes,
ArrayRef<bool> inputScalableVecDims,
- bool vectorizeNDExtract) {
+ bool vectorizeNDExtract,
+ bool flatten1DDepthwiseConv) {
LDBG("Attempting to vectorize:\n" << *op << "\n");
LDBG("Input vector sizes: ");
LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
@@ -1696,8 +1698,8 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
// TODO: isaConvolutionOpInterface that can also infer from generic
// features. Will require stride/dilation attributes inference.
if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
- FailureOr<Operation *> convOr =
- vectorizeConvolution(rewriter, linalgOp);
+ FailureOr<Operation *> convOr = vectorizeConvolution(
+ rewriter, linalgOp, flatten1DDepthwiseConv);
if (succeeded(convOr)) {
llvm::append_range(results, (*convOr)->getResults());
return success();
@@ -2822,7 +2824,7 @@ struct Conv1DGenerator
/// kw is always unrolled.
/// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
/// > 1.
- FailureOr<Operation *> depthwiseConv() {
+ FailureOr<Operation *> depthwiseConvGeneric() {
if (!valid)
return rewriter.notifyMatchFailure(op, "unvectorizable depthwise conv");
@@ -2936,6 +2938,176 @@ struct Conv1DGenerator
.getOperation();
}
+ /// Generate a vector implementation for ("flatten channel dim"):
+ /// ```
+ /// Op def: ( n, w, c, kw)
+ /// Iters: ({Par(), Par(), Par(), Red()})
+ /// Layout: {{n, 1 * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
+ /// ```
+ /// c of the input/output is collapsed with w. kw is always unrolled and
+ /// broadcast to match w.
+ ///
+ /// TODO: Add support for non-unit stride/dilation
+ /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
+ /// > 1.
+ FailureOr<Operation *> depthwiseConvFlatten() {
+ if (!valid)
+ return rewriter.notifyMatchFailure(op, "unvectorizable depthwise conv");
+
+ int64_t nSize, iSize, wSize, cSize, kwSize;
+ // kernel{kw, c}
+ bindShapeDims(rhsShapedType, kwSize, cSize);
+ // out{n, w, c}
+ bindShapeDims(resShapedType, nSize, wSize);
+ // in{n, w, c}
+ bindShapeDims(lhsShapedType, nSize, iSize);
+
+ vector::TransferWriteOp write;
+ Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+
+ if (strideW == 1)
+ return rewriter.notifyMatchFailure(
+ op, "Non-unit strides are not supported yet");
+ if (dilationW == 1)
+ return rewriter.notifyMatchFailure(
+ op, "Non-unit dilations are not supported yet");
+
+ Type lhsEltType = lhsShapedType.getElementType();
+ Type rhsEltType = rhsShapedType.getElementType();
+ Type resEltType = resShapedType.getElementType();
+ VectorType rhsType = VectorType::get({kwSize, cSize}, rhsEltType);
+ VectorType lhsType = VectorType::get(
+ {nSize,
+ // iw = (ow * sw + kw * dw - 1) * c
+ // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
+ (((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1) *
+ cSize},
+ lhsEltType);
+
+ VectorType resType = VectorType::get({nSize, wSize * cSize}, resEltType);
+
+ Value res, lhs, lhsFlat, resFlat;
+ // Read rhs slice of size {kw, c} @ [0, 0].
+ Value rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
+ ValueRange{zero, zero});
+
+ SmallVector<ReassociationIndices> reassociation = {{0}, {1, 2}};
+
+ // Flatten w and c dimensions
+ auto lhsTypeCollapsed = VectorType::get({nSize, iSize * cSize}, lhsEltType);
+ auto linalgOp = dyn_cast<LinalgOp>(op);
+ lhsFlat =
+ linalgOp.hasTensorSemantics()
+ ? (Value)rewriter.create<tensor::CollapseShapeOp>(
+ loc,
+ RankedTensorType::get(lhsTypeCollapsed.getShape(),
+ lhsEltType),
+ lhsShaped, reassociation)
+ : (Value)rewriter.create<memref::CollapseShapeOp>(
+ loc, MemRefType::get(lhsTypeCollapsed.getShape(), lhsEltType),
+ lhsShaped, reassociation);
+ resFlat =
+ linalgOp.hasTensorSemantics()
+ ? (Value)rewriter.create<tensor::CollapseShapeOp>(
+ loc, RankedTensorType::get(resType.getShape(), resEltType),
+ resShaped, reassociation)
+ : (Value)rewriter.create<memref::CollapseShapeOp>(
+ loc, MemRefType::get(resType.getShape(), resEltType),
+ resShaped, reassociation);
+
+ // Read lhs slice of size {n, (w * wSize + kw * dilationW) * c} @ [0,
+ // 0].
+ lhs = rewriter.create<vector::TransferReadOp>(loc, lhsType, lhsFlat,
+ ValueRange{zero, zero});
+ // Read res slice of size {n, w * c} @ [0, 0].
+ res = rewriter.create<vector::TransferReadOp>(loc, resType, resFlat,
+ ValueRange{zero, zero});
+
+ //===------------------------------------------------------------------===//
+ // Begin vector-only rewrite part
+ //===------------------------------------------------------------------===//
+ // Unroll along kw and read slices of lhs and rhs.
+ SmallVector<Value> lhsVals, rhsVals, resVals;
+ // Extract lhs slice of size {n, wSizeStep * c}
+ // @ [0, (sw * w + dw * kw) * cSize].
+ for (int64_t kw = 0; kw < kwSize; ++kw) {
+ for (int64_t w = 0; w < wSize; w += wSize) {
+ lhsVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
+ loc, lhs,
+ /*offsets=*/
+ ArrayRef<int64_t>{0, (w * wSize + kw * dilationW) * cSize},
+ /*sizes=*/ArrayRef<int64_t>{nSize, wSize * cSize},
+ /*strides=*/ArrayRef<int64_t>{1, 1}));
+ }
+ }
+ // Extract rhs slice of size {c} @ [kw].
+ for (int64_t kw = 0; kw < kwSize; ++kw) {
+ rhsVals.push_back(rewriter.create<vector::ExtractOp>(
+ loc, rhs, /*offsets=*/ArrayRef<int64_t>{kw}));
+ }
+
+ // Extract res slice
+ // Flattened case: {n, wSizeStep * c} @ [0, w].
+ for (int64_t w = 0; w < wSize; w += wSize) {
+ resVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
+ loc, res,
+ /*offsets=*/ArrayRef<int64_t>{0, w * cSize},
+ /*sizes=*/ArrayRef<int64_t>{nSize, wSize * cSize},
+ /*strides=*/ArrayRef<int64_t>{1, 1}));
+ }
+
+ auto linearIndex = [&](int64_t kw, int64_t w) {
+ return kw * (wSize / wSize) + w;
+ };
+
+ // Compute contraction:
+ // O{n, w * c} += I{n, (sw * w + dw * kw) * c} * F{c}
+ for (int64_t kw = 0; kw < kwSize; ++kw) {
+ for (int64_t w = 0; w < wSize; w += wSize) {
+ resVals[w] = depthwiseConv1dFlatSliceAsMulAcc(
+ rewriter, loc, lhsVals[linearIndex(kw, w)], rhsVals[kw],
+ resVals[w]);
+ }
+ }
+
+ // Its possible we failed to create the Fma.
+ if (!llvm::all_of(resVals, [](Value v) { return v; })) {
+ // Manually revert (in reverse order) to avoid leaving a bad IR state.
+ for (auto &collection :
+ {resVals, rhsVals, lhsVals, {res, rhs, lhs, zero}})
+ for (Value v : collection)
+ rewriter.eraseOp(v.getDefiningOp());
+ return rewriter.notifyMatchFailure(op, "failed to create FMA");
+ }
+
+ // Write back res slice. This does not depend on kw.
+ // Flattened case: {n, wSizeStep * c} @ [0, w].
+ for (int64_t w = 0; w < wSize; w += wSize) {
+ res = rewriter.create<vector::InsertStridedSliceOp>(
+ loc, resVals[w], res,
+ /*offsets=*/ArrayRef<int64_t>{0, w * cSize},
+ /*strides=*/ArrayRef<int64_t>{1, 1});
+ }
+ //===------------------------------------------------------------------===//
+ // End vector-only rewrite part
+ //===------------------------------------------------------------------===//
+ // Write back res slice of size {n, w * c} @ [0, 0].
+ mlir::vector::TransferWriteOp tWrite =
+ rewriter.create<vector::TransferWriteOp>(loc, res, resFlat,
+ ValueRange{zero, zero});
+
+ // A tensor has to be re-shaped back to it's original shape ...
+ if (linalgOp.hasTensorSemantics())
+ // Re-expand shape
+ return rewriter
+ .create<tensor::ExpandShapeOp>(loc, resShapedType, tWrite.getResult(),
+ reassociation)
+ .getOperation();
+ /// ... memrefs don't requie reshaping (re-shape is just a different view
+ /// into the same memref)
+ return tWrite.getOperation();
+ }
+
/// Lower lhs{n, w, c} * rhs{c} -> res{n, w, c} to MulAcc
Value depthwiseConv1dSliceAsMulAcc(RewriterBase &rewriter, Location loc,
Value lhs, Value rhs, Value res) {
@@ -2959,6 +3131,39 @@ struct Conv1DGenerator
return rewriter.create<arith::AddIOp>(loc, mul, res);
}
+ /// Lower lhs{n, w * c} * rhs{c} -> res{n, w * c} to MulAcc
+ Value depthwiseConv1dFlatSliceAsMulAcc(RewriterBase &rewriter, Location loc,
+ Value lhs, Value rhs, Value res) {
+ auto rhsTy = rhs.getType().cast<ShapedType>();
+ auto resTy = res.getType().cast<ShapedType>();
+
+ lhs = promote(rewriter, loc, lhs, resTy);
+
+ auto rhsSize = rhs.getType().cast<VectorType>().getShape()[0];
+ auto resSize = res.getType().cast<VectorType>().getShape()[1];
+
+ SmallVector<int64_t, 16> indicies;
+ for (int i = 0; i < resSize / rhsSize; ++i) {
+ for (int j = 0; j < rhsSize; ++j)
+ indicies.push_back(j);
+ }
+
+ rhs = rewriter.create<vector::ShuffleOp>(loc, rhs, rhs, indicies);
+
+ rhs = rewriter.create<vector::BroadcastOp>(
+ loc, resTy.clone(rhsTy.getElementType()), rhs);
+ rhs = promote(rewriter, loc, rhs, resTy);
+
+ if (!lhs || !rhs)
+ return nullptr;
+
+ if (resTy.getElementType().isa<FloatType>())
+ return rewriter.create<vector::FMAOp>(loc, lhs, rhs, res);
+
+ auto mul = rewriter.create<arith::MulIOp>(loc, lhs, rhs);
+ return rewriter.create<arith::AddIOp>(loc, mul, res);
+ }
+
/// Entry point for non-channeled convolution:
/// {{w + kw}, {kw}, {w}}
FailureOr<Operation *> generateNonChanneledConv() {
@@ -3049,7 +3254,7 @@ struct Conv1DGenerator
/// Entry point that transposes into the common form:
/// {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
- FailureOr<Operation *> generateDilatedConv() {
+ FailureOr<Operation *> generateDilatedConv(bool flatten = false) {
AffineExpr n, w, c, kw;
bindDims(ctx, n, w, c, kw);
if (!iters({Par(), Par(), Par(), Red()}))
@@ -3060,7 +3265,7 @@ struct Conv1DGenerator
if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
/*rhsIndex*/ {kw, c},
/*resIndex*/ {n, w, c}}))
- return depthwiseConv();
+ return flatten ? depthwiseConvFlatten() : depthwiseConvGeneric();
return rewriter.notifyMatchFailure(op, "not a depthwise::Nwc layout");
}
@@ -3125,8 +3330,9 @@ struct Conv1DGenerator
/// Helper function to vectorize a LinalgOp with convolution semantics.
// TODO: extend the generic vectorization to support windows and drop this.
-static FailureOr<Operation *> vectorizeConvolution(RewriterBase &rewriter,
- LinalgOp op) {
+static FailureOr<Operation *>
+vectorizeConvolution(RewriterBase &rewriter, LinalgOp op,
+ bool flatten1DDepthwiseConv) {
// The ConvolutionOpInterface gives us guarantees of existence for
// strides/dilations. However, we do not need to rely on those, we can simply
// use them if present, otherwise use the default and let the generic conv.
@@ -3151,7 +3357,7 @@ static FailureOr<Operation *> vectorizeConvolution(RewriterBase &rewriter,
res = e.generateNcwPooling();
if (succeeded(res))
return res;
- return e.generateDilatedConv();
+ return e.generateDilatedConv(flatten1DDepthwiseConv);
}
struct VectorizeConvolution : public OpInterfaceRewritePattern<LinalgOp> {
diff --git a/mlir/test/Dialect/Linalg/vectorize-convolution-flatten.mlir b/mlir/test/Dialect/Linalg/vectorize-convolution-flatten.mlir
new file mode 100644
index 0000000000000..6b0f920bfa42e
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/vectorize-convolution-flatten.mlir
@@ -0,0 +1,222 @@
+// RUN: mlir-opt -split-input-file -transform-interpreter %s | FileCheck %s
+
+func.func @flatten_tensor(%input: tensor<1x8x3xi8>, %filter: tensor<1x3xi8>, %output: tensor<1x8x3xi8>) -> (tensor<1x8x3xi8>) {
+ %res = linalg.depthwise_conv_1d_nwc_wc
+ {dilations = dense<1> : vector<1xi64>,
+ strides = dense<1> : vector<1xi64>}
+ ins(%input, %filter : tensor<1x8x3xi8>, tensor<1x3xi8>)
+ outs(%output : tensor<1x8x3xi8>) -> tensor<1x8x3xi8>
+ return %res : tensor<1x8x3xi8>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+ %2 = transform.structured.vectorize_children_and_apply_patterns %1 {flatten_1d_depthwise_conv} : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// CHECK-LABEL: func.func @flatten_tensor(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x8x3xi8>,
+// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x3xi8>,
+// CHECK-SAME: %[[VAL_2:.*]]: tensor<1x8x3xi8>) -> tensor<1x8x3xi8> {
+// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_4:.*]] = arith.constant 0 : i8
+// CHECK: %[[VAL_5:.*]] = vector.transfer_read %[[VAL_1]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : tensor<1x3xi8>, vector<1x3xi8>
+// CHECK: %[[VAL_6:.*]] = tensor.collapse_shape %[[VAL_0]] {{\[\[}}0], [1, 2]] : tensor<1x8x3xi8> into tensor<1x24xi8>
+// CHECK: %[[VAL_7:.*]] = tensor.collapse_shape %[[VAL_2]] {{\[\[}}0], [1, 2]] : tensor<1x8x3xi8> into tensor<1x24xi8>
+// CHECK: %[[VAL_8:.*]] = vector.transfer_read %[[VAL_6]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : tensor<1x24xi8>, vector<1x24xi8>
+// CHECK: %[[VAL_9:.*]] = vector.transfer_read %[[VAL_7]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : tensor<1x24xi8>, vector<1x24xi8>
+// CHECK: %[[VAL_10:.*]] = vector.extract %[[VAL_5]][0] : vector<3xi8> from vector<1x3xi8>
+// CHECK: %[[VAL_11:.*]] = vector.shuffle %[[VAL_10]], %[[VAL_10]] [0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2] : vector<3xi8>, vector<3xi8>
+// CHECK: %[[VAL_12:.*]] = vector.broadcast %[[VAL_11]] : vector<24xi8> to vector<1x24xi8>
+// CHECK: %[[VAL_13:.*]] = arith.muli %[[VAL_8]], %[[VAL_12]] : vector<1x24xi8>
+// CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_13]], %[[VAL_9]] : vector<1x24xi8>
+// CHECK: %[[VAL_15:.*]] = vector.transfer_write %[[VAL_14]], %[[VAL_7]]{{\[}}%[[VAL_3]], %[[VAL_3]]] {in_bounds = [true, true]} : vector<1x24xi8>, tensor<1x24xi8>
+// CHECK: %[[VAL_16:.*]] = tensor.expand_shape %[[VAL_15]] {{\[\[}}0], [1, 2]] : tensor<1x24xi8> into tensor<1x8x3xi8>
+// CHECK: return %[[VAL_16]] : tensor<1x8x3xi8>
+// CHECK: }
+
+//------
+
+func.func @flatten_memref(%input: memref<1x8x3xi8>, %filter: memref<1x3xi8>, %output: memref<1x8x3xi8>) {
+ linalg.depthwise_conv_1d_nwc_wc
+ {dilations = dense<1> : vector<1xi64>,
+ strides = dense<1> : vector<1xi64>}
+ ins(%input, %filter : memref<1x8x3xi8>, memref<1x3xi8>)
+ outs(%output : memref<1x8x3xi8>)
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+ %2 = transform.structured.vectorize_children_and_apply_patterns %1 {flatten_1d_depthwise_conv} : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// CHECK-LABEL: func.func @flatten_memref(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<1x8x3xi8>,
+// CHECK-SAME: %[[VAL_1:.*]]: memref<1x3xi8>,
+// CHECK-SAME: %[[VAL_2:.*]]: memref<1x8x3xi8>) {
+// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_4:.*]] = arith.constant 0 : i8
+// CHECK: %[[VAL_5:.*]] = vector.transfer_read %[[VAL_1]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : memref<1x3xi8>, vector<1x3xi8>
+// CHECK: %[[VAL_6:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0], [1, 2]] : memref<1x8x3xi8> into memref<1x24xi8>
+// CHECK: %[[VAL_7:.*]] = memref.collapse_shape %[[VAL_2]] {{\[\[}}0], [1, 2]] : memref<1x8x3xi8> into memref<1x24xi8>
+// CHECK: %[[VAL_8:.*]] = vector.transfer_read %[[VAL_6]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : memref<1x24xi8>, vector<1x24xi8>
+// CHECK: %[[VAL_9:.*]] = vector.transfer_read %[[VAL_7]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : memref<1x24xi8>, vector<1x24xi8>
+// CHECK: %[[VAL_10:.*]] = vector.extract %[[VAL_5]][0] : vector<3xi8> from vector<1x3xi8>
+// CHECK: %[[VAL_11:.*]] = vector.shuffle %[[VAL_10]], %[[VAL_10]] [0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2] : vector<3xi8>, vector<3xi8>
+// CHECK: %[[VAL_12:.*]] = vector.broadcast %[[VAL_11]] : vector<24xi8> to vector<1x24xi8>
+// CHECK: %[[VAL_13:.*]] = arith.muli %[[VAL_8]], %[[VAL_12]] : vector<1x24xi8>
+// CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_13]], %[[VAL_9]] : vector<1x24xi8>
+// CHECK: vector.transfer_write %[[VAL_14]], %[[VAL_7]]{{\[}}%[[VAL_3]], %[[VAL_3]]] {in_bounds = [true, true]} : vector<1x24xi8>, memref<1x24xi8>
+// CHECK: return
+// CHECK: }
+
+// -----
+
+func.func @flatten_memref_wider_filter(%input: memref<1x8x3xi8>, %filter: memref<2x3xi8>, %output: memref<1x7x3xi8>) {
+ linalg.depthwise_conv_1d_nwc_wc
+ {dilations = dense<1> : vector<1xi64>,
+ strides = dense<1> : vector<1xi64>}
+ ins(%input, %filter : memref<1x8x3xi8>, memref<2x3xi8>)
+ outs(%output : memref<1x7x3xi8>)
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+ %2 = transform.structured.vectorize_children_and_apply_patterns %1 {flatten_1d_depthwise_conv} : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// CHECK-LABEL: func.func @flatten_memref_wider_filter(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<1x8x3xi8>,
+// CHECK-SAME: %[[VAL_1:.*]]: memref<2x3xi8>,
+// CHECK-SAME: %[[VAL_2:.*]]: memref<1x7x3xi8>) {
+// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_4:.*]] = arith.constant 0 : i8
+// CHECK: %[[VAL_5:.*]] = vector.transfer_read %[[VAL_1]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : memref<2x3xi8>, vector<2x3xi8>
+// CHECK: %[[VAL_6:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0], [1, 2]] : memref<1x8x3xi8> into memref<1x24xi8>
+// CHECK: %[[VAL_7:.*]] = memref.collapse_shape %[[VAL_2]] {{\[\[}}0], [1, 2]] : memref<1x7x3xi8> into memref<1x21xi8>
+// CHECK: %[[VAL_8:.*]] = vector.transfer_read %[[VAL_6]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : memref<1x24xi8>, vector<1x24xi8>
+// CHECK: %[[VAL_9:.*]] = vector.transfer_read %[[VAL_7]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : memref<1x21xi8>, vector<1x21xi8>
+// CHECK: %[[VAL_10:.*]] = vector.extract_strided_slice %[[VAL_8]] {offsets = [0, 0], sizes = [1, 21], strides = [1, 1]} : vector<1x24xi8> to vector<1x21xi8>
+// CHECK: %[[VAL_11:.*]] = vector.extract_strided_slice %[[VAL_8]] {offsets = [0, 3], sizes = [1, 21], strides = [1, 1]} : vector<1x24xi8> to vector<1x21xi8>
+// CHECK: %[[VAL_12:.*]] = vector.extract %[[VAL_5]][0] : vector<3xi8> from vector<2x3xi8>
+// CHECK: %[[VAL_13:.*]] = vector.extract %[[VAL_5]][1] : vector<3xi8> from vector<2x3xi8>
+// CHECK: %[[VAL_14:.*]] = vector.shuffle %[[VAL_12]], %[[VAL_12]] [0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2] : vector<3xi8>, vector<3xi8>
+// CHECK: %[[VAL_15:.*]] = vector.broadcast %[[VAL_14]] : vector<21xi8> to vector<1x21xi8>
+// CHECK: %[[VAL_16:.*]] = arith.muli %[[VAL_10]], %[[VAL_15]] : vector<1x21xi8>
+// CHECK: %[[VAL_17:.*]] = arith.addi %[[VAL_16]], %[[VAL_9]] : vector<1x21xi8>
+// CHECK: %[[VAL_18:.*]] = vector.shuffle %[[VAL_13]], %[[VAL_13]] [0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2] : vector<3xi8>, vector<3xi8>
+// CHECK: %[[VAL_19:.*]] = vector.broadcast %[[VAL_18]] : vector<21xi8> to vector<1x21xi8>
+// CHECK: %[[VAL_20:.*]] = arith.muli %[[VAL_11]], %[[VAL_19]] : vector<1x21xi8>
+// CHECK: %[[VAL_21:.*]] = arith.addi %[[VAL_20]], %[[VAL_17]] : vector<1x21xi8>
+// CHECK: vector.transfer_write %[[VAL_21]], %[[VAL_7]]{{\[}}%[[VAL_3]], %[[VAL_3]]] {in_bounds = [true, true]} : vector<1x21xi8>, memref<1x21xi8>
+// CHECK: return
+// CHECK: }
+
+// -----
+
+func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref(%input: memref<3x5x4xf32>, %filter: memref<2x4xf32>, %output: memref<3x2x4xf32>) {
+ linalg.depthwise_conv_1d_nwc_wc
+ {dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
+ ins(%input, %filter : memref<3x5x4xf32>, memref<2x4xf32>)
+ outs(%output : memref<3x2x4xf32>)
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+ %2 = transform.structured.vectorize_children_and_apply_patterns %1 {flatten_1d_depthwise_conv} : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// CHECK-LABEL: func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<3x5x4xf32>,
+// CHECK-SAME: %[[VAL_1:.*]]: memref<2x4xf32>,
+// CHECK-SAME: %[[VAL_2:.*]]: memref<3x2x4xf32>) {
+// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_4:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[VAL_5:.*]] = vector.transfer_read %[[VAL_1]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : memref<2x4xf32>, vector<2x4xf32>
+// CHECK: %[[VAL_6:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0], [1, 2]] : memref<3x5x4xf32> into memref<3x20xf32>
+// CHECK: %[[VAL_7:.*]] = memref.collapse_shape %[[VAL_2]] {{\[\[}}0], [1, 2]] : memref<3x2x4xf32> into memref<3x8xf32>
+// CHECK: %[[VAL_8:.*]] = vector.transfer_read %[[VAL_6]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : memref<3x20xf32>, vector<3x16xf32>
+// CHECK: %[[VAL_9:.*]] = vector.transfer_read %[[VAL_7]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : memref<3x8xf32>, vector<3x8xf32>
+// CHECK: %[[VAL_10:.*]] = vector.extract_strided_slice %[[VAL_8]] {offsets = [0, 0], sizes = [3, 8], strides = [1, 1]} : vector<3x16xf32> to vector<3x8xf32>
+// CHECK: %[[VAL_11:.*]] = vector.extract_strided_slice %[[VAL_8]] {offsets = [0, 8], sizes = [3, 8], strides = [1, 1]} : vector<3x16xf32> to vector<3x8xf32>
+// CHECK: %[[VAL_12:.*]] = vector.extract %[[VAL_5]][0] : vector<4xf32> from vector<2x4xf32>
+// CHECK: %[[VAL_13:.*]] = vector.extract %[[VAL_5]][1] : vector<4xf32> from vector<2x4xf32>
+// CHECK: %[[VAL_14:.*]] = vector.shuffle %[[VAL_12]], %[[VAL_12]] [0, 1, 2, 3, 0, 1, 2, 3] : vector<4xf32>, vector<4xf32>
+// CHECK: %[[VAL_15:.*]] = vector.broadcast %[[VAL_14]] : vector<8xf32> to vector<3x8xf32>
+// CHECK: %[[VAL_16:.*]] = vector.fma %[[VAL_10]], %[[VAL_15]], %[[VAL_9]] : vector<3x8xf32>
+// CHECK: %[[VAL_17:.*]] = vector.shuffle %[[VAL_13]], %[[VAL_13]] [0, 1, 2, 3, 0, 1, 2, 3] : vector<4xf32>, vector<4xf32>
+// CHECK: %[[VAL_18:.*]] = vector.broadcast %[[VAL_17]] : vector<8xf32> to vector<3x8xf32>
+// CHECK: %[[VAL_19:.*]] = vector.fma %[[VAL_11]], %[[VAL_18]], %[[VAL_16]] : vector<3x8xf32>
+// CHECK: vector.transfer_write %[[VAL_19]], %[[VAL_7]]{{\[}}%[[VAL_3]], %[[VAL_3]]] {in_bounds = [true, true]} : vector<3x8xf32>, memref<3x8xf32>
+// CHECK: return
+// CHECK: }
+
+// -----
+
+func.func @depthwise_conv1d_nwc_wc_3x5x4xi8_memref(%input: memref<3x5x4xi8>, %filter: memref<2x4xi8>, %output: memref<3x2x4xi32>) {
+ linalg.depthwise_conv_1d_nwc_wc
+ {dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
+ ins(%input, %filter : memref<3x5x4xi8>, memref<2x4xi8>)
+ outs(%output : memref<3x2x4xi32>)
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+ %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// CHECK-LABEL: func.func @depthwise_conv1d_nwc_wc_3x5x4xi8_memref(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<3x5x4xi8>,
+// CHECK-SAME: %[[VAL_1:.*]]: memref<2x4xi8>,
+// CHECK-SAME: %[[VAL_2:.*]]: memref<3x2x4xi32>) {
+// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_4:.*]] = arith.constant 0 : i8
+// CHECK: %[[VAL_5:.*]] = arith.constant 0 : i32
+// CHECK: %[[VAL_6:.*]] = vector.transfer_read %[[VAL_1]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : memref<2x4xi8>, vector<2x4xi8>
+// CHECK: %[[VAL_7:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0], [1, 2]] : memref<3x5x4xi8> into memref<3x20xi8>
+// CHECK: %[[VAL_8:.*]] = memref.collapse_shape %[[VAL_2]] {{\[\[}}0], [1, 2]] : memref<3x2x4xi32> into memref<3x8xi32>
+// CHECK: %[[VAL_9:.*]] = vector.transfer_read %[[VAL_7]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : memref<3x20xi8>, vector<3x16xi8>
+// CHECK: %[[VAL_10:.*]] = vector.transfer_read %[[VAL_8]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_5]] {in_bounds = [true, true]} : memref<3x8xi32>, vector<3x8xi32>
+// CHECK: %[[VAL_11:.*]] = vector.extract_strided_slice %[[VAL_9]] {offsets = [0, 0], sizes = [3, 8], strides = [1, 1]} : vector<3x16xi8> to vector<3x8xi8>
+// CHECK: %[[VAL_12:.*]] = vector.extract_strided_slice %[[VAL_9]] {offsets = [0, 8], sizes = [3, 8], strides = [1, 1]} : vector<3x16xi8> to vector<3x8xi8>
+// CHECK: %[[VAL_13:.*]] = vector.extract %[[VAL_6]][0] : vector<4xi8> from vector<2x4xi8>
+// CHECK: %[[VAL_14:.*]] = vector.extract %[[VAL_6]][1] : vector<4xi8> from vector<2x4xi8>
+// CHECK: %[[VAL_15:.*]] = arith.extsi %[[VAL_11]] : vector<3x8xi8> to vector<3x8xi32>
+// CHECK: %[[VAL_16:.*]] = vector.shuffle %[[VAL_13]], %[[VAL_13]] [0, 1, 2, 3, 0, 1, 2, 3] : vector<4xi8>, vector<4xi8>
+// CHECK: %[[VAL_17:.*]] = vector.broadcast %[[VAL_16]] : vector<8xi8> to vector<3x8xi8>
+// CHECK: %[[VAL_18:.*]] = arith.extsi %[[VAL_17]] : vector<3x8xi8> to vector<3x8xi32>
+// CHECK: %[[VAL_19:.*]] = arith.muli %[[VAL_15]], %[[VAL_18]] : vector<3x8xi32>
+// CHECK: %[[VAL_20:.*]] = arith.addi %[[VAL_19]], %[[VAL_10]] : vector<3x8xi32>
+// CHECK: %[[VAL_21:.*]] = arith.extsi %[[VAL_12]] : vector<3x8xi8> to vector<3x8xi32>
+// CHECK: %[[VAL_22:.*]] = vector.shuffle %[[VAL_14]], %[[VAL_14]] [0, 1, 2, 3, 0, 1, 2, 3] : vector<4xi8>, vector<4xi8>
+// CHECK: %[[VAL_23:.*]] = vector.broadcast %[[VAL_22]] : vector<8xi8> to vector<3x8xi8>
+// CHECK: %[[VAL_24:.*]] = arith.extsi %[[VAL_23]] : vector<3x8xi8> to vector<3x8xi32>
+// CHECK: %[[VAL_25:.*]] = arith.muli %[[VAL_21]], %[[VAL_24]] : vector<3x8xi32>
+// CHECK: %[[VAL_26:.*]] = arith.addi %[[VAL_25]], %[[VAL_20]] : vector<3x8xi32>
+// CHECK: vector.transfer_write %[[VAL_26]], %[[VAL_8]]{{\[}}%[[VAL_3]], %[[VAL_3]]] {in_bounds = [true, true]} : vector<3x8xi32>, memref<3x8xi32>
+// CHECK: return
+// CHECK: }
diff --git a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
index 93e36a69567bd..59b242d530442 100644
--- a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
@@ -535,9 +535,9 @@ func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref(%input: memref<3x5x4xf32>, %
// CHECK-DAG: %[[F0:.+]] = arith.constant 0.000000e+00 : f32
/// Read the whole data in one shot.
-// CHECK: %[[V_INPUT_R:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]]
-// CHECK: %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]]]
-// CHECK: %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
+// CHECK-DAG: %[[V_INPUT_R:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]]
+// CHECK-DAG: %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]]]
+// CHECK-DAG: %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
// CHECK: %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
// CHECK-SAME: {offsets = [0, 0, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x4xf32> to vector<3x2x4xf32>
@@ -575,9 +575,9 @@ func.func @depthwise_conv1d_nwc_wc_3x5x4xi8_memref(%input: memref<3x5x4xi8>, %fi
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
/// Read the whole data in one shot.
-// CHECK: %[[V_INPUT_R:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]]
-// CHECK: %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]]]
-// CHECK: %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
+// CHECK-DAG: %[[V_INPUT_R:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]]
+// CHECK-DAG: %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]]]
+// CHECK-DAG: %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
// CHECK: %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
// CHECK-SAME: {offsets = [0, 0, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x4xi8> to vector<3x2x4xi8>
>From 02415b5de2c054f158870d04eaa913b701ba8713 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Wed, 15 Nov 2023 18:38:07 +0000
Subject: [PATCH 2/4] fixup! [mlir][linalg][conv] Flatten the channel dimension
when vectorizing
Following on from Nicolas' observation, this commit refactors the
implementation to simply replace:
```
resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc,
lhsVals[linearIndex(kw, w)],
rhsVals[kw], resVals[w]);
```
with shape_cast + depthwiseConv1dSliceAsMulAcc + shape_cast.
---
.../Linalg/Transforms/Vectorization.cpp | 252 +++---------------
.../Linalg/vectorize-convolution-flatten.mlir | 162 +++++------
2 files changed, 118 insertions(+), 296 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 0a947af87144b..c18e5d627a9de 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -2824,7 +2824,7 @@ struct Conv1DGenerator
/// kw is always unrolled.
/// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
/// > 1.
- FailureOr<Operation *> depthwiseConvGeneric() {
+ FailureOr<Operation *> depthwiseConvGeneric(bool flatten) {
if (!valid)
return rewriter.notifyMatchFailure(op, "unvectorizable depthwise conv");
@@ -2871,6 +2871,9 @@ struct Conv1DGenerator
//===------------------------------------------------------------------===//
// Unroll along kw and read slices of lhs and rhs.
SmallVector<Value> lhsVals, rhsVals, resVals;
+ auto inOutSliceSizes = SmallVector<int64_t>{nSize, wSizeStep, cSize};
+ auto inOutStrides = SmallVector<int64_t>{1, 1, 1};
+
// Extract lhs slice of size {n, wSizeStep, c}
// @ [0, sw * w + dw * kw, 0].
for (int64_t kw = 0; kw < kwSize; ++kw) {
@@ -2878,8 +2881,7 @@ struct Conv1DGenerator
lhsVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
loc, lhs,
/*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
- /*sizes=*/ArrayRef<int64_t>{nSize, wSizeStep, cSize},
- /*strides=*/ArrayRef<int64_t>{1, 1, 1}));
+ inOutSliceSizes, inOutStrides));
}
}
// Extract rhs slice of size {c} @ [kw].
@@ -2891,21 +2893,35 @@ struct Conv1DGenerator
for (int64_t w = 0; w < wSize; w += wSizeStep) {
resVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
loc, res,
- /*offsets=*/ArrayRef<int64_t>{0, w, 0},
- /*sizes=*/ArrayRef<int64_t>{nSize, wSizeStep, cSize},
- /*strides=*/ArrayRef<int64_t>{1, 1, 1}));
+ /*offsets=*/ArrayRef<int64_t>{0, w, 0}, inOutSliceSizes,
+ inOutStrides));
}
auto linearIndex = [&](int64_t kw, int64_t w) {
return kw * (wSize / wSizeStep) + w;
};
+ auto inOutFlattenSliceSizes =
+ SmallVector<int64_t>{nSize, wSizeStep * cSize};
+ auto lhsCastType = VectorType::get(inOutFlattenSliceSizes, lhsEltType);
+ auto resCastType = VectorType::get(inOutFlattenSliceSizes, lhsEltType);
// Compute contraction: O{n, w, c} += I{n, sw * w + dw * kw, c} * F{c}
for (int64_t kw = 0; kw < kwSize; ++kw) {
for (int64_t w = 0; w < wSize; w += wSizeStep) {
- resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc,
- lhsVals[linearIndex(kw, w)],
- rhsVals[kw], resVals[w]);
+ Value lhsVal = lhsVals[linearIndex(kw, w)];
+ Value resVal = resVals[w];
+ ShapedType filterBCastTy = cast<ShapedType>(resVal.getType());
+ if (flatten) {
+ lhsVal = rewriter.create<vector::ShapeCastOp>(
+ loc, lhsCastType, lhsVals[linearIndex(kw, w)]);
+ resVal = rewriter.create<vector::ShapeCastOp>(loc, resCastType,
+ resVals[w]);
+ }
+ resVals[w] = depthwiseConv1dSliceAsMulAcc(
+ rewriter, loc, lhsVal, rhsVals[kw], resVal, filterBCastTy, flatten);
+ if (flatten)
+ resVals[w] = rewriter.create<vector::ShapeCastOp>(
+ loc, VectorType::get(inOutSliceSizes, resEltType), resVals[w]);
}
}
@@ -2938,179 +2954,13 @@ struct Conv1DGenerator
.getOperation();
}
- /// Generate a vector implementation for ("flatten channel dim"):
- /// ```
- /// Op def: ( n, w, c, kw)
- /// Iters: ({Par(), Par(), Par(), Red()})
- /// Layout: {{n, 1 * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
- /// ```
- /// c of the input/output is collapsed with w. kw is always unrolled and
- /// broadcast to match w.
- ///
- /// TODO: Add support for non-unit stride/dilation
- /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
- /// > 1.
- FailureOr<Operation *> depthwiseConvFlatten() {
- if (!valid)
- return rewriter.notifyMatchFailure(op, "unvectorizable depthwise conv");
-
- int64_t nSize, iSize, wSize, cSize, kwSize;
- // kernel{kw, c}
- bindShapeDims(rhsShapedType, kwSize, cSize);
- // out{n, w, c}
- bindShapeDims(resShapedType, nSize, wSize);
- // in{n, w, c}
- bindShapeDims(lhsShapedType, nSize, iSize);
-
- vector::TransferWriteOp write;
- Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
-
- if (strideW == 1)
- return rewriter.notifyMatchFailure(
- op, "Non-unit strides are not supported yet");
- if (dilationW == 1)
- return rewriter.notifyMatchFailure(
- op, "Non-unit dilations are not supported yet");
-
- Type lhsEltType = lhsShapedType.getElementType();
- Type rhsEltType = rhsShapedType.getElementType();
- Type resEltType = resShapedType.getElementType();
- VectorType rhsType = VectorType::get({kwSize, cSize}, rhsEltType);
- VectorType lhsType = VectorType::get(
- {nSize,
- // iw = (ow * sw + kw * dw - 1) * c
- // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
- (((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1) *
- cSize},
- lhsEltType);
-
- VectorType resType = VectorType::get({nSize, wSize * cSize}, resEltType);
-
- Value res, lhs, lhsFlat, resFlat;
- // Read rhs slice of size {kw, c} @ [0, 0].
- Value rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
- ValueRange{zero, zero});
-
- SmallVector<ReassociationIndices> reassociation = {{0}, {1, 2}};
-
- // Flatten w and c dimensions
- auto lhsTypeCollapsed = VectorType::get({nSize, iSize * cSize}, lhsEltType);
- auto linalgOp = dyn_cast<LinalgOp>(op);
- lhsFlat =
- linalgOp.hasTensorSemantics()
- ? (Value)rewriter.create<tensor::CollapseShapeOp>(
- loc,
- RankedTensorType::get(lhsTypeCollapsed.getShape(),
- lhsEltType),
- lhsShaped, reassociation)
- : (Value)rewriter.create<memref::CollapseShapeOp>(
- loc, MemRefType::get(lhsTypeCollapsed.getShape(), lhsEltType),
- lhsShaped, reassociation);
- resFlat =
- linalgOp.hasTensorSemantics()
- ? (Value)rewriter.create<tensor::CollapseShapeOp>(
- loc, RankedTensorType::get(resType.getShape(), resEltType),
- resShaped, reassociation)
- : (Value)rewriter.create<memref::CollapseShapeOp>(
- loc, MemRefType::get(resType.getShape(), resEltType),
- resShaped, reassociation);
-
- // Read lhs slice of size {n, (w * wSize + kw * dilationW) * c} @ [0,
- // 0].
- lhs = rewriter.create<vector::TransferReadOp>(loc, lhsType, lhsFlat,
- ValueRange{zero, zero});
- // Read res slice of size {n, w * c} @ [0, 0].
- res = rewriter.create<vector::TransferReadOp>(loc, resType, resFlat,
- ValueRange{zero, zero});
-
- //===------------------------------------------------------------------===//
- // Begin vector-only rewrite part
- //===------------------------------------------------------------------===//
- // Unroll along kw and read slices of lhs and rhs.
- SmallVector<Value> lhsVals, rhsVals, resVals;
- // Extract lhs slice of size {n, wSizeStep * c}
- // @ [0, (sw * w + dw * kw) * cSize].
- for (int64_t kw = 0; kw < kwSize; ++kw) {
- for (int64_t w = 0; w < wSize; w += wSize) {
- lhsVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
- loc, lhs,
- /*offsets=*/
- ArrayRef<int64_t>{0, (w * wSize + kw * dilationW) * cSize},
- /*sizes=*/ArrayRef<int64_t>{nSize, wSize * cSize},
- /*strides=*/ArrayRef<int64_t>{1, 1}));
- }
- }
- // Extract rhs slice of size {c} @ [kw].
- for (int64_t kw = 0; kw < kwSize; ++kw) {
- rhsVals.push_back(rewriter.create<vector::ExtractOp>(
- loc, rhs, /*offsets=*/ArrayRef<int64_t>{kw}));
- }
-
- // Extract res slice
- // Flattened case: {n, wSizeStep * c} @ [0, w].
- for (int64_t w = 0; w < wSize; w += wSize) {
- resVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
- loc, res,
- /*offsets=*/ArrayRef<int64_t>{0, w * cSize},
- /*sizes=*/ArrayRef<int64_t>{nSize, wSize * cSize},
- /*strides=*/ArrayRef<int64_t>{1, 1}));
- }
-
- auto linearIndex = [&](int64_t kw, int64_t w) {
- return kw * (wSize / wSize) + w;
- };
-
- // Compute contraction:
- // O{n, w * c} += I{n, (sw * w + dw * kw) * c} * F{c}
- for (int64_t kw = 0; kw < kwSize; ++kw) {
- for (int64_t w = 0; w < wSize; w += wSize) {
- resVals[w] = depthwiseConv1dFlatSliceAsMulAcc(
- rewriter, loc, lhsVals[linearIndex(kw, w)], rhsVals[kw],
- resVals[w]);
- }
- }
-
- // Its possible we failed to create the Fma.
- if (!llvm::all_of(resVals, [](Value v) { return v; })) {
- // Manually revert (in reverse order) to avoid leaving a bad IR state.
- for (auto &collection :
- {resVals, rhsVals, lhsVals, {res, rhs, lhs, zero}})
- for (Value v : collection)
- rewriter.eraseOp(v.getDefiningOp());
- return rewriter.notifyMatchFailure(op, "failed to create FMA");
- }
-
- // Write back res slice. This does not depend on kw.
- // Flattened case: {n, wSizeStep * c} @ [0, w].
- for (int64_t w = 0; w < wSize; w += wSize) {
- res = rewriter.create<vector::InsertStridedSliceOp>(
- loc, resVals[w], res,
- /*offsets=*/ArrayRef<int64_t>{0, w * cSize},
- /*strides=*/ArrayRef<int64_t>{1, 1});
- }
- //===------------------------------------------------------------------===//
- // End vector-only rewrite part
- //===------------------------------------------------------------------===//
- // Write back res slice of size {n, w * c} @ [0, 0].
- mlir::vector::TransferWriteOp tWrite =
- rewriter.create<vector::TransferWriteOp>(loc, res, resFlat,
- ValueRange{zero, zero});
-
- // A tensor has to be re-shaped back to it's original shape ...
- if (linalgOp.hasTensorSemantics())
- // Re-expand shape
- return rewriter
- .create<tensor::ExpandShapeOp>(loc, resShapedType, tWrite.getResult(),
- reassociation)
- .getOperation();
- /// ... memrefs don't requie reshaping (re-shape is just a different view
- /// into the same memref)
- return tWrite.getOperation();
- }
-
- /// Lower lhs{n, w, c} * rhs{c} -> res{n, w, c} to MulAcc
+ /// Lower:
+ /// * lhs{n, w, c} * rhs{c} -> res{n, w, c} (flatten = false)
+ /// * lhs{n, w * c} * rhs{c} -> res{n, w * c} (flatten = true)
+ /// to MulAcc.
Value depthwiseConv1dSliceAsMulAcc(RewriterBase &rewriter, Location loc,
- Value lhs, Value rhs, Value res) {
+ Value lhs, Value rhs, Value res,
+ ShapedType bcastTy, bool flatten) {
auto rhsTy = cast<ShapedType>(rhs.getType());
auto resTy = cast<ShapedType>(res.getType());
@@ -3118,46 +2968,16 @@ struct Conv1DGenerator
lhs = promote(rewriter, loc, lhs, resTy);
rhs = rewriter.create<vector::BroadcastOp>(
- loc, resTy.clone(rhsTy.getElementType()), rhs);
- rhs = promote(rewriter, loc, rhs, resTy);
-
- if (!lhs || !rhs)
- return nullptr;
-
- if (isa<FloatType>(resTy.getElementType()))
- return rewriter.create<vector::FMAOp>(loc, lhs, rhs, res);
+ loc, bcastTy.clone(rhsTy.getElementType()), rhs);
+ if (flatten)
+ rhs = rewriter.create<vector::ShapeCastOp>(loc, resTy, rhs);
- auto mul = rewriter.create<arith::MulIOp>(loc, lhs, rhs);
- return rewriter.create<arith::AddIOp>(loc, mul, res);
- }
-
- /// Lower lhs{n, w * c} * rhs{c} -> res{n, w * c} to MulAcc
- Value depthwiseConv1dFlatSliceAsMulAcc(RewriterBase &rewriter, Location loc,
- Value lhs, Value rhs, Value res) {
- auto rhsTy = rhs.getType().cast<ShapedType>();
- auto resTy = res.getType().cast<ShapedType>();
-
- lhs = promote(rewriter, loc, lhs, resTy);
-
- auto rhsSize = rhs.getType().cast<VectorType>().getShape()[0];
- auto resSize = res.getType().cast<VectorType>().getShape()[1];
-
- SmallVector<int64_t, 16> indicies;
- for (int i = 0; i < resSize / rhsSize; ++i) {
- for (int j = 0; j < rhsSize; ++j)
- indicies.push_back(j);
- }
-
- rhs = rewriter.create<vector::ShuffleOp>(loc, rhs, rhs, indicies);
-
- rhs = rewriter.create<vector::BroadcastOp>(
- loc, resTy.clone(rhsTy.getElementType()), rhs);
rhs = promote(rewriter, loc, rhs, resTy);
if (!lhs || !rhs)
return nullptr;
- if (resTy.getElementType().isa<FloatType>())
+ if (isa<FloatType>(resTy.getElementType()))
return rewriter.create<vector::FMAOp>(loc, lhs, rhs, res);
auto mul = rewriter.create<arith::MulIOp>(loc, lhs, rhs);
@@ -3265,7 +3085,7 @@ struct Conv1DGenerator
if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
/*rhsIndex*/ {kw, c},
/*resIndex*/ {n, w, c}}))
- return flatten ? depthwiseConvFlatten() : depthwiseConvGeneric();
+ return depthwiseConvGeneric(flatten);
return rewriter.notifyMatchFailure(op, "not a depthwise::Nwc layout");
}
diff --git a/mlir/test/Dialect/Linalg/vectorize-convolution-flatten.mlir b/mlir/test/Dialect/Linalg/vectorize-convolution-flatten.mlir
index 6b0f920bfa42e..17f93ed46e779 100644
--- a/mlir/test/Dialect/Linalg/vectorize-convolution-flatten.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-convolution-flatten.mlir
@@ -17,25 +17,24 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
-
// CHECK-LABEL: func.func @flatten_tensor(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x8x3xi8>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x3xi8>,
// CHECK-SAME: %[[VAL_2:.*]]: tensor<1x8x3xi8>) -> tensor<1x8x3xi8> {
// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_4:.*]] = arith.constant 0 : i8
-// CHECK: %[[VAL_5:.*]] = vector.transfer_read %[[VAL_1]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : tensor<1x3xi8>, vector<1x3xi8>
-// CHECK: %[[VAL_6:.*]] = tensor.collapse_shape %[[VAL_0]] {{\[\[}}0], [1, 2]] : tensor<1x8x3xi8> into tensor<1x24xi8>
-// CHECK: %[[VAL_7:.*]] = tensor.collapse_shape %[[VAL_2]] {{\[\[}}0], [1, 2]] : tensor<1x8x3xi8> into tensor<1x24xi8>
-// CHECK: %[[VAL_8:.*]] = vector.transfer_read %[[VAL_6]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : tensor<1x24xi8>, vector<1x24xi8>
-// CHECK: %[[VAL_9:.*]] = vector.transfer_read %[[VAL_7]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : tensor<1x24xi8>, vector<1x24xi8>
-// CHECK: %[[VAL_10:.*]] = vector.extract %[[VAL_5]][0] : vector<3xi8> from vector<1x3xi8>
-// CHECK: %[[VAL_11:.*]] = vector.shuffle %[[VAL_10]], %[[VAL_10]] [0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2] : vector<3xi8>, vector<3xi8>
-// CHECK: %[[VAL_12:.*]] = vector.broadcast %[[VAL_11]] : vector<24xi8> to vector<1x24xi8>
-// CHECK: %[[VAL_13:.*]] = arith.muli %[[VAL_8]], %[[VAL_12]] : vector<1x24xi8>
-// CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_13]], %[[VAL_9]] : vector<1x24xi8>
-// CHECK: %[[VAL_15:.*]] = vector.transfer_write %[[VAL_14]], %[[VAL_7]]{{\[}}%[[VAL_3]], %[[VAL_3]]] {in_bounds = [true, true]} : vector<1x24xi8>, tensor<1x24xi8>
-// CHECK: %[[VAL_16:.*]] = tensor.expand_shape %[[VAL_15]] {{\[\[}}0], [1, 2]] : tensor<1x24xi8> into tensor<1x8x3xi8>
+// CHECK: %[[VAL_5:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_3]], %[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true, true]} : tensor<1x8x3xi8>, vector<1x8x3xi8>
+// CHECK: %[[VAL_6:.*]] = vector.transfer_read %[[VAL_1]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : tensor<1x3xi8>, vector<1x3xi8>
+// CHECK: %[[VAL_7:.*]] = vector.transfer_read %[[VAL_2]]{{\[}}%[[VAL_3]], %[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true, true]} : tensor<1x8x3xi8>, vector<1x8x3xi8>
+// CHECK: %[[VAL_8:.*]] = vector.extract %[[VAL_6]][0] : vector<3xi8> from vector<1x3xi8>
+// CHECK: %[[VAL_9:.*]] = vector.shape_cast %[[VAL_5]] : vector<1x8x3xi8> to vector<1x24xi8>
+// CHECK: %[[VAL_10:.*]] = vector.shape_cast %[[VAL_7]] : vector<1x8x3xi8> to vector<1x24xi8>
+// CHECK: %[[VAL_11:.*]] = vector.broadcast %[[VAL_8]] : vector<3xi8> to vector<1x8x3xi8>
+// CHECK: %[[VAL_12:.*]] = vector.shape_cast %[[VAL_11]] : vector<1x8x3xi8> to vector<1x24xi8>
+// CHECK: %[[VAL_13:.*]] = arith.muli %[[VAL_9]], %[[VAL_12]] : vector<1x24xi8>
+// CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_13]], %[[VAL_10]] : vector<1x24xi8>
+// CHECK: %[[VAL_15:.*]] = vector.shape_cast %[[VAL_14]] : vector<1x24xi8> to vector<1x8x3xi8>
+// CHECK: %[[VAL_16:.*]] = vector.transfer_write %[[VAL_15]], %[[VAL_2]]{{\[}}%[[VAL_3]], %[[VAL_3]], %[[VAL_3]]] {in_bounds = [true, true, true]} : vector<1x8x3xi8>, tensor<1x8x3xi8>
// CHECK: return %[[VAL_16]] : tensor<1x8x3xi8>
// CHECK: }
@@ -65,17 +64,18 @@ module attributes {transform.with_named_sequence} {
// CHECK-SAME: %[[VAL_2:.*]]: memref<1x8x3xi8>) {
// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_4:.*]] = arith.constant 0 : i8
-// CHECK: %[[VAL_5:.*]] = vector.transfer_read %[[VAL_1]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : memref<1x3xi8>, vector<1x3xi8>
-// CHECK: %[[VAL_6:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0], [1, 2]] : memref<1x8x3xi8> into memref<1x24xi8>
-// CHECK: %[[VAL_7:.*]] = memref.collapse_shape %[[VAL_2]] {{\[\[}}0], [1, 2]] : memref<1x8x3xi8> into memref<1x24xi8>
-// CHECK: %[[VAL_8:.*]] = vector.transfer_read %[[VAL_6]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : memref<1x24xi8>, vector<1x24xi8>
-// CHECK: %[[VAL_9:.*]] = vector.transfer_read %[[VAL_7]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : memref<1x24xi8>, vector<1x24xi8>
-// CHECK: %[[VAL_10:.*]] = vector.extract %[[VAL_5]][0] : vector<3xi8> from vector<1x3xi8>
-// CHECK: %[[VAL_11:.*]] = vector.shuffle %[[VAL_10]], %[[VAL_10]] [0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2] : vector<3xi8>, vector<3xi8>
-// CHECK: %[[VAL_12:.*]] = vector.broadcast %[[VAL_11]] : vector<24xi8> to vector<1x24xi8>
-// CHECK: %[[VAL_13:.*]] = arith.muli %[[VAL_8]], %[[VAL_12]] : vector<1x24xi8>
-// CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_13]], %[[VAL_9]] : vector<1x24xi8>
-// CHECK: vector.transfer_write %[[VAL_14]], %[[VAL_7]]{{\[}}%[[VAL_3]], %[[VAL_3]]] {in_bounds = [true, true]} : vector<1x24xi8>, memref<1x24xi8>
+// CHECK: %[[VAL_5:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_3]], %[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true, true]} : memref<1x8x3xi8>, vector<1x8x3xi8>
+// CHECK: %[[VAL_6:.*]] = vector.transfer_read %[[VAL_1]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : memref<1x3xi8>, vector<1x3xi8>
+// CHECK: %[[VAL_7:.*]] = vector.transfer_read %[[VAL_2]]{{\[}}%[[VAL_3]], %[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true, true]} : memref<1x8x3xi8>, vector<1x8x3xi8>
+// CHECK: %[[VAL_8:.*]] = vector.extract %[[VAL_6]][0] : vector<3xi8> from vector<1x3xi8>
+// CHECK: %[[VAL_9:.*]] = vector.shape_cast %[[VAL_5]] : vector<1x8x3xi8> to vector<1x24xi8>
+// CHECK: %[[VAL_10:.*]] = vector.shape_cast %[[VAL_7]] : vector<1x8x3xi8> to vector<1x24xi8>
+// CHECK: %[[VAL_11:.*]] = vector.broadcast %[[VAL_8]] : vector<3xi8> to vector<1x8x3xi8>
+// CHECK: %[[VAL_12:.*]] = vector.shape_cast %[[VAL_11]] : vector<1x8x3xi8> to vector<1x24xi8>
+// CHECK: %[[VAL_13:.*]] = arith.muli %[[VAL_9]], %[[VAL_12]] : vector<1x24xi8>
+// CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_13]], %[[VAL_10]] : vector<1x24xi8>
+// CHECK: %[[VAL_15:.*]] = vector.shape_cast %[[VAL_14]] : vector<1x24xi8> to vector<1x8x3xi8>
+// CHECK: vector.transfer_write %[[VAL_15]], %[[VAL_2]]{{\[}}%[[VAL_3]], %[[VAL_3]], %[[VAL_3]]] {in_bounds = [true, true, true]} : vector<1x8x3xi8>, memref<1x8x3xi8>
// CHECK: return
// CHECK: }
@@ -105,27 +105,30 @@ module attributes {transform.with_named_sequence} {
// CHECK-SAME: %[[VAL_2:.*]]: memref<1x7x3xi8>) {
// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_4:.*]] = arith.constant 0 : i8
-// CHECK: %[[VAL_5:.*]] = vector.transfer_read %[[VAL_1]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : memref<2x3xi8>, vector<2x3xi8>
-// CHECK: %[[VAL_6:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0], [1, 2]] : memref<1x8x3xi8> into memref<1x24xi8>
-// CHECK: %[[VAL_7:.*]] = memref.collapse_shape %[[VAL_2]] {{\[\[}}0], [1, 2]] : memref<1x7x3xi8> into memref<1x21xi8>
-// CHECK: %[[VAL_8:.*]] = vector.transfer_read %[[VAL_6]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : memref<1x24xi8>, vector<1x24xi8>
-// CHECK: %[[VAL_9:.*]] = vector.transfer_read %[[VAL_7]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : memref<1x21xi8>, vector<1x21xi8>
-// CHECK: %[[VAL_10:.*]] = vector.extract_strided_slice %[[VAL_8]] {offsets = [0, 0], sizes = [1, 21], strides = [1, 1]} : vector<1x24xi8> to vector<1x21xi8>
-// CHECK: %[[VAL_11:.*]] = vector.extract_strided_slice %[[VAL_8]] {offsets = [0, 3], sizes = [1, 21], strides = [1, 1]} : vector<1x24xi8> to vector<1x21xi8>
-// CHECK: %[[VAL_12:.*]] = vector.extract %[[VAL_5]][0] : vector<3xi8> from vector<2x3xi8>
-// CHECK: %[[VAL_13:.*]] = vector.extract %[[VAL_5]][1] : vector<3xi8> from vector<2x3xi8>
-// CHECK: %[[VAL_14:.*]] = vector.shuffle %[[VAL_12]], %[[VAL_12]] [0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2] : vector<3xi8>, vector<3xi8>
-// CHECK: %[[VAL_15:.*]] = vector.broadcast %[[VAL_14]] : vector<21xi8> to vector<1x21xi8>
-// CHECK: %[[VAL_16:.*]] = arith.muli %[[VAL_10]], %[[VAL_15]] : vector<1x21xi8>
-// CHECK: %[[VAL_17:.*]] = arith.addi %[[VAL_16]], %[[VAL_9]] : vector<1x21xi8>
-// CHECK: %[[VAL_18:.*]] = vector.shuffle %[[VAL_13]], %[[VAL_13]] [0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2] : vector<3xi8>, vector<3xi8>
-// CHECK: %[[VAL_19:.*]] = vector.broadcast %[[VAL_18]] : vector<21xi8> to vector<1x21xi8>
-// CHECK: %[[VAL_20:.*]] = arith.muli %[[VAL_11]], %[[VAL_19]] : vector<1x21xi8>
-// CHECK: %[[VAL_21:.*]] = arith.addi %[[VAL_20]], %[[VAL_17]] : vector<1x21xi8>
-// CHECK: vector.transfer_write %[[VAL_21]], %[[VAL_7]]{{\[}}%[[VAL_3]], %[[VAL_3]]] {in_bounds = [true, true]} : vector<1x21xi8>, memref<1x21xi8>
+// CHECK: %[[VAL_5:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_3]], %[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true, true]} : memref<1x8x3xi8>, vector<1x8x3xi8>
+// CHECK: %[[VAL_6:.*]] = vector.transfer_read %[[VAL_1]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : memref<2x3xi8>, vector<2x3xi8>
+// CHECK: %[[VAL_7:.*]] = vector.transfer_read %[[VAL_2]]{{\[}}%[[VAL_3]], %[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true, true]} : memref<1x7x3xi8>, vector<1x7x3xi8>
+// CHECK: %[[VAL_8:.*]] = vector.extract_strided_slice %[[VAL_5]] {offsets = [0, 0, 0], sizes = [1, 7, 3], strides = [1, 1, 1]} : vector<1x8x3xi8> to vector<1x7x3xi8>
+// CHECK: %[[VAL_9:.*]] = vector.extract_strided_slice %[[VAL_5]] {offsets = [0, 1, 0], sizes = [1, 7, 3], strides = [1, 1, 1]} : vector<1x8x3xi8> to vector<1x7x3xi8>
+// CHECK: %[[VAL_10:.*]] = vector.extract %[[VAL_6]][0] : vector<3xi8> from vector<2x3xi8>
+// CHECK: %[[VAL_11:.*]] = vector.extract %[[VAL_6]][1] : vector<3xi8> from vector<2x3xi8>
+// CHECK: %[[VAL_12:.*]] = vector.shape_cast %[[VAL_8]] : vector<1x7x3xi8> to vector<1x21xi8>
+// CHECK: %[[VAL_13:.*]] = vector.shape_cast %[[VAL_7]] : vector<1x7x3xi8> to vector<1x21xi8>
+// CHECK: %[[VAL_14:.*]] = vector.broadcast %[[VAL_10]] : vector<3xi8> to vector<1x7x3xi8>
+// CHECK: %[[VAL_15:.*]] = vector.shape_cast %[[VAL_14]] : vector<1x7x3xi8> to vector<1x21xi8>
+// CHECK: %[[VAL_16:.*]] = arith.muli %[[VAL_12]], %[[VAL_15]] : vector<1x21xi8>
+// CHECK: %[[VAL_17:.*]] = arith.addi %[[VAL_16]], %[[VAL_13]] : vector<1x21xi8>
+// CHECK: %[[VAL_18:.*]] = vector.shape_cast %[[VAL_9]] : vector<1x7x3xi8> to vector<1x21xi8>
+// CHECK: %[[VAL_19:.*]] = vector.broadcast %[[VAL_11]] : vector<3xi8> to vector<1x7x3xi8>
+// CHECK: %[[VAL_20:.*]] = vector.shape_cast %[[VAL_19]] : vector<1x7x3xi8> to vector<1x21xi8>
+// CHECK: %[[VAL_21:.*]] = arith.muli %[[VAL_18]], %[[VAL_20]] : vector<1x21xi8>
+// CHECK: %[[VAL_22:.*]] = arith.addi %[[VAL_21]], %[[VAL_17]] : vector<1x21xi8>
+// CHECK: %[[VAL_23:.*]] = vector.shape_cast %[[VAL_22]] : vector<1x21xi8> to vector<1x7x3xi8>
+// CHECK: vector.transfer_write %[[VAL_23]], %[[VAL_2]]{{\[}}%[[VAL_3]], %[[VAL_3]], %[[VAL_3]]] {in_bounds = [true, true, true]} : vector<1x7x3xi8>, memref<1x7x3xi8>
// CHECK: return
// CHECK: }
+
// -----
func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref(%input: memref<3x5x4xf32>, %filter: memref<2x4xf32>, %output: memref<3x2x4xf32>) {
@@ -151,22 +154,24 @@ module attributes {transform.with_named_sequence} {
// CHECK-SAME: %[[VAL_2:.*]]: memref<3x2x4xf32>) {
// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_4:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK: %[[VAL_5:.*]] = vector.transfer_read %[[VAL_1]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : memref<2x4xf32>, vector<2x4xf32>
-// CHECK: %[[VAL_6:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0], [1, 2]] : memref<3x5x4xf32> into memref<3x20xf32>
-// CHECK: %[[VAL_7:.*]] = memref.collapse_shape %[[VAL_2]] {{\[\[}}0], [1, 2]] : memref<3x2x4xf32> into memref<3x8xf32>
-// CHECK: %[[VAL_8:.*]] = vector.transfer_read %[[VAL_6]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : memref<3x20xf32>, vector<3x16xf32>
-// CHECK: %[[VAL_9:.*]] = vector.transfer_read %[[VAL_7]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : memref<3x8xf32>, vector<3x8xf32>
-// CHECK: %[[VAL_10:.*]] = vector.extract_strided_slice %[[VAL_8]] {offsets = [0, 0], sizes = [3, 8], strides = [1, 1]} : vector<3x16xf32> to vector<3x8xf32>
-// CHECK: %[[VAL_11:.*]] = vector.extract_strided_slice %[[VAL_8]] {offsets = [0, 8], sizes = [3, 8], strides = [1, 1]} : vector<3x16xf32> to vector<3x8xf32>
-// CHECK: %[[VAL_12:.*]] = vector.extract %[[VAL_5]][0] : vector<4xf32> from vector<2x4xf32>
-// CHECK: %[[VAL_13:.*]] = vector.extract %[[VAL_5]][1] : vector<4xf32> from vector<2x4xf32>
-// CHECK: %[[VAL_14:.*]] = vector.shuffle %[[VAL_12]], %[[VAL_12]] [0, 1, 2, 3, 0, 1, 2, 3] : vector<4xf32>, vector<4xf32>
-// CHECK: %[[VAL_15:.*]] = vector.broadcast %[[VAL_14]] : vector<8xf32> to vector<3x8xf32>
-// CHECK: %[[VAL_16:.*]] = vector.fma %[[VAL_10]], %[[VAL_15]], %[[VAL_9]] : vector<3x8xf32>
-// CHECK: %[[VAL_17:.*]] = vector.shuffle %[[VAL_13]], %[[VAL_13]] [0, 1, 2, 3, 0, 1, 2, 3] : vector<4xf32>, vector<4xf32>
-// CHECK: %[[VAL_18:.*]] = vector.broadcast %[[VAL_17]] : vector<8xf32> to vector<3x8xf32>
-// CHECK: %[[VAL_19:.*]] = vector.fma %[[VAL_11]], %[[VAL_18]], %[[VAL_16]] : vector<3x8xf32>
-// CHECK: vector.transfer_write %[[VAL_19]], %[[VAL_7]]{{\[}}%[[VAL_3]], %[[VAL_3]]] {in_bounds = [true, true]} : vector<3x8xf32>, memref<3x8xf32>
+// CHECK: %[[VAL_5:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_3]], %[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true, true]} : memref<3x5x4xf32>, vector<3x4x4xf32>
+// CHECK: %[[VAL_6:.*]] = vector.transfer_read %[[VAL_1]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : memref<2x4xf32>, vector<2x4xf32>
+// CHECK: %[[VAL_7:.*]] = vector.transfer_read %[[VAL_2]]{{\[}}%[[VAL_3]], %[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true, true]} : memref<3x2x4xf32>, vector<3x2x4xf32>
+// CHECK: %[[VAL_8:.*]] = vector.extract_strided_slice %[[VAL_5]] {offsets = [0, 0, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x4xf32> to vector<3x2x4xf32>
+// CHECK: %[[VAL_9:.*]] = vector.extract_strided_slice %[[VAL_5]] {offsets = [0, 2, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x4xf32> to vector<3x2x4xf32>
+// CHECK: %[[VAL_10:.*]] = vector.extract %[[VAL_6]][0] : vector<4xf32> from vector<2x4xf32>
+// CHECK: %[[VAL_11:.*]] = vector.extract %[[VAL_6]][1] : vector<4xf32> from vector<2x4xf32>
+// CHECK: %[[VAL_12:.*]] = vector.shape_cast %[[VAL_8]] : vector<3x2x4xf32> to vector<3x8xf32>
+// CHECK: %[[VAL_13:.*]] = vector.shape_cast %[[VAL_7]] : vector<3x2x4xf32> to vector<3x8xf32>
+// CHECK: %[[VAL_14:.*]] = vector.broadcast %[[VAL_10]] : vector<4xf32> to vector<3x2x4xf32>
+// CHECK: %[[VAL_15:.*]] = vector.shape_cast %[[VAL_14]] : vector<3x2x4xf32> to vector<3x8xf32>
+// CHECK: %[[VAL_16:.*]] = vector.fma %[[VAL_12]], %[[VAL_15]], %[[VAL_13]] : vector<3x8xf32>
+// CHECK: %[[VAL_17:.*]] = vector.shape_cast %[[VAL_9]] : vector<3x2x4xf32> to vector<3x8xf32>
+// CHECK: %[[VAL_18:.*]] = vector.broadcast %[[VAL_11]] : vector<4xf32> to vector<3x2x4xf32>
+// CHECK: %[[VAL_19:.*]] = vector.shape_cast %[[VAL_18]] : vector<3x2x4xf32> to vector<3x8xf32>
+// CHECK: %[[VAL_20:.*]] = vector.fma %[[VAL_17]], %[[VAL_19]], %[[VAL_16]] : vector<3x8xf32>
+// CHECK: %[[VAL_21:.*]] = vector.shape_cast %[[VAL_20]] : vector<3x8xf32> to vector<3x2x4xf32>
+// CHECK: vector.transfer_write %[[VAL_21]], %[[VAL_2]]{{\[}}%[[VAL_3]], %[[VAL_3]], %[[VAL_3]]] {in_bounds = [true, true, true]} : vector<3x2x4xf32>, memref<3x2x4xf32>
// CHECK: return
// CHECK: }
@@ -196,27 +201,24 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_4:.*]] = arith.constant 0 : i8
// CHECK: %[[VAL_5:.*]] = arith.constant 0 : i32
-// CHECK: %[[VAL_6:.*]] = vector.transfer_read %[[VAL_1]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : memref<2x4xi8>, vector<2x4xi8>
-// CHECK: %[[VAL_7:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0], [1, 2]] : memref<3x5x4xi8> into memref<3x20xi8>
-// CHECK: %[[VAL_8:.*]] = memref.collapse_shape %[[VAL_2]] {{\[\[}}0], [1, 2]] : memref<3x2x4xi32> into memref<3x8xi32>
-// CHECK: %[[VAL_9:.*]] = vector.transfer_read %[[VAL_7]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : memref<3x20xi8>, vector<3x16xi8>
-// CHECK: %[[VAL_10:.*]] = vector.transfer_read %[[VAL_8]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_5]] {in_bounds = [true, true]} : memref<3x8xi32>, vector<3x8xi32>
-// CHECK: %[[VAL_11:.*]] = vector.extract_strided_slice %[[VAL_9]] {offsets = [0, 0], sizes = [3, 8], strides = [1, 1]} : vector<3x16xi8> to vector<3x8xi8>
-// CHECK: %[[VAL_12:.*]] = vector.extract_strided_slice %[[VAL_9]] {offsets = [0, 8], sizes = [3, 8], strides = [1, 1]} : vector<3x16xi8> to vector<3x8xi8>
-// CHECK: %[[VAL_13:.*]] = vector.extract %[[VAL_6]][0] : vector<4xi8> from vector<2x4xi8>
-// CHECK: %[[VAL_14:.*]] = vector.extract %[[VAL_6]][1] : vector<4xi8> from vector<2x4xi8>
-// CHECK: %[[VAL_15:.*]] = arith.extsi %[[VAL_11]] : vector<3x8xi8> to vector<3x8xi32>
-// CHECK: %[[VAL_16:.*]] = vector.shuffle %[[VAL_13]], %[[VAL_13]] [0, 1, 2, 3, 0, 1, 2, 3] : vector<4xi8>, vector<4xi8>
-// CHECK: %[[VAL_17:.*]] = vector.broadcast %[[VAL_16]] : vector<8xi8> to vector<3x8xi8>
-// CHECK: %[[VAL_18:.*]] = arith.extsi %[[VAL_17]] : vector<3x8xi8> to vector<3x8xi32>
-// CHECK: %[[VAL_19:.*]] = arith.muli %[[VAL_15]], %[[VAL_18]] : vector<3x8xi32>
-// CHECK: %[[VAL_20:.*]] = arith.addi %[[VAL_19]], %[[VAL_10]] : vector<3x8xi32>
-// CHECK: %[[VAL_21:.*]] = arith.extsi %[[VAL_12]] : vector<3x8xi8> to vector<3x8xi32>
-// CHECK: %[[VAL_22:.*]] = vector.shuffle %[[VAL_14]], %[[VAL_14]] [0, 1, 2, 3, 0, 1, 2, 3] : vector<4xi8>, vector<4xi8>
-// CHECK: %[[VAL_23:.*]] = vector.broadcast %[[VAL_22]] : vector<8xi8> to vector<3x8xi8>
-// CHECK: %[[VAL_24:.*]] = arith.extsi %[[VAL_23]] : vector<3x8xi8> to vector<3x8xi32>
-// CHECK: %[[VAL_25:.*]] = arith.muli %[[VAL_21]], %[[VAL_24]] : vector<3x8xi32>
-// CHECK: %[[VAL_26:.*]] = arith.addi %[[VAL_25]], %[[VAL_20]] : vector<3x8xi32>
-// CHECK: vector.transfer_write %[[VAL_26]], %[[VAL_8]]{{\[}}%[[VAL_3]], %[[VAL_3]]] {in_bounds = [true, true]} : vector<3x8xi32>, memref<3x8xi32>
+// CHECK: %[[VAL_6:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_3]], %[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true, true]} : memref<3x5x4xi8>, vector<3x4x4xi8>
+// CHECK: %[[VAL_7:.*]] = vector.transfer_read %[[VAL_1]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : memref<2x4xi8>, vector<2x4xi8>
+// CHECK: %[[VAL_8:.*]] = vector.transfer_read %[[VAL_2]]{{\[}}%[[VAL_3]], %[[VAL_3]], %[[VAL_3]]], %[[VAL_5]] {in_bounds = [true, true, true]} : memref<3x2x4xi32>, vector<3x2x4xi32>
+// CHECK: %[[VAL_9:.*]] = vector.extract_strided_slice %[[VAL_6]] {offsets = [0, 0, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x4xi8> to vector<3x2x4xi8>
+// CHECK: %[[VAL_10:.*]] = vector.extract_strided_slice %[[VAL_6]] {offsets = [0, 2, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x4xi8> to vector<3x2x4xi8>
+// CHECK: %[[VAL_11:.*]] = vector.extract %[[VAL_7]][0] : vector<4xi8> from vector<2x4xi8>
+// CHECK: %[[VAL_12:.*]] = vector.extract %[[VAL_7]][1] : vector<4xi8> from vector<2x4xi8>
+// CHECK: %[[VAL_13:.*]] = arith.extsi %[[VAL_9]] : vector<3x2x4xi8> to vector<3x2x4xi32>
+// CHECK: %[[VAL_14:.*]] = arith.extsi %[[VAL_11]] : vector<4xi8> to vector<4xi32>
+// CHECK: %[[VAL_15:.*]] = vector.broadcast %[[VAL_14]] : vector<4xi32> to vector<3x2x4xi32>
+// CHECK: %[[VAL_16:.*]] = arith.muli %[[VAL_13]], %[[VAL_15]] : vector<3x2x4xi32>
+// CHECK: %[[VAL_17:.*]] = arith.addi %[[VAL_16]], %[[VAL_8]] : vector<3x2x4xi32>
+// CHECK: %[[VAL_18:.*]] = arith.extsi %[[VAL_10]] : vector<3x2x4xi8> to vector<3x2x4xi32>
+// CHECK: %[[VAL_19:.*]] = arith.extsi %[[VAL_12]] : vector<4xi8> to vector<4xi32>
+// CHECK: %[[VAL_20:.*]] = vector.broadcast %[[VAL_19]] : vector<4xi32> to vector<3x2x4xi32>
+// CHECK: %[[VAL_21:.*]] = arith.muli %[[VAL_18]], %[[VAL_20]] : vector<3x2x4xi32>
+// CHECK: %[[VAL_22:.*]] = arith.addi %[[VAL_21]], %[[VAL_17]] : vector<3x2x4xi32>
+// CHECK: vector.transfer_write %[[VAL_22]], %[[VAL_2]]{{\[}}%[[VAL_3]], %[[VAL_3]], %[[VAL_3]]] {in_bounds = [true, true, true]} : vector<3x2x4xi32>, memref<3x2x4xi32>
// CHECK: return
// CHECK: }
+
>From 69fa20ebbdaf34e615077f39388bca2f2eb75b2b Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Thu, 16 Nov 2023 10:24:45 +0000
Subject: [PATCH 3/4] fixup! [mlir][linalg][conv] Flatten the channel dimension
when vectorizing
Revert renaming + fix formatting
---
.../lib/Dialect/Linalg/Transforms/Vectorization.cpp | 13 +++++++------
1 file changed, 7 insertions(+), 6 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index c18e5d627a9de..b618a1b1daeea 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -2824,7 +2824,7 @@ struct Conv1DGenerator
/// kw is always unrolled.
/// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
/// > 1.
- FailureOr<Operation *> depthwiseConvGeneric(bool flatten) {
+ FailureOr<Operation *> depthwiseConv(bool flatten) {
if (!valid)
return rewriter.notifyMatchFailure(op, "unvectorizable depthwise conv");
@@ -2881,7 +2881,8 @@ struct Conv1DGenerator
lhsVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
loc, lhs,
/*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
- inOutSliceSizes, inOutStrides));
+ inOutSliceSizes,
+ inOutStrides));
}
}
// Extract rhs slice of size {c} @ [kw].
@@ -2893,7 +2894,8 @@ struct Conv1DGenerator
for (int64_t w = 0; w < wSize; w += wSizeStep) {
resVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
loc, res,
- /*offsets=*/ArrayRef<int64_t>{0, w, 0}, inOutSliceSizes,
+ /*offsets=*/ArrayRef<int64_t>{0, w, 0},
+ inOutSliceSizes,
inOutStrides));
}
@@ -2901,8 +2903,7 @@ struct Conv1DGenerator
return kw * (wSize / wSizeStep) + w;
};
- auto inOutFlattenSliceSizes =
- SmallVector<int64_t>{nSize, wSizeStep * cSize};
+ auto inOutFlattenSliceSizes = SmallVector<int64_t>{nSize, wSizeStep * cSize};
auto lhsCastType = VectorType::get(inOutFlattenSliceSizes, lhsEltType);
auto resCastType = VectorType::get(inOutFlattenSliceSizes, lhsEltType);
// Compute contraction: O{n, w, c} += I{n, sw * w + dw * kw, c} * F{c}
@@ -3085,7 +3086,7 @@ struct Conv1DGenerator
if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
/*rhsIndex*/ {kw, c},
/*resIndex*/ {n, w, c}}))
- return depthwiseConvGeneric(flatten);
+ return depthwiseConv(flatten);
return rewriter.notifyMatchFailure(op, "not a depthwise::Nwc layout");
}
>From 8e8f56d96e1b674c78079365db635c9b6418f2f6 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Wed, 6 Dec 2023 18:40:49 +0000
Subject: [PATCH 4/4] fixup! [mlir][linalg][conv] Flatten the channel dimension
when vectorizing
Final tweaks (more comments, revert unrelated change in a test file)
---
.../TransformOps/LinalgTransformOps.cpp | 3 +
.../Linalg/Transforms/Vectorization.cpp | 24 +-
.../Linalg/vectorize-convolution-flatten.mlir | 401 +++++++++++-------
.../Dialect/Linalg/vectorize-convolution.mlir | 12 +-
4 files changed, 267 insertions(+), 173 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 279a7c487b7d5..86fa4f8cccc4b 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2990,6 +2990,9 @@ struct VectorizationPattern : public RewritePattern {
/// Controls whether to vectorize `tensor.extract` when the input tensor is
/// rank >= 2.
bool vectorizeNDExtract = false;
+ /// Controls whether to "flatten" the channel dimension when vectorising 1D
+ /// depthwise convolutions. This should lead to bette vectorization for
+ /// tensors with a low number of channel dimensions.
bool flatten1DDepthwiseConv = false;
};
} // namespace
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index b618a1b1daeea..c21d007c931b9 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -2881,8 +2881,7 @@ struct Conv1DGenerator
lhsVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
loc, lhs,
/*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
- inOutSliceSizes,
- inOutStrides));
+ inOutSliceSizes, inOutStrides));
}
}
// Extract rhs slice of size {c} @ [kw].
@@ -2894,8 +2893,7 @@ struct Conv1DGenerator
for (int64_t w = 0; w < wSize; w += wSizeStep) {
resVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
loc, res,
- /*offsets=*/ArrayRef<int64_t>{0, w, 0},
- inOutSliceSizes,
+ /*offsets=*/ArrayRef<int64_t>{0, w, 0}, inOutSliceSizes,
inOutStrides));
}
@@ -2903,9 +2901,10 @@ struct Conv1DGenerator
return kw * (wSize / wSizeStep) + w;
};
- auto inOutFlattenSliceSizes = SmallVector<int64_t>{nSize, wSizeStep * cSize};
+ auto inOutFlattenSliceSizes =
+ SmallVector<int64_t>{nSize, wSizeStep * cSize};
auto lhsCastType = VectorType::get(inOutFlattenSliceSizes, lhsEltType);
- auto resCastType = VectorType::get(inOutFlattenSliceSizes, lhsEltType);
+ auto resCastType = VectorType::get(inOutFlattenSliceSizes, resEltType);
// Compute contraction: O{n, w, c} += I{n, sw * w + dw * kw, c} * F{c}
for (int64_t kw = 0; kw < kwSize; ++kw) {
for (int64_t w = 0; w < wSize; w += wSizeStep) {
@@ -2913,6 +2912,8 @@ struct Conv1DGenerator
Value resVal = resVals[w];
ShapedType filterBCastTy = cast<ShapedType>(resVal.getType());
if (flatten) {
+ // Flatten the input and filter vectors (collapse the channel
+ // dimension)
lhsVal = rewriter.create<vector::ShapeCastOp>(
loc, lhsCastType, lhsVals[linearIndex(kw, w)]);
resVal = rewriter.create<vector::ShapeCastOp>(loc, resCastType,
@@ -2920,9 +2921,11 @@ struct Conv1DGenerator
}
resVals[w] = depthwiseConv1dSliceAsMulAcc(
rewriter, loc, lhsVal, rhsVals[kw], resVal, filterBCastTy, flatten);
- if (flatten)
+ if (flatten) {
+ // Un-flatten the output vector (restore the channel dimension)
resVals[w] = rewriter.create<vector::ShapeCastOp>(
loc, VectorType::get(inOutSliceSizes, resEltType), resVals[w]);
+ }
}
}
@@ -2970,8 +2973,11 @@ struct Conv1DGenerator
rhs = rewriter.create<vector::BroadcastOp>(
loc, bcastTy.clone(rhsTy.getElementType()), rhs);
- if (flatten)
- rhs = rewriter.create<vector::ShapeCastOp>(loc, resTy, rhs);
+ if (flatten) {
+ // Flatten the channel dimension
+ rhs = rewriter.create<vector::ShapeCastOp>(
+ loc, resTy.clone(rhsTy.getElementType()), rhs);
+ }
rhs = promote(rewriter, loc, rhs, resTy);
diff --git a/mlir/test/Dialect/Linalg/vectorize-convolution-flatten.mlir b/mlir/test/Dialect/Linalg/vectorize-convolution-flatten.mlir
index 17f93ed46e779..a242d09671825 100644
--- a/mlir/test/Dialect/Linalg/vectorize-convolution-flatten.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-convolution-flatten.mlir
@@ -1,6 +1,8 @@
// RUN: mlir-opt -split-input-file -transform-interpreter %s | FileCheck %s
-func.func @flatten_tensor(%input: tensor<1x8x3xi8>, %filter: tensor<1x3xi8>, %output: tensor<1x8x3xi8>) -> (tensor<1x8x3xi8>) {
+func.func @depthwise_conv1d_nwc_wc_1x8x3xi8_tensor(%input: tensor<1x8x3xi8>,
+ %filter: tensor<1x3xi8>,
+ %output: tensor<1x8x3xi8>) -> (tensor<1x8x3xi8>) {
%res = linalg.depthwise_conv_1d_nwc_wc
{dilations = dense<1> : vector<1xi64>,
strides = dense<1> : vector<1xi64>}
@@ -17,78 +19,81 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
-// CHECK-LABEL: func.func @flatten_tensor(
-// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x8x3xi8>,
-// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x3xi8>,
-// CHECK-SAME: %[[VAL_2:.*]]: tensor<1x8x3xi8>) -> tensor<1x8x3xi8> {
-// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index
-// CHECK: %[[VAL_4:.*]] = arith.constant 0 : i8
-// CHECK: %[[VAL_5:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_3]], %[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true, true]} : tensor<1x8x3xi8>, vector<1x8x3xi8>
-// CHECK: %[[VAL_6:.*]] = vector.transfer_read %[[VAL_1]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : tensor<1x3xi8>, vector<1x3xi8>
-// CHECK: %[[VAL_7:.*]] = vector.transfer_read %[[VAL_2]]{{\[}}%[[VAL_3]], %[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true, true]} : tensor<1x8x3xi8>, vector<1x8x3xi8>
-// CHECK: %[[VAL_8:.*]] = vector.extract %[[VAL_6]][0] : vector<3xi8> from vector<1x3xi8>
-// CHECK: %[[VAL_9:.*]] = vector.shape_cast %[[VAL_5]] : vector<1x8x3xi8> to vector<1x24xi8>
-// CHECK: %[[VAL_10:.*]] = vector.shape_cast %[[VAL_7]] : vector<1x8x3xi8> to vector<1x24xi8>
-// CHECK: %[[VAL_11:.*]] = vector.broadcast %[[VAL_8]] : vector<3xi8> to vector<1x8x3xi8>
-// CHECK: %[[VAL_12:.*]] = vector.shape_cast %[[VAL_11]] : vector<1x8x3xi8> to vector<1x24xi8>
-// CHECK: %[[VAL_13:.*]] = arith.muli %[[VAL_9]], %[[VAL_12]] : vector<1x24xi8>
-// CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_13]], %[[VAL_10]] : vector<1x24xi8>
-// CHECK: %[[VAL_15:.*]] = vector.shape_cast %[[VAL_14]] : vector<1x24xi8> to vector<1x8x3xi8>
-// CHECK: %[[VAL_16:.*]] = vector.transfer_write %[[VAL_15]], %[[VAL_2]]{{\[}}%[[VAL_3]], %[[VAL_3]], %[[VAL_3]]] {in_bounds = [true, true, true]} : vector<1x8x3xi8>, tensor<1x8x3xi8>
-// CHECK: return %[[VAL_16]] : tensor<1x8x3xi8>
-// CHECK: }
+// CHECK-LABEL: func.func @depthwise_conv1d_nwc_wc_1x8x3xi8_tensor
+// CHECK-SAME: %[[INPUT:.*]]: tensor<1x8x3xi8>,
+// CHECK-SAME: %[[FILTER:.*]]: tensor<1x3xi8>,
+// CHECK-SAME: %[[OUTPUT:.*]]: tensor<1x8x3xi8>) -> tensor<1x8x3xi8> {
+
+// CHECK-DAG: %[[C0_IDX:.*]] = arith.constant 0 : index
+
+/// Read the whole data in one shot.
+// CHECK: %[[V_INPUT_R:.*]] = vector.transfer_read %[[INPUT]][%[[C0_IDX]], %[[C0_IDX]], %[[C0_IDX]]]
+// CHECK: %[[V_FILTER_R:.*]] = vector.transfer_read %[[FILTER]][%[[C0_IDX]], %[[C0_IDX]]]
+// CHECK: %[[V_OUTPUT_R:.*]] = vector.transfer_read %[[OUTPUT]][%[[C0_IDX]], %[[C0_IDX]], %[[C0_IDX]]]
+
+// CHECK: %[[V_FILTER_0:.*]] = vector.extract %[[V_FILTER_R]][0] : vector<3xi8> from vector<1x3xi8>
+
+/// w == 0, kw = 0
+// CHECK: %[[SC_INPUT:.*]] = vector.shape_cast %[[V_INPUT_R]] : vector<1x8x3xi8> to vector<1x24xi8>
+// CHECK: %[[SC_OUTPUT:.*]] = vector.shape_cast %[[V_OUTPUT_R]] : vector<1x8x3xi8> to vector<1x24xi8>
+// CHECK: %[[B_FILTER:.*]] = vector.broadcast %[[V_FILTER_0]] : vector<3xi8> to vector<1x8x3xi8>
+// CHECK: %[[SC_FILTER:.*]] = vector.shape_cast %[[B_FILTER]] : vector<1x8x3xi8> to vector<1x24xi8>
+// CHECK: %[[MULI:.*]] = arith.muli %[[SC_INPUT]], %[[SC_FILTER]] : vector<1x24xi8>
+// CHECK: %[[ADDI:.*]] = arith.addi %[[MULI]], %[[SC_OUTPUT]] : vector<1x24xi8>
+
+// Write the result back in one shot.
+// CHECK: %[[SC_ADDI:.*]] = vector.shape_cast %[[ADDI]] : vector<1x24xi8> to vector<1x8x3xi8>
+// CHECK: vector.transfer_write %[[SC_ADDI]], %[[OUTPUT]][%[[C0_IDX]], %[[C0_IDX]], %[[C0_IDX]]]
//------
-func.func @flatten_memref(%input: memref<1x8x3xi8>, %filter: memref<1x3xi8>, %output: memref<1x8x3xi8>) {
+func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref_dillation_2(%input: memref<3x5x4xf32>,
+ %filter: memref<2x4xf32>,
+ %output: memref<3x2x4xf32>) {
linalg.depthwise_conv_1d_nwc_wc
- {dilations = dense<1> : vector<1xi64>,
- strides = dense<1> : vector<1xi64>}
- ins(%input, %filter : memref<1x8x3xi8>, memref<1x3xi8>)
- outs(%output : memref<1x8x3xi8>)
+ {dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
+ ins(%input, %filter : memref<3x5x4xf32>, memref<2x4xf32>)
+ outs(%output : memref<3x2x4xf32>)
return
}
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc"]} in %arg0 : (!transform.any_op) -> !transform.any_op
- %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
- %2 = transform.structured.vectorize_children_and_apply_patterns %1 {flatten_1d_depthwise_conv} : (!transform.any_op) -> !transform.any_op
- transform.yield
- }
-}
+// CHECK: func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref_dillation_2
+// CHECK-SAME: (%[[INPUT:[0-9a-z]+]]: memref<3x5x4xf32>, %[[FILTER:[0-9a-z]+]]: memref<2x4xf32>, %[[OUTPUT:[0-9a-z]+]]: memref<3x2x4xf32>)
-// CHECK-LABEL: func.func @flatten_memref(
-// CHECK-SAME: %[[VAL_0:.*]]: memref<1x8x3xi8>,
-// CHECK-SAME: %[[VAL_1:.*]]: memref<1x3xi8>,
-// CHECK-SAME: %[[VAL_2:.*]]: memref<1x8x3xi8>) {
-// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index
-// CHECK: %[[VAL_4:.*]] = arith.constant 0 : i8
-// CHECK: %[[VAL_5:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_3]], %[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true, true]} : memref<1x8x3xi8>, vector<1x8x3xi8>
-// CHECK: %[[VAL_6:.*]] = vector.transfer_read %[[VAL_1]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : memref<1x3xi8>, vector<1x3xi8>
-// CHECK: %[[VAL_7:.*]] = vector.transfer_read %[[VAL_2]]{{\[}}%[[VAL_3]], %[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true, true]} : memref<1x8x3xi8>, vector<1x8x3xi8>
-// CHECK: %[[VAL_8:.*]] = vector.extract %[[VAL_6]][0] : vector<3xi8> from vector<1x3xi8>
-// CHECK: %[[VAL_9:.*]] = vector.shape_cast %[[VAL_5]] : vector<1x8x3xi8> to vector<1x24xi8>
-// CHECK: %[[VAL_10:.*]] = vector.shape_cast %[[VAL_7]] : vector<1x8x3xi8> to vector<1x24xi8>
-// CHECK: %[[VAL_11:.*]] = vector.broadcast %[[VAL_8]] : vector<3xi8> to vector<1x8x3xi8>
-// CHECK: %[[VAL_12:.*]] = vector.shape_cast %[[VAL_11]] : vector<1x8x3xi8> to vector<1x24xi8>
-// CHECK: %[[VAL_13:.*]] = arith.muli %[[VAL_9]], %[[VAL_12]] : vector<1x24xi8>
-// CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_13]], %[[VAL_10]] : vector<1x24xi8>
-// CHECK: %[[VAL_15:.*]] = vector.shape_cast %[[VAL_14]] : vector<1x24xi8> to vector<1x8x3xi8>
-// CHECK: vector.transfer_write %[[VAL_15]], %[[VAL_2]]{{\[}}%[[VAL_3]], %[[VAL_3]], %[[VAL_3]]] {in_bounds = [true, true, true]} : vector<1x8x3xi8>, memref<1x8x3xi8>
-// CHECK: return
-// CHECK: }
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[F0:.+]] = arith.constant 0.000000e+00 : f32
-// -----
+/// Read the whole data in one shot.
+// CHECK-DAG: %[[V_INPUT_R:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]]
+// CHECK-DAG: %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]]]
+// CHECK-DAG: %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
+
+// CHECK: %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
+// CHECK-SAME: {offsets = [0, 0, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x4xf32> to vector<3x2x4xf32>
+// CHECK: %[[V_INPUT_1:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
+// CHECK-SAME: {offsets = [0, 2, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x4xf32> to vector<3x2x4xf32>
+
+// CHECK: %[[V_FILTER_0:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<4xf32> from vector<2x4xf32>
+// CHECK: %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][1] : vector<4xf32> from vector<2x4xf32>
+
+
+/// w == 0, kw = 0
+// CHECK: %[[SC_V_INPUT_0:.*]] = vector.shape_cast %[[V_INPUT_0]] : vector<3x2x4xf32> to vector<3x8xf32>
+// CHECK: %[[SC_V_OUTPUT_R:.*]] = vector.shape_cast %[[V_OUTPUT_R]] : vector<3x2x4xf32> to vector<3x8xf32>
+// CHECK: %[[B_FILTER_0:.*]] = vector.broadcast %[[V_FILTER_0]] : vector<4xf32> to vector<3x2x4xf32>
+// CHECK: %[[SC_B_FILTER_0:.*]] = vector.shape_cast %[[B_FILTER_0]] : vector<3x2x4xf32> to vector<3x8xf32>
+// CHECK: %[[FMA_0:.*]] = vector.fma %[[SC_V_INPUT_0]], %[[SC_B_FILTER_0]], %[[SC_V_OUTPUT_R]] : vector<3x8xf32>
+
+/// w == 0, kw = 1
+// CHECK: %[[SC_V_INPUT_1:.*]] = vector.shape_cast %[[V_INPUT_1]] : vector<3x2x4xf32> to vector<3x8xf32>
+// CHECK: %[[B_V_FILTER_1:.*]] = vector.broadcast %[[V_FILTER_1]] : vector<4xf32> to vector<3x2x4xf32>
+// CHECK: %[[SC_B_FILTER_1:.*]] = vector.shape_cast %[[B_V_FILTER_1]] : vector<3x2x4xf32> to vector<3x8xf32>
+// CHECK: %[[FMA_1:.*]] = vector.fma %[[SC_V_INPUT_1]], %[[SC_B_FILTER_1]], %[[FMA_0]] : vector<3x8xf32>
+
+// Write the result back in one shot.
+// CHECK: %[[SC_FMA_1:.*]] = vector.shape_cast %[[FMA_1]] : vector<3x8xf32> to vector<3x2x4xf32>
+// CHECK: vector.transfer_write %[[SC_FMA_1]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
-func.func @flatten_memref_wider_filter(%input: memref<1x8x3xi8>, %filter: memref<2x3xi8>, %output: memref<1x7x3xi8>) {
- linalg.depthwise_conv_1d_nwc_wc
- {dilations = dense<1> : vector<1xi64>,
- strides = dense<1> : vector<1xi64>}
- ins(%input, %filter : memref<1x8x3xi8>, memref<2x3xi8>)
- outs(%output : memref<1x7x3xi8>)
- return
-}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
@@ -99,46 +104,59 @@ module attributes {transform.with_named_sequence} {
}
}
-// CHECK-LABEL: func.func @flatten_memref_wider_filter(
-// CHECK-SAME: %[[VAL_0:.*]]: memref<1x8x3xi8>,
-// CHECK-SAME: %[[VAL_1:.*]]: memref<2x3xi8>,
-// CHECK-SAME: %[[VAL_2:.*]]: memref<1x7x3xi8>) {
-// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index
-// CHECK: %[[VAL_4:.*]] = arith.constant 0 : i8
-// CHECK: %[[VAL_5:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_3]], %[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true, true]} : memref<1x8x3xi8>, vector<1x8x3xi8>
-// CHECK: %[[VAL_6:.*]] = vector.transfer_read %[[VAL_1]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : memref<2x3xi8>, vector<2x3xi8>
-// CHECK: %[[VAL_7:.*]] = vector.transfer_read %[[VAL_2]]{{\[}}%[[VAL_3]], %[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true, true]} : memref<1x7x3xi8>, vector<1x7x3xi8>
-// CHECK: %[[VAL_8:.*]] = vector.extract_strided_slice %[[VAL_5]] {offsets = [0, 0, 0], sizes = [1, 7, 3], strides = [1, 1, 1]} : vector<1x8x3xi8> to vector<1x7x3xi8>
-// CHECK: %[[VAL_9:.*]] = vector.extract_strided_slice %[[VAL_5]] {offsets = [0, 1, 0], sizes = [1, 7, 3], strides = [1, 1, 1]} : vector<1x8x3xi8> to vector<1x7x3xi8>
-// CHECK: %[[VAL_10:.*]] = vector.extract %[[VAL_6]][0] : vector<3xi8> from vector<2x3xi8>
-// CHECK: %[[VAL_11:.*]] = vector.extract %[[VAL_6]][1] : vector<3xi8> from vector<2x3xi8>
-// CHECK: %[[VAL_12:.*]] = vector.shape_cast %[[VAL_8]] : vector<1x7x3xi8> to vector<1x21xi8>
-// CHECK: %[[VAL_13:.*]] = vector.shape_cast %[[VAL_7]] : vector<1x7x3xi8> to vector<1x21xi8>
-// CHECK: %[[VAL_14:.*]] = vector.broadcast %[[VAL_10]] : vector<3xi8> to vector<1x7x3xi8>
-// CHECK: %[[VAL_15:.*]] = vector.shape_cast %[[VAL_14]] : vector<1x7x3xi8> to vector<1x21xi8>
-// CHECK: %[[VAL_16:.*]] = arith.muli %[[VAL_12]], %[[VAL_15]] : vector<1x21xi8>
-// CHECK: %[[VAL_17:.*]] = arith.addi %[[VAL_16]], %[[VAL_13]] : vector<1x21xi8>
-// CHECK: %[[VAL_18:.*]] = vector.shape_cast %[[VAL_9]] : vector<1x7x3xi8> to vector<1x21xi8>
-// CHECK: %[[VAL_19:.*]] = vector.broadcast %[[VAL_11]] : vector<3xi8> to vector<1x7x3xi8>
-// CHECK: %[[VAL_20:.*]] = vector.shape_cast %[[VAL_19]] : vector<1x7x3xi8> to vector<1x21xi8>
-// CHECK: %[[VAL_21:.*]] = arith.muli %[[VAL_18]], %[[VAL_20]] : vector<1x21xi8>
-// CHECK: %[[VAL_22:.*]] = arith.addi %[[VAL_21]], %[[VAL_17]] : vector<1x21xi8>
-// CHECK: %[[VAL_23:.*]] = vector.shape_cast %[[VAL_22]] : vector<1x21xi8> to vector<1x7x3xi8>
-// CHECK: vector.transfer_write %[[VAL_23]], %[[VAL_2]]{{\[}}%[[VAL_3]], %[[VAL_3]], %[[VAL_3]]] {in_bounds = [true, true, true]} : vector<1x7x3xi8>, memref<1x7x3xi8>
-// CHECK: return
-// CHECK: }
-
-
// -----
-func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref(%input: memref<3x5x4xf32>, %filter: memref<2x4xf32>, %output: memref<3x2x4xf32>) {
+func.func @depthwise_conv1d_nwc_wc_3x5x4xi8_memref_dilation_2(%input: memref<3x5x4xi8>,
+ %filter: memref<2x4xi8>,
+ %output: memref<3x2x4xi32>) {
linalg.depthwise_conv_1d_nwc_wc
{dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
- ins(%input, %filter : memref<3x5x4xf32>, memref<2x4xf32>)
- outs(%output : memref<3x2x4xf32>)
+ ins(%input, %filter : memref<3x5x4xi8>, memref<2x4xi8>)
+ outs(%output : memref<3x2x4xi32>)
return
}
+// CHECK: func @depthwise_conv1d_nwc_wc_3x5x4xi8_memref_dilation_2
+// CHECK-SAME: (%[[INPUT:[0-9a-z]+]]: memref<3x5x4xi8>, %[[FILTER:[0-9a-z]+]]: memref<2x4xi8>, %[[OUTPUT:[0-9a-z]+]]: memref<3x2x4xi32>)
+
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+
+/// Read the whole data in one shot.
+// CHECK-DAG: %[[V_INPUT_R:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]]
+// CHECK-DAG: %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]]]
+// CHECK-DAG: %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
+
+// CHECK: %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
+// CHECK-SAME: {offsets = [0, 0, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x4xi8> to vector<3x2x4xi8>
+// CHECK: %[[V_INPUT_1:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
+// CHECK-SAME: {offsets = [0, 2, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x4xi8> to vector<3x2x4xi8>
+
+// CHECK: %[[V_FILTER_0:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<4xi8> from vector<2x4xi8>
+// CHECK: %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][1] : vector<4xi8> from vector<2x4xi8>
+
+/// w == 0, kw = 0
+// CHECK: %[[SC_V_INPUT_0:.*]] = vector.shape_cast %[[V_INPUT_0]] : vector<3x2x4xi8> to vector<3x8xi8>
+// CHECK: %[[SC_V_OUTPUT_R:.*]] = vector.shape_cast %[[V_OUTPUT_R]] : vector<3x2x4xi32> to vector<3x8xi32>
+// CHECK: %[[EXT_INPUT_0:.*]] = arith.extsi %[[SC_V_INPUT_0]] : vector<3x8xi8> to vector<3x8xi32>
+// CHECK: %[[B_FILTER_0:.*]] = vector.broadcast %[[V_FILTER_0]] : vector<4xi8> to vector<3x2x4xi8>
+// CHECK: %[[SC_B_FILTER_0:.*]] = vector.shape_cast %[[B_FILTER_0]] : vector<3x2x4xi8> to vector<3x8xi8>
+// CHECK: %[[EXT_FILTER_0:.*]] = arith.extsi %[[SC_B_FILTER_0]] : vector<3x8xi8> to vector<3x8xi32>
+// CHECK: %[[MUL_0:.*]] = arith.muli %[[EXT_INPUT_0]], %[[EXT_FILTER_0]] : vector<3x8xi32>
+// CHECK: %[[ADD_0:.*]] = arith.addi %[[MUL_0]], %[[SC_V_OUTPUT_R]] : vector<3x8xi32>
+
+/// w == 0, kw = 1
+// CHECK: %[[SC_V_INPUT_1:.*]] = vector.shape_cast %[[V_INPUT_1]] : vector<3x2x4xi8> to vector<3x8xi8>
+// CHECK: %[[EXT_INPUT_1:.*]] = arith.extsi %[[SC_V_INPUT_1]] : vector<3x8xi8> to vector<3x8xi32>
+// CHECK: %[[B_FILTER_1:.*]] = vector.broadcast %[[V_FILTER_1]] : vector<4xi8> to vector<3x2x4xi8>
+// CHECK: %[[SC_B_FILTER_1:.*]] = vector.shape_cast %[[B_FILTER_1]] : vector<3x2x4xi8> to vector<3x8xi8>
+// CHECK: %[[EXT_FILTER_1:.*]] = arith.extsi %[[SC_B_FILTER_1]] : vector<3x8xi8> to vector<3x8xi32>
+// CHECK: %[[MUL_1:.*]] = arith.muli %[[EXT_INPUT_1]], %[[EXT_FILTER_1]] : vector<3x8xi32>
+// CHECK: %[[ADD_1:.*]] = arith.addi %[[MUL_1]], %[[ADD_0]] : vector<3x8xi32>
+
+// Write the result back in one shot.
+// CHECK: %[[SC_ADD_1:.*]] = vector.shape_cast %[[ADD_1]] : vector<3x8xi32> to vector<3x2x4xi32>
+// CHECK: vector.transfer_write %[[SC_ADD_1]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
+
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc"]} in %arg0 : (!transform.any_op) -> !transform.any_op
@@ -148,77 +166,144 @@ module attributes {transform.with_named_sequence} {
}
}
-// CHECK-LABEL: func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref(
-// CHECK-SAME: %[[VAL_0:.*]]: memref<3x5x4xf32>,
-// CHECK-SAME: %[[VAL_1:.*]]: memref<2x4xf32>,
-// CHECK-SAME: %[[VAL_2:.*]]: memref<3x2x4xf32>) {
-// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index
-// CHECK: %[[VAL_4:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK: %[[VAL_5:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_3]], %[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true, true]} : memref<3x5x4xf32>, vector<3x4x4xf32>
-// CHECK: %[[VAL_6:.*]] = vector.transfer_read %[[VAL_1]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : memref<2x4xf32>, vector<2x4xf32>
-// CHECK: %[[VAL_7:.*]] = vector.transfer_read %[[VAL_2]]{{\[}}%[[VAL_3]], %[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true, true]} : memref<3x2x4xf32>, vector<3x2x4xf32>
-// CHECK: %[[VAL_8:.*]] = vector.extract_strided_slice %[[VAL_5]] {offsets = [0, 0, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x4xf32> to vector<3x2x4xf32>
-// CHECK: %[[VAL_9:.*]] = vector.extract_strided_slice %[[VAL_5]] {offsets = [0, 2, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x4xf32> to vector<3x2x4xf32>
-// CHECK: %[[VAL_10:.*]] = vector.extract %[[VAL_6]][0] : vector<4xf32> from vector<2x4xf32>
-// CHECK: %[[VAL_11:.*]] = vector.extract %[[VAL_6]][1] : vector<4xf32> from vector<2x4xf32>
-// CHECK: %[[VAL_12:.*]] = vector.shape_cast %[[VAL_8]] : vector<3x2x4xf32> to vector<3x8xf32>
-// CHECK: %[[VAL_13:.*]] = vector.shape_cast %[[VAL_7]] : vector<3x2x4xf32> to vector<3x8xf32>
-// CHECK: %[[VAL_14:.*]] = vector.broadcast %[[VAL_10]] : vector<4xf32> to vector<3x2x4xf32>
-// CHECK: %[[VAL_15:.*]] = vector.shape_cast %[[VAL_14]] : vector<3x2x4xf32> to vector<3x8xf32>
-// CHECK: %[[VAL_16:.*]] = vector.fma %[[VAL_12]], %[[VAL_15]], %[[VAL_13]] : vector<3x8xf32>
-// CHECK: %[[VAL_17:.*]] = vector.shape_cast %[[VAL_9]] : vector<3x2x4xf32> to vector<3x8xf32>
-// CHECK: %[[VAL_18:.*]] = vector.broadcast %[[VAL_11]] : vector<4xf32> to vector<3x2x4xf32>
-// CHECK: %[[VAL_19:.*]] = vector.shape_cast %[[VAL_18]] : vector<3x2x4xf32> to vector<3x8xf32>
-// CHECK: %[[VAL_20:.*]] = vector.fma %[[VAL_17]], %[[VAL_19]], %[[VAL_16]] : vector<3x8xf32>
-// CHECK: %[[VAL_21:.*]] = vector.shape_cast %[[VAL_20]] : vector<3x8xf32> to vector<3x2x4xf32>
-// CHECK: vector.transfer_write %[[VAL_21]], %[[VAL_2]]{{\[}}%[[VAL_3]], %[[VAL_3]], %[[VAL_3]]] {in_bounds = [true, true, true]} : vector<3x2x4xf32>, memref<3x2x4xf32>
-// CHECK: return
-// CHECK: }
-
// -----
-func.func @depthwise_conv1d_nwc_wc_3x5x4xi8_memref(%input: memref<3x5x4xi8>, %filter: memref<2x4xi8>, %output: memref<3x2x4xi32>) {
- linalg.depthwise_conv_1d_nwc_wc
- {dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
- ins(%input, %filter : memref<3x5x4xi8>, memref<2x4xi8>)
- outs(%output : memref<3x2x4xi32>)
- return
+func.func @depthwise_conv1d_nwc_wc_3x9x4xi8_tensor_stride_2(%input: tensor<3x9x4xi8>,
+ %filter: tensor<3x4xi8>,
+ %output: tensor<3x3x4xi8>) -> tensor<3x3x4xi8> {
+ %res = linalg.depthwise_conv_1d_nwc_wc
+ {dilations = dense<1> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>}
+ ins(%input, %filter : tensor<3x9x4xi8>, tensor<3x4xi8>)
+ outs(%output : tensor<3x3x4xi8>) -> tensor<3x3x4xi8>
+ return %res : tensor<3x3x4xi8>
}
+// CHECK-LABEL: func.func @depthwise_conv1d_nwc_wc_3x9x4xi8_tensor_stride_2
+// CHECK-SAME: %[[INPUT:.*]]: tensor<3x9x4xi8>,
+// CHECK-SAME: %[[FILTER:.*]]: tensor<3x4xi8>,
+// CHECK-SAME: %[[OUTPUT:.*]]: tensor<3x3x4xi8>) -> tensor<3x3x4xi8> {
+
+// CHECK-DAG: %[[C0_IDX:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C0_I8:.*]] = arith.constant 0 : i8
+
+/// Read the whole data in one shot.
+// CHECK: %[[V_INPUT_R:.*]] = vector.transfer_read %[[INPUT]][%[[C0_IDX]], %[[C0_IDX]], %[[C0_IDX]]], %[[C0_I8]]
+// CHECK: %[[V_FILTER_R:.*]] = vector.transfer_read %[[FILTER]][%[[C0_IDX]], %[[C0_IDX]]], %[[C0_I8]]
+// CHECK: %[[V_OUTPUT_R:.*]] = vector.transfer_read %[[OUTPUT]][%[[C0_IDX]], %[[C0_IDX]], %[[C0_IDX]]], %[[C0_I8]]
+
+// CHECK: %[[V_INPUT_0:.*]] = vector.extract_strided_slice %[[V_INPUT_R]]
+// CHECK-SAME: {offsets = [0, 0, 0], sizes = [3, 1, 4], strides = [1, 1, 1]} : vector<3x7x4xi8> to vector<3x1x4xi8>
+// CHECK: %[[V_INPUT_1:.*]] = vector.extract_strided_slice %[[V_INPUT_R]]
+// CHECK-SAME: {offsets = [0, 2, 0], sizes = [3, 1, 4], strides = [1, 1, 1]} : vector<3x7x4xi8> to vector<3x1x4xi8>
+// CHECK: %[[V_INPUT_2:.*]] = vector.extract_strided_slice %[[V_INPUT_R]]
+// CHECK-SAME: {offsets = [0, 4, 0], sizes = [3, 1, 4], strides = [1, 1, 1]} : vector<3x7x4xi8> to vector<3x1x4xi8>
+// CHECK: %[[V_INPUT_3:.*]] = vector.extract_strided_slice %[[V_INPUT_R]]
+// CHECK-SAME: {offsets = [0, 1, 0], sizes = [3, 1, 4], strides = [1, 1, 1]} : vector<3x7x4xi8> to vector<3x1x4xi8>
+// CHECK: %[[V_INPUT_4:.*]] = vector.extract_strided_slice %[[V_INPUT_R]]
+// CHECK-SAME: {offsets = [0, 3, 0], sizes = [3, 1, 4], strides = [1, 1, 1]} : vector<3x7x4xi8> to vector<3x1x4xi8>
+// CHECK: %[[V_INPUT_5:.*]] = vector.extract_strided_slice %[[V_INPUT_R]]
+// CHECK-SAME: {offsets = [0, 5, 0], sizes = [3, 1, 4], strides = [1, 1, 1]} : vector<3x7x4xi8> to vector<3x1x4xi8>
+// CHECK: %[[V_INPUT_6:.*]] = vector.extract_strided_slice %[[V_INPUT_R]]
+// CHECK-SAME: {offsets = [0, 2, 0], sizes = [3, 1, 4], strides = [1, 1, 1]} : vector<3x7x4xi8> to vector<3x1x4xi8>
+// CHECK: %[[V_INPUT_7:.*]] = vector.extract_strided_slice %[[V_INPUT_R]]
+// CHECK-SAME: {offsets = [0, 4, 0], sizes = [3, 1, 4], strides = [1, 1, 1]} : vector<3x7x4xi8> to vector<3x1x4xi8>
+// CHECK: %[[V_INPUT_8:.*]] = vector.extract_strided_slice %[[V_INPUT_R]]
+// CHECK-SAME: {offsets = [0, 6, 0], sizes = [3, 1, 4], strides = [1, 1, 1]} : vector<3x7x4xi8> to vector<3x1x4xi8>
+
+// CHECK: %[[V_FILTER_0:.*]] = vector.extract %[[V_FILTER_R]][0] : vector<4xi8> from vector<3x4xi8>
+// CHECK: %[[V_FILTER_1:.*]] = vector.extract %[[V_FILTER_R]][1] : vector<4xi8> from vector<3x4xi8>
+// CHECK: %[[V_FILTER_2:.*]] = vector.extract %[[V_FILTER_R]][2] : vector<4xi8> from vector<3x4xi8>
+
+// CHECK: %[[V_OUTPUT_0:.*]] = vector.extract_strided_slice %[[V_OUTPUT_R]]
+// CHECK-SAME: {offsets = [0, 0, 0], sizes = [3, 1, 4], strides = [1, 1, 1]} : vector<3x3x4xi8> to vector<3x1x4xi8>
+// CHECK: %[[V_OUTPUT_1:.*]] = vector.extract_strided_slice %[[V_OUTPUT_R]]
+// CHECK-SAME: {offsets = [0, 1, 0], sizes = [3, 1, 4], strides = [1, 1, 1]} : vector<3x3x4xi8> to vector<3x1x4xi8>
+// CHECK: %[[V_OUTPUT_2:.*]] = vector.extract_strided_slice %[[V_OUTPUT_R]]
+// CHECK-SAME: {offsets = [0, 2, 0], sizes = [3, 1, 4], strides = [1, 1, 1]} : vector<3x3x4xi8> to vector<3x1x4xi8>
+
+/// w == 0, kw == 0
+// CHECK: %[[VAL_23:.*]] = vector.shape_cast %[[V_INPUT_0]] : vector<3x1x4xi8> to vector<3x4xi8>
+// CHECK: %[[VAL_24:.*]] = vector.shape_cast %[[V_OUTPUT_0]] : vector<3x1x4xi8> to vector<3x4xi8>
+// CHECK: %[[B_FILTER_0:.*]] = vector.broadcast %[[V_FILTER_0]] : vector<4xi8> to vector<3x1x4xi8>
+// CHECK: %[[VAL_26:.*]] = vector.shape_cast %[[B_FILTER_0]] : vector<3x1x4xi8> to vector<3x4xi8>
+// CHECK: %[[VAL_27:.*]] = arith.muli %[[VAL_23]], %[[VAL_26]] : vector<3x4xi8>
+// CHECK: %[[VAL_28:.*]] = arith.addi %[[VAL_27]], %[[VAL_24]] : vector<3x4xi8>
+
+/// w == 1, kw == 0
+// CHECK: %[[VAL_29:.*]] = vector.shape_cast %[[V_INPUT_1]] : vector<3x1x4xi8> to vector<3x4xi8>
+// CHECK: %[[VAL_30:.*]] = vector.shape_cast %[[V_OUTPUT_1]] : vector<3x1x4xi8> to vector<3x4xi8>
+// CHECK: %[[B_FILTER_0:.*]] = vector.broadcast %[[V_FILTER_0]] : vector<4xi8> to vector<3x1x4xi8>
+// CHECK: %[[VAL_32:.*]] = vector.shape_cast %[[B_FILTER_0]] : vector<3x1x4xi8> to vector<3x4xi8>
+// CHECK: %[[VAL_33:.*]] = arith.muli %[[VAL_29]], %[[VAL_32]] : vector<3x4xi8>
+// CHECK: %[[VAL_34:.*]] = arith.addi %[[VAL_33]], %[[VAL_30]] : vector<3x4xi8>
+
+/// w == 2, kw == 0
+// CHECK: %[[VAL_35:.*]] = vector.shape_cast %[[V_INPUT_2]] : vector<3x1x4xi8> to vector<3x4xi8>
+// CHECK: %[[VAL_36:.*]] = vector.shape_cast %[[V_OUTPUT_2]] : vector<3x1x4xi8> to vector<3x4xi8>
+// CHECK: %[[B_FILTER_0:.*]] = vector.broadcast %[[V_FILTER_0]] : vector<4xi8> to vector<3x1x4xi8>
+// CHECK: %[[VAL_38:.*]] = vector.shape_cast %[[B_FILTER_0]] : vector<3x1x4xi8> to vector<3x4xi8>
+// CHECK: %[[VAL_39:.*]] = arith.muli %[[VAL_35]], %[[VAL_38]] : vector<3x4xi8>
+// CHECK: %[[VAL_40:.*]] = arith.addi %[[VAL_39]], %[[VAL_36]] : vector<3x4xi8>
+
+/// w == 3, kw == 1
+// CHECK: %[[VAL_41:.*]] = vector.shape_cast %[[V_INPUT_3]] : vector<3x1x4xi8> to vector<3x4xi8>
+// CHECK: %[[B_FILTER_1:.*]] = vector.broadcast %[[V_FILTER_1]] : vector<4xi8> to vector<3x1x4xi8>
+// CHECK: %[[VAL_43:.*]] = vector.shape_cast %[[B_FILTER_1]] : vector<3x1x4xi8> to vector<3x4xi8>
+// CHECK: %[[VAL_44:.*]] = arith.muli %[[VAL_41]], %[[VAL_43]] : vector<3x4xi8>
+// CHECK: %[[VAL_45:.*]] = arith.addi %[[VAL_44]], %[[VAL_28]] : vector<3x4xi8>
+
+/// w == 4, kw == 1
+// CHECK: %[[VAL_46:.*]] = vector.shape_cast %[[V_INPUT_4]] : vector<3x1x4xi8> to vector<3x4xi8>
+// CHECK: %[[B_FILTER_1:.*]] = vector.broadcast %[[V_FILTER_1]] : vector<4xi8> to vector<3x1x4xi8>
+// CHECK: %[[VAL_48:.*]] = vector.shape_cast %[[B_FILTER_1]] : vector<3x1x4xi8> to vector<3x4xi8>
+// CHECK: %[[VAL_49:.*]] = arith.muli %[[VAL_46]], %[[VAL_48]] : vector<3x4xi8>
+// CHECK: %[[VAL_50:.*]] = arith.addi %[[VAL_49]], %[[VAL_34]] : vector<3x4xi8>
+
+/// w == 5, kw == 1
+// CHECK: %[[VAL_51:.*]] = vector.shape_cast %[[V_INPUT_5]] : vector<3x1x4xi8> to vector<3x4xi8>
+// CHECK: %[[B_FILTER_1:.*]] = vector.broadcast %[[V_FILTER_1]] : vector<4xi8> to vector<3x1x4xi8>
+// CHECK: %[[VAL_53:.*]] = vector.shape_cast %[[B_FILTER_1]] : vector<3x1x4xi8> to vector<3x4xi8>
+// CHECK: %[[VAL_54:.*]] = arith.muli %[[VAL_51]], %[[VAL_53]] : vector<3x4xi8>
+// CHECK: %[[VAL_55:.*]] = arith.addi %[[VAL_54]], %[[VAL_40]] : vector<3x4xi8>
+
+/// w == 6, kw == 2
+// CHECK: %[[VAL_56:.*]] = vector.shape_cast %[[V_INPUT_6]] : vector<3x1x4xi8> to vector<3x4xi8>
+// CHECK: %[[B_FILTER_2:.*]] = vector.broadcast %[[V_FILTER_2]] : vector<4xi8> to vector<3x1x4xi8>
+// CHECK: %[[VAL_58:.*]] = vector.shape_cast %[[B_FILTER_2]] : vector<3x1x4xi8> to vector<3x4xi8>
+// CHECK: %[[VAL_59:.*]] = arith.muli %[[VAL_56]], %[[VAL_58]] : vector<3x4xi8>
+// CHECK: %[[VAL_60:.*]] = arith.addi %[[VAL_59]], %[[VAL_45]] : vector<3x4xi8>
+
+/// w == 7, kw == 2
+// CHECK: %[[VAL_61:.*]] = vector.shape_cast %[[VAL_60]] : vector<3x4xi8> to vector<3x1x4xi8>
+// CHECK: %[[VAL_62:.*]] = vector.shape_cast %[[V_INPUT_7]] : vector<3x1x4xi8> to vector<3x4xi8>
+// CHECK: %[[B_FILTER_2:.*]] = vector.broadcast %[[V_FILTER_2]] : vector<4xi8> to vector<3x1x4xi8>
+// CHECK: %[[VAL_64:.*]] = vector.shape_cast %[[B_FILTER_2]] : vector<3x1x4xi8> to vector<3x4xi8>
+// CHECK: %[[VAL_65:.*]] = arith.muli %[[VAL_62]], %[[VAL_64]] : vector<3x4xi8>
+// CHECK: %[[VAL_66:.*]] = arith.addi %[[VAL_65]], %[[VAL_50]] : vector<3x4xi8>
+
+/// w == 8, kw == 2
+// CHECK: %[[VAL_67:.*]] = vector.shape_cast %[[VAL_66]] : vector<3x4xi8> to vector<3x1x4xi8>
+// CHECK: %[[VAL_68:.*]] = vector.shape_cast %[[V_INPUT_8]] : vector<3x1x4xi8> to vector<3x4xi8>
+// CHECK: %[[B_FILTER_2:.*]] = vector.broadcast %[[V_FILTER_2]] : vector<4xi8> to vector<3x1x4xi8>
+// CHECK: %[[VAL_70:.*]] = vector.shape_cast %[[B_FILTER_2]] : vector<3x1x4xi8> to vector<3x4xi8>
+// CHECK: %[[VAL_71:.*]] = arith.muli %[[VAL_68]], %[[VAL_70]] : vector<3x4xi8>
+// CHECK: %[[VAL_72:.*]] = arith.addi %[[VAL_71]], %[[VAL_55]] : vector<3x4xi8>
+
+// Write the result back.
+// CHECK: %[[VAL_73:.*]] = vector.shape_cast %[[VAL_72]] : vector<3x4xi8> to vector<3x1x4xi8>
+// CHECK: %[[VAL_74:.*]] = vector.insert_strided_slice %[[VAL_61]], %[[V_OUTPUT_R]]
+// CHECK-SAME: {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<3x1x4xi8> into vector<3x3x4xi8>
+// CHECK: %[[VAL_75:.*]] = vector.insert_strided_slice %[[VAL_67]], %[[VAL_74]]
+// CHECK-SAME: {offsets = [0, 1, 0], strides = [1, 1, 1]} : vector<3x1x4xi8> into vector<3x3x4xi8>
+// CHECK: %[[VAL_76:.*]] = vector.insert_strided_slice %[[VAL_73]], %[[VAL_75]]
+// CHECK-SAME: {offsets = [0, 2, 0], strides = [1, 1, 1]} : vector<3x1x4xi8> into vector<3x3x4xi8>
+// CHECK: %[[VAL_77:.*]] = vector.transfer_write %[[VAL_76]], %[[OUTPUT]][%[[C0_IDX]], %[[C0_IDX]], %[[C0_IDX]]]
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc"]} in %arg0 : (!transform.any_op) -> !transform.any_op
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
- %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
+ %2 = transform.structured.vectorize_children_and_apply_patterns %1 {flatten_1d_depthwise_conv} : (!transform.any_op) -> !transform.any_op
transform.yield
}
}
-// CHECK-LABEL: func.func @depthwise_conv1d_nwc_wc_3x5x4xi8_memref(
-// CHECK-SAME: %[[VAL_0:.*]]: memref<3x5x4xi8>,
-// CHECK-SAME: %[[VAL_1:.*]]: memref<2x4xi8>,
-// CHECK-SAME: %[[VAL_2:.*]]: memref<3x2x4xi32>) {
-// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index
-// CHECK: %[[VAL_4:.*]] = arith.constant 0 : i8
-// CHECK: %[[VAL_5:.*]] = arith.constant 0 : i32
-// CHECK: %[[VAL_6:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_3]], %[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true, true]} : memref<3x5x4xi8>, vector<3x4x4xi8>
-// CHECK: %[[VAL_7:.*]] = vector.transfer_read %[[VAL_1]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : memref<2x4xi8>, vector<2x4xi8>
-// CHECK: %[[VAL_8:.*]] = vector.transfer_read %[[VAL_2]]{{\[}}%[[VAL_3]], %[[VAL_3]], %[[VAL_3]]], %[[VAL_5]] {in_bounds = [true, true, true]} : memref<3x2x4xi32>, vector<3x2x4xi32>
-// CHECK: %[[VAL_9:.*]] = vector.extract_strided_slice %[[VAL_6]] {offsets = [0, 0, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x4xi8> to vector<3x2x4xi8>
-// CHECK: %[[VAL_10:.*]] = vector.extract_strided_slice %[[VAL_6]] {offsets = [0, 2, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x4xi8> to vector<3x2x4xi8>
-// CHECK: %[[VAL_11:.*]] = vector.extract %[[VAL_7]][0] : vector<4xi8> from vector<2x4xi8>
-// CHECK: %[[VAL_12:.*]] = vector.extract %[[VAL_7]][1] : vector<4xi8> from vector<2x4xi8>
-// CHECK: %[[VAL_13:.*]] = arith.extsi %[[VAL_9]] : vector<3x2x4xi8> to vector<3x2x4xi32>
-// CHECK: %[[VAL_14:.*]] = arith.extsi %[[VAL_11]] : vector<4xi8> to vector<4xi32>
-// CHECK: %[[VAL_15:.*]] = vector.broadcast %[[VAL_14]] : vector<4xi32> to vector<3x2x4xi32>
-// CHECK: %[[VAL_16:.*]] = arith.muli %[[VAL_13]], %[[VAL_15]] : vector<3x2x4xi32>
-// CHECK: %[[VAL_17:.*]] = arith.addi %[[VAL_16]], %[[VAL_8]] : vector<3x2x4xi32>
-// CHECK: %[[VAL_18:.*]] = arith.extsi %[[VAL_10]] : vector<3x2x4xi8> to vector<3x2x4xi32>
-// CHECK: %[[VAL_19:.*]] = arith.extsi %[[VAL_12]] : vector<4xi8> to vector<4xi32>
-// CHECK: %[[VAL_20:.*]] = vector.broadcast %[[VAL_19]] : vector<4xi32> to vector<3x2x4xi32>
-// CHECK: %[[VAL_21:.*]] = arith.muli %[[VAL_18]], %[[VAL_20]] : vector<3x2x4xi32>
-// CHECK: %[[VAL_22:.*]] = arith.addi %[[VAL_21]], %[[VAL_17]] : vector<3x2x4xi32>
-// CHECK: vector.transfer_write %[[VAL_22]], %[[VAL_2]]{{\[}}%[[VAL_3]], %[[VAL_3]], %[[VAL_3]]] {in_bounds = [true, true, true]} : vector<3x2x4xi32>, memref<3x2x4xi32>
-// CHECK: return
-// CHECK: }
-
diff --git a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
index 59b242d530442..93e36a69567bd 100644
--- a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
@@ -535,9 +535,9 @@ func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref(%input: memref<3x5x4xf32>, %
// CHECK-DAG: %[[F0:.+]] = arith.constant 0.000000e+00 : f32
/// Read the whole data in one shot.
-// CHECK-DAG: %[[V_INPUT_R:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]]
-// CHECK-DAG: %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]]]
-// CHECK-DAG: %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
+// CHECK: %[[V_INPUT_R:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]]
+// CHECK: %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]]]
+// CHECK: %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
// CHECK: %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
// CHECK-SAME: {offsets = [0, 0, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x4xf32> to vector<3x2x4xf32>
@@ -575,9 +575,9 @@ func.func @depthwise_conv1d_nwc_wc_3x5x4xi8_memref(%input: memref<3x5x4xi8>, %fi
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
/// Read the whole data in one shot.
-// CHECK-DAG: %[[V_INPUT_R:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]]
-// CHECK-DAG: %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]]]
-// CHECK-DAG: %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
+// CHECK: %[[V_INPUT_R:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]]
+// CHECK: %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]]]
+// CHECK: %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
// CHECK: %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
// CHECK-SAME: {offsets = [0, 0, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x4xi8> to vector<3x2x4xi8>
More information about the Mlir-commits
mailing list