[Mlir-commits] [mlir] [mlir][linalg][conv] Flatten the channel dimension when vectorizing (PR #71918)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Nov 10 02:33:24 PST 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-linalg
Author: Andrzej WarzyĆski (banach-space)
<details>
<summary>Changes</summary>
The current vectorization of 1D depthwise convolutions in Linalg is
_sub-optimal_ for tensor with a low number of channel dimensions, e.g.:
```mlir
linalg.depthwise_conv_1d_nwc_wc
{dilations = dense<1> : vector<1xi64>,
strides = dense<1> : vector<1xi64>}
ins(%input, %filter : tensor<1x8x3xi8>, tensor<1x3xi8>)
outs(%output : tensor<1x8x3xi8>) -> tensor<1x8x3xi8>
```
That's due to the fact that ultimately (i.e. at LLVM level),
vectorization happens along the trailing dimension (i.e. the channel
dimension). In this case it leads to vectors with 3 elements (or worse,
if there's e.g. only 1 channel dimension). For comparison, a 128 bit
wide vector registers can hold 16 x i8.
Instead, this patch adds an option to flatten/collapse the channel
dimension into the width dimension of the input/output:
```mlir
%collapsed = tensor.collapse_shape %input [[0], [1, 2]] : tensor<1x8x3xi8> into tensor<1x24xi8>
%collapsed_0 = tensor.collapse_shape %outpu [[0], [1, 2]] : tensor<1x8x3xi8> into tensor<1x24xi8>
```
(Note that for this to work, the filter is broadcast rather than
re-shaped. Please see the test cases for details).
The new vectorization strategy is implemented in `depthwiseConvFlatten`,
which was implemented based on `depthwiseConvGeneric` (i.e. the original
vectorization hook). The current vectorization is preserved and kept as
the default option. New vectorization can be selected through e.g. a
transform dialect attribute:
```mlir
transform.structured.vectorize_children_and_apply_patterns %conv {flatten_1d_depthwise_conv}
```
A forthcoming patch will implement a strategy to automatically switch
between the two implementations, depending on the shape of the input
tensors.
Co-authored by: Bradley Smith <bradley.smith@<!-- -->arm.com>
---
Patch is 38.22 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/71918.diff
6 Files Affected:
- (modified) mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td (+3-1)
- (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h (+2-1)
- (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+16-5)
- (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+217-11)
- (added) mlir/test/Dialect/Linalg/vectorize-convolution-flatten.mlir (+222)
- (modified) mlir/test/Dialect/Linalg/vectorize-convolution.mlir (+6-6)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index f1c3d717f1fa951..310efe164f93950 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -2034,6 +2034,7 @@ def VectorizeChildrenAndApplyPatternsOp :
let arguments = (ins TransformHandleTypeInterface:$target,
UnitAttr:$vectorize_padding,
UnitAttr:$vectorize_nd_extract,
+ UnitAttr:$flatten_1d_depthwise_conv,
UnitAttr:$disable_multi_reduction_to_contract_patterns,
UnitAttr:$disable_transfer_permutation_map_lowering_patterns);
let results = (outs TransformHandleTypeInterface:$transformed);
@@ -2045,7 +2046,8 @@ def VectorizeChildrenAndApplyPatternsOp :
let builders = [
OpBuilder<(ins "Value":$target,
CArg<"bool", "false">:$vectorizePadding,
- CArg<"bool", "false">:$vectorizeNDExtract)>,
+ CArg<"bool", "false">:$vectorizeNDExtract,
+ CArg<"bool", "false">:$flatten1DDepthwise)>
];
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 6547648f7495c31..a4aee1f45249c2b 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -753,7 +753,8 @@ LogicalResult deallocateGPUPrivateMemory(OpBuilder &, Value /*buffer*/);
LogicalResult vectorize(RewriterBase &rewriter, Operation *op,
ArrayRef<int64_t> inputVectorSizes = {},
ArrayRef<bool> inputScalableVecDims = {},
- bool vectorizeNDExtract = false);
+ bool vectorizeNDExtract = false,
+ bool flatten1DDepthwiseConv = false);
/// Emit a suitable vector form for a Copy op with fully static shape.
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp);
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index de4965f937162ea..35e8be7806928e1 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2937,7 +2937,7 @@ LogicalResult TileUsingForallOp::verify() {
void transform::VectorizeChildrenAndApplyPatternsOp::build(
OpBuilder &builder, OperationState &result, Value target,
- bool vectorizePadding, bool vectorizeExtract) {
+ bool vectorizePadding, bool vectorizeExtract, bool flatten1DDepthwiseConv) {
result.addOperands(target);
if (vectorizePadding) {
result.addAttribute(
@@ -2951,6 +2951,12 @@ void transform::VectorizeChildrenAndApplyPatternsOp::build(
result.name),
builder.getUnitAttr());
}
+ if (flatten1DDepthwiseConv) {
+ result.addAttribute(
+ VectorizeChildrenAndApplyPatternsOp::getFlatten_1dDepthwiseConvAttrName(
+ result.name),
+ builder.getUnitAttr());
+ }
result.addTypes(transform::AnyOpType::get(builder.getContext()));
}
@@ -2959,22 +2965,26 @@ namespace {
/// VectorizeChildrenAndApplyPatternsOp::applyToOne.
struct VectorizationPattern : public RewritePattern {
explicit VectorizationPattern(MLIRContext *context,
- bool vectorizeExtract = false)
+ bool vectorizeExtract = false,
+ bool flattenConv = false)
: RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context),
- vectorizeNDExtract(vectorizeExtract) {}
+ vectorizeNDExtract(vectorizeExtract),
+ flatten1DDepthwiseConv(flattenConv) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
if (!linalgOp)
return rewriter.notifyMatchFailure(op, "expected Linalg Op");
return vectorize(rewriter, linalgOp, /*inputVectorSizes=*/{},
- /*scalableVecDims=*/{}, vectorizeNDExtract);
+ /*scalableVecDims=*/{}, vectorizeNDExtract,
+ flatten1DDepthwiseConv);
}
private:
/// Controls whether to vectorize `tensor.extract` when the input tensor is
/// rank >= 2.
bool vectorizeNDExtract = false;
+ bool flatten1DDepthwiseConv = false;
};
} // namespace
@@ -2991,7 +3001,8 @@ transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
MLIRContext *ctx = getContext();
RewritePatternSet patterns(ctx);
- patterns.add<VectorizationPattern>(ctx, getVectorizeNdExtract());
+ patterns.add<VectorizationPattern>(ctx, getVectorizeNdExtract(),
+ getFlatten_1dDepthwiseConv());
if (!getDisableTransferPermutationMapLoweringPatterns())
vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index b8d82159856825f..f6f74b448edf9a8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -44,8 +44,9 @@ using namespace mlir::linalg;
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
/// Try to vectorize `convOp` as a convolution.
-static FailureOr<Operation *> vectorizeConvolution(RewriterBase &rewriter,
- LinalgOp convOp);
+static FailureOr<Operation *>
+vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp,
+ bool flatten1DDepthwiseConv = false);
/// Return the unique instance of OpType in `block` if it is indeed unique.
/// Return null if none or more than 1 instances exist.
@@ -1664,7 +1665,8 @@ static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) {
LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
ArrayRef<int64_t> inputVectorSizes,
ArrayRef<bool> inputScalableVecDims,
- bool vectorizeNDExtract) {
+ bool vectorizeNDExtract,
+ bool flatten1DDepthwiseConv) {
LDBG("Attempting to vectorize:\n" << *op << "\n");
LDBG("Input vector sizes: ");
LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
@@ -1696,8 +1698,8 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
// TODO: isaConvolutionOpInterface that can also infer from generic
// features. Will require stride/dilation attributes inference.
if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
- FailureOr<Operation *> convOr =
- vectorizeConvolution(rewriter, linalgOp);
+ FailureOr<Operation *> convOr = vectorizeConvolution(
+ rewriter, linalgOp, flatten1DDepthwiseConv);
if (succeeded(convOr)) {
llvm::append_range(results, (*convOr)->getResults());
return success();
@@ -2822,7 +2824,7 @@ struct Conv1DGenerator
/// kw is always unrolled.
/// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
/// > 1.
- FailureOr<Operation *> depthwiseConv() {
+ FailureOr<Operation *> depthwiseConvGeneric() {
if (!valid)
return rewriter.notifyMatchFailure(op, "unvectorizable depthwise conv");
@@ -2936,6 +2938,176 @@ struct Conv1DGenerator
.getOperation();
}
+ /// Generate a vector implementation for ("flatten channel dim"):
+ /// ```
+ /// Op def: ( n, w, c, kw)
+ /// Iters: ({Par(), Par(), Par(), Red()})
+ /// Layout: {{n, 1 * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
+ /// ```
+ /// c of the input/output is collapsed with w. kw is always unrolled and
+ /// broadcast to match w.
+ ///
+ /// TODO: Add support for non-unit stride/dilation
+ /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
+ /// > 1.
+ FailureOr<Operation *> depthwiseConvFlatten() {
+ if (!valid)
+ return rewriter.notifyMatchFailure(op, "unvectorizable depthwise conv");
+
+ int64_t nSize, iSize, wSize, cSize, kwSize;
+ // kernel{kw, c}
+ bindShapeDims(rhsShapedType, kwSize, cSize);
+ // out{n, w, c}
+ bindShapeDims(resShapedType, nSize, wSize);
+ // in{n, w, c}
+ bindShapeDims(lhsShapedType, nSize, iSize);
+
+ vector::TransferWriteOp write;
+ Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+
+ if (strideW == 1)
+ return rewriter.notifyMatchFailure(
+ op, "Non-unit strides are not supported yet");
+ if (dilationW == 1)
+ return rewriter.notifyMatchFailure(
+ op, "Non-unit dilations are not supported yet");
+
+ Type lhsEltType = lhsShapedType.getElementType();
+ Type rhsEltType = rhsShapedType.getElementType();
+ Type resEltType = resShapedType.getElementType();
+ VectorType rhsType = VectorType::get({kwSize, cSize}, rhsEltType);
+ VectorType lhsType = VectorType::get(
+ {nSize,
+ // iw = (ow * sw + kw * dw - 1) * c
+ // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
+ (((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1) *
+ cSize},
+ lhsEltType);
+
+ VectorType resType = VectorType::get({nSize, wSize * cSize}, resEltType);
+
+ Value res, lhs, lhsFlat, resFlat;
+ // Read rhs slice of size {kw, c} @ [0, 0].
+ Value rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
+ ValueRange{zero, zero});
+
+ SmallVector<ReassociationIndices> reassociation = {{0}, {1, 2}};
+
+ // Flatten w and c dimensions
+ auto lhsTypeCollapsed = VectorType::get({nSize, iSize * cSize}, lhsEltType);
+ auto linalgOp = dyn_cast<LinalgOp>(op);
+ lhsFlat =
+ linalgOp.hasTensorSemantics()
+ ? (Value)rewriter.create<tensor::CollapseShapeOp>(
+ loc,
+ RankedTensorType::get(lhsTypeCollapsed.getShape(),
+ lhsEltType),
+ lhsShaped, reassociation)
+ : (Value)rewriter.create<memref::CollapseShapeOp>(
+ loc, MemRefType::get(lhsTypeCollapsed.getShape(), lhsEltType),
+ lhsShaped, reassociation);
+ resFlat =
+ linalgOp.hasTensorSemantics()
+ ? (Value)rewriter.create<tensor::CollapseShapeOp>(
+ loc, RankedTensorType::get(resType.getShape(), resEltType),
+ resShaped, reassociation)
+ : (Value)rewriter.create<memref::CollapseShapeOp>(
+ loc, MemRefType::get(resType.getShape(), resEltType),
+ resShaped, reassociation);
+
+ // Read lhs slice of size {n, (w * wSize + kw * dilationW) * c} @ [0,
+ // 0].
+ lhs = rewriter.create<vector::TransferReadOp>(loc, lhsType, lhsFlat,
+ ValueRange{zero, zero});
+ // Read res slice of size {n, w * c} @ [0, 0].
+ res = rewriter.create<vector::TransferReadOp>(loc, resType, resFlat,
+ ValueRange{zero, zero});
+
+ //===------------------------------------------------------------------===//
+ // Begin vector-only rewrite part
+ //===------------------------------------------------------------------===//
+ // Unroll along kw and read slices of lhs and rhs.
+ SmallVector<Value> lhsVals, rhsVals, resVals;
+ // Extract lhs slice of size {n, wSizeStep * c}
+ // @ [0, (sw * w + dw * kw) * cSize].
+ for (int64_t kw = 0; kw < kwSize; ++kw) {
+ for (int64_t w = 0; w < wSize; w += wSize) {
+ lhsVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
+ loc, lhs,
+ /*offsets=*/
+ ArrayRef<int64_t>{0, (w * wSize + kw * dilationW) * cSize},
+ /*sizes=*/ArrayRef<int64_t>{nSize, wSize * cSize},
+ /*strides=*/ArrayRef<int64_t>{1, 1}));
+ }
+ }
+ // Extract rhs slice of size {c} @ [kw].
+ for (int64_t kw = 0; kw < kwSize; ++kw) {
+ rhsVals.push_back(rewriter.create<vector::ExtractOp>(
+ loc, rhs, /*offsets=*/ArrayRef<int64_t>{kw}));
+ }
+
+ // Extract res slice
+ // Flattened case: {n, wSizeStep * c} @ [0, w].
+ for (int64_t w = 0; w < wSize; w += wSize) {
+ resVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
+ loc, res,
+ /*offsets=*/ArrayRef<int64_t>{0, w * cSize},
+ /*sizes=*/ArrayRef<int64_t>{nSize, wSize * cSize},
+ /*strides=*/ArrayRef<int64_t>{1, 1}));
+ }
+
+ auto linearIndex = [&](int64_t kw, int64_t w) {
+ return kw * (wSize / wSize) + w;
+ };
+
+ // Compute contraction:
+ // O{n, w * c} += I{n, (sw * w + dw * kw) * c} * F{c}
+ for (int64_t kw = 0; kw < kwSize; ++kw) {
+ for (int64_t w = 0; w < wSize; w += wSize) {
+ resVals[w] = depthwiseConv1dFlatSliceAsMulAcc(
+ rewriter, loc, lhsVals[linearIndex(kw, w)], rhsVals[kw],
+ resVals[w]);
+ }
+ }
+
+ // Its possible we failed to create the Fma.
+ if (!llvm::all_of(resVals, [](Value v) { return v; })) {
+ // Manually revert (in reverse order) to avoid leaving a bad IR state.
+ for (auto &collection :
+ {resVals, rhsVals, lhsVals, {res, rhs, lhs, zero}})
+ for (Value v : collection)
+ rewriter.eraseOp(v.getDefiningOp());
+ return rewriter.notifyMatchFailure(op, "failed to create FMA");
+ }
+
+ // Write back res slice. This does not depend on kw.
+ // Flattened case: {n, wSizeStep * c} @ [0, w].
+ for (int64_t w = 0; w < wSize; w += wSize) {
+ res = rewriter.create<vector::InsertStridedSliceOp>(
+ loc, resVals[w], res,
+ /*offsets=*/ArrayRef<int64_t>{0, w * cSize},
+ /*strides=*/ArrayRef<int64_t>{1, 1});
+ }
+ //===------------------------------------------------------------------===//
+ // End vector-only rewrite part
+ //===------------------------------------------------------------------===//
+ // Write back res slice of size {n, w * c} @ [0, 0].
+ mlir::vector::TransferWriteOp tWrite =
+ rewriter.create<vector::TransferWriteOp>(loc, res, resFlat,
+ ValueRange{zero, zero});
+
+ // A tensor has to be re-shaped back to it's original shape ...
+ if (linalgOp.hasTensorSemantics())
+ // Re-expand shape
+ return rewriter
+ .create<tensor::ExpandShapeOp>(loc, resShapedType, tWrite.getResult(),
+ reassociation)
+ .getOperation();
+ /// ... memrefs don't requie reshaping (re-shape is just a different view
+ /// into the same memref)
+ return tWrite.getOperation();
+ }
+
/// Lower lhs{n, w, c} * rhs{c} -> res{n, w, c} to MulAcc
Value depthwiseConv1dSliceAsMulAcc(RewriterBase &rewriter, Location loc,
Value lhs, Value rhs, Value res) {
@@ -2959,6 +3131,39 @@ struct Conv1DGenerator
return rewriter.create<arith::AddIOp>(loc, mul, res);
}
+ /// Lower lhs{n, w * c} * rhs{c} -> res{n, w * c} to MulAcc
+ Value depthwiseConv1dFlatSliceAsMulAcc(RewriterBase &rewriter, Location loc,
+ Value lhs, Value rhs, Value res) {
+ auto rhsTy = rhs.getType().cast<ShapedType>();
+ auto resTy = res.getType().cast<ShapedType>();
+
+ lhs = promote(rewriter, loc, lhs, resTy);
+
+ auto rhsSize = rhs.getType().cast<VectorType>().getShape()[0];
+ auto resSize = res.getType().cast<VectorType>().getShape()[1];
+
+ SmallVector<int64_t, 16> indicies;
+ for (int i = 0; i < resSize / rhsSize; ++i) {
+ for (int j = 0; j < rhsSize; ++j)
+ indicies.push_back(j);
+ }
+
+ rhs = rewriter.create<vector::ShuffleOp>(loc, rhs, rhs, indicies);
+
+ rhs = rewriter.create<vector::BroadcastOp>(
+ loc, resTy.clone(rhsTy.getElementType()), rhs);
+ rhs = promote(rewriter, loc, rhs, resTy);
+
+ if (!lhs || !rhs)
+ return nullptr;
+
+ if (resTy.getElementType().isa<FloatType>())
+ return rewriter.create<vector::FMAOp>(loc, lhs, rhs, res);
+
+ auto mul = rewriter.create<arith::MulIOp>(loc, lhs, rhs);
+ return rewriter.create<arith::AddIOp>(loc, mul, res);
+ }
+
/// Entry point for non-channeled convolution:
/// {{w + kw}, {kw}, {w}}
FailureOr<Operation *> generateNonChanneledConv() {
@@ -3049,7 +3254,7 @@ struct Conv1DGenerator
/// Entry point that transposes into the common form:
/// {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
- FailureOr<Operation *> generateDilatedConv() {
+ FailureOr<Operation *> generateDilatedConv(bool flatten = false) {
AffineExpr n, w, c, kw;
bindDims(ctx, n, w, c, kw);
if (!iters({Par(), Par(), Par(), Red()}))
@@ -3060,7 +3265,7 @@ struct Conv1DGenerator
if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
/*rhsIndex*/ {kw, c},
/*resIndex*/ {n, w, c}}))
- return depthwiseConv();
+ return flatten ? depthwiseConvFlatten() : depthwiseConvGeneric();
return rewriter.notifyMatchFailure(op, "not a depthwise::Nwc layout");
}
@@ -3125,8 +3330,9 @@ struct Conv1DGenerator
/// Helper function to vectorize a LinalgOp with convolution semantics.
// TODO: extend the generic vectorization to support windows and drop this.
-static FailureOr<Operation *> vectorizeConvolution(RewriterBase &rewriter,
- LinalgOp op) {
+static FailureOr<Operation *>
+vectorizeConvolution(RewriterBase &rewriter, LinalgOp op,
+ bool flatten1DDepthwiseConv) {
// The ConvolutionOpInterface gives us guarantees of existence for
// strides/dilations. However, we do not need to rely on those, we can simply
// use them if present, otherwise use the default and let the generic conv.
@@ -3151,7 +3357,7 @@ static FailureOr<Operation *> vectorizeConvolution(RewriterBase &rewriter,
res = e.generateNcwPooling();
if (succeeded(res))
return res;
- return e.generateDilatedConv();
+ return e.generateDilatedConv(flatten1DDepthwiseConv);
}
struct VectorizeConvolution : public OpInterfaceRewritePattern<LinalgOp> {
diff --git a/mlir/test/Dialect/Linalg/vectorize-convolution-flatten.mlir b/mlir/test/Dialect/Linalg/vectorize-convolution-flatten.mlir
new file mode 100644
index 000000000000000..6b0f920bfa42e7f
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/vectorize-convolution-flatten.mlir
@@ -0,0 +1,222 @@
+// RUN: mlir-opt -split-input-file -transform-interpreter %s | FileCheck %s
+
+func.func @flatten_tensor(%input: tensor<1x8x3xi8>, %filter: tensor<1x3xi8>, %output: tensor<1x8x3xi8>) -> (tensor<1x8x3xi8>) {
+ %res = linalg.depthwise_conv_1d_nwc_wc
+ {dilations = dense<1> : vector<1xi64>,
+ strides = dense<1> : vector<1xi64>}
+ ins(%input, %filter : tensor<1x8x3xi8>, tensor<1x3xi8>)
+ outs(%output : tensor<1x8x3xi8>) -> tensor<1x8x3xi8>
+ return %res : tensor<1x8x3xi8>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+ %2 = transform.structured.vectorize_children_and_apply_patterns %1 {flatten_1d_depthwise_conv} : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// CHECK-LABEL: func.func @flatten_tensor(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x8x3xi8>,
+// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x3xi8>,
+// CHECK-SAME: %[[VAL_2:.*]]: tensor<1x8x3xi8>) -> tensor<1x8x3xi8> ...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/71918
More information about the Mlir-commits
mailing list