[llvm-branch-commits] [mlir] [mlir][linalg] Implement TilingInterface for winograd operators (PR #96184)
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Thu Jun 20 05:48:35 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-linalg
Author: Hsiangkai Wang (Hsiangkai)
<details>
<summary>Changes</summary>
In order to support arbitrary size input data of conv2d, implement TilingInterface for winograd operators. Before converting winograd operators into nested loops with matrix multiply, tile the input of conv2d into the supported size first.
Add a transform operator structured.decompose_winograd_op to decompose winograd operators. Before applying the transform op, use tile_using_for to tile the input data into supported size. The test case shows how to tile and decompose winograd operators.
---
Patch is 58.85 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/96184.diff
7 Files Affected:
- (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td (+18-3)
- (modified) mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td (+37)
- (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h (+45)
- (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+281)
- (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+27)
- (modified) mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp (+18)
- (added) mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir (+332)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index de1097b6ac27b..45726d6ee2224 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -154,7 +154,12 @@ def Linalg_SoftmaxOp : Linalg_Op<"softmax",
let hasVerifier = 1;
}
-def Linalg_WinogradFilterTransformOp : Linalg_Op<"winograd_filter_transform"> {
+def Linalg_WinogradFilterTransformOp : Linalg_Op<"winograd_filter_transform",
+ [DeclareOpInterfaceMethods<TilingInterface,
+ ["getIterationDomain",
+ "getLoopIteratorTypes",
+ "getResultTilePosition",
+ "getTiledImplementation"]>]> {
let summary = "Winograd filter transform operator";
let description = [{
Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
@@ -192,7 +197,12 @@ def Linalg_WinogradFilterTransformOp : Linalg_Op<"winograd_filter_transform"> {
let hasVerifier = 1;
}
-def Linalg_WinogradInputTransformOp : Linalg_Op<"winograd_input_transform"> {
+def Linalg_WinogradInputTransformOp : Linalg_Op<"winograd_input_transform",
+ [DeclareOpInterfaceMethods<TilingInterface,
+ ["getIterationDomain",
+ "getLoopIteratorTypes",
+ "getResultTilePosition",
+ "getTiledImplementation"]>]> {
let summary = "Winograd input transform operator";
let description = [{
Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
@@ -230,7 +240,12 @@ def Linalg_WinogradInputTransformOp : Linalg_Op<"winograd_input_transform"> {
let hasVerifier = 1;
}
-def Linalg_WinogradOutputTransformOp : Linalg_Op<"winograd_output_transform"> {
+def Linalg_WinogradOutputTransformOp : Linalg_Op<"winograd_output_transform",
+ [DeclareOpInterfaceMethods<TilingInterface,
+ ["getIterationDomain",
+ "getLoopIteratorTypes",
+ "getResultTilePosition",
+ "getTiledImplementation"]>]> {
let summary = "Winograd output transform operator";
let description = [{
Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 68d0f713caad4..71736eae38b4f 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -2638,4 +2638,41 @@ def WinogradConv2DOp : Op<Transform_Dialect,
}];
}
+def DecomposeWinogradOp : Op<Transform_Dialect,
+ "structured.decompose_winograd_op",
+ [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+ TransformOpInterface, TransformEachOpTrait,
+ ReportTrackingListenerFailuresOpTrait]> {
+ let description = [{
+ Decompose winograd operators. It will convert filter, input and output
+ transform operators into a combination of scf, tensor, and linalg
+ equivalent operators. Before applying this transform operator, users
+ need to tile winograd transform operators into supported sizes.
+
+ #### Return modes:
+
+ This operation fails if `target` is unsupported. Otherwise, the operation
+ succeeds and returns a handle of the sequence that replaces the original
+ operator.
+ }];
+
+ let arguments = (ins TransformHandleTypeInterface:$target);
+ let results = (outs TransformHandleTypeInterface:$transformed);
+
+ let assemblyFormat =
+ "$target attr-dict `:` functional-type($target, results)";
+
+ let builders = [
+ OpBuilder<(ins "Value":$target)>
+ ];
+
+ let extraClassDeclaration = [{
+ ::mlir::DiagnosedSilenceableFailure applyToOne(
+ ::mlir::transform::TransformRewriter &rewriter,
+ ::mlir::Operation *target,
+ ::mlir::transform::ApplyToEachResultList &results,
+ ::mlir::transform::TransformState &state);
+ }];
+}
+
#endif // LINALG_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index bb7ec590faad0..d0eec2be1f8fb 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1319,6 +1319,51 @@ FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
linalg::Conv2DNhwcFhwcOp op, int64_t m,
int64_t r);
+/// Rewrite linalg.winograd_filter_transform. The data layout of the filter is
+/// FHWC. The transformation matrix is 2-dimension. We need to extract H x W
+/// from FHWC first. We need to generate 2 levels of loops to iterate on F and
+/// C. After the rewriting, we get
+///
+/// scf.for %f = lo_f to hi_f step 1
+/// scf.for %c = lo_c to hi_c step 1
+/// %extracted = extract filter<h x w> from filter<f x h x w x c>
+/// %ret = linalg.matmul G, %extracted
+/// %ret = linalg.matmul %ret, GT
+/// %inserted = insert %ret into filter<h x w x c x f>
+FailureOr<Operation *>
+decomposeWinogradFilterTransformOp(RewriterBase &rewriter,
+ linalg::WinogradFilterTransformOp op);
+
+/// Rewrite linalg.winograd_input_transform. The data layout of the input is
+/// NHWC. The transformation matrix is 2-dimension. We need to extract H x W
+/// from NHWC first. We need to generate 2 levels of loops to iterate on N and
+/// C. After the rewriting, we get
+///
+/// scf.for %n = lo_n to hi_n step 1
+/// scf.for %c = lo_c to hi_c step 1
+/// %extracted = extract input<h x w> from input<n x h x w x c>
+/// %ret = linalg.matmul BT, %extracted
+/// %ret = linalg.matmul %ret, B
+/// %inserted = insert %ret into input<h x w x n x c>
+FailureOr<Operation *>
+decomposeWinogradInputTransformOp(RewriterBase &rewriter,
+ linalg::WinogradInputTransformOp op);
+
+/// Rewrite linalg.winograd_output_transform. The data layout of the output is
+/// HWNF. The transformation matrix is 2-dimension. We need to extract H x W
+/// from HWNF first. We need to generate 2 levels of loops to iterate on N and
+/// F. After the transformation, we get
+///
+/// scf.for %n = lo_n to hi_n step 1
+/// scf.for %f = lo_f to hi_f step 1
+/// %extracted = extract input<h x w> from result<h x w x n x f>
+/// %ret = linalg.matmul AT, %extracted
+/// %ret = linalg.matmul %ret, A
+/// %inserted = insert %ret into ret<n x h x w x f>
+FailureOr<Operation *>
+decomposeWinogradOutputTransformOp(RewriterBase &rewriter,
+ linalg::WinogradOutputTransformOp op);
+
//===----------------------------------------------------------------------===//
// Rewrite patterns wrapping transformations.
// TODO: every single such pattern should be a close to noop wrapper around a
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 7bf2a5bca037f..a416e1f6e257f 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2760,6 +2760,89 @@ LogicalResult WinogradFilterTransformOp::verify() {
return success();
}
+SmallVector<Range>
+WinogradFilterTransformOp::getIterationDomain(OpBuilder &builder) {
+ Location loc = getLoc();
+ Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+ Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
+ Value output = getOutput();
+ SmallVector<Range> loopBounds(6);
+ for (unsigned dim = 0; dim < 6; ++dim) {
+ loopBounds[dim].offset = zero;
+ loopBounds[dim].size = getDimValue(builder, loc, output, dim);
+ loopBounds[dim].stride = one;
+ }
+ return loopBounds;
+}
+
+SmallVector<utils::IteratorType>
+WinogradFilterTransformOp::getLoopIteratorTypes() {
+ SmallVector<utils::IteratorType> iteratorTypes(6,
+ utils::IteratorType::parallel);
+ return iteratorTypes;
+}
+
+Value getValueFromOpFoldResult(OpFoldResult opFoldResult, OpBuilder &builder,
+ Location loc) {
+ if (auto val = opFoldResult.dyn_cast<Value>()) {
+ return val;
+ } else if (auto attr = opFoldResult.dyn_cast<Attribute>()) {
+ auto intAttr = cast<IntegerAttr>(attr);
+ return builder.create<arith::ConstantOp>(loc, intAttr);
+ }
+ // This should never happen if OpFoldResult is correctly formed.
+ return nullptr;
+}
+
+LogicalResult WinogradFilterTransformOp::getResultTilePosition(
+ OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
+ SmallVector<OpFoldResult> &resultSizes) {
+ auto zeroAttr = builder.getI64IntegerAttr(0);
+ auto oneAttr = builder.getI64IntegerAttr(1);
+
+ resultOffsets.push_back(offsets[0]);
+ resultOffsets.push_back(offsets[1]);
+ resultOffsets.push_back(zeroAttr);
+ resultOffsets.push_back(zeroAttr);
+ resultOffsets.push_back(zeroAttr);
+ resultOffsets.push_back(zeroAttr);
+ resultSizes.push_back(oneAttr);
+ resultSizes.push_back(oneAttr);
+ resultSizes.push_back(sizes[2]);
+ resultSizes.push_back(sizes[3]);
+ resultSizes.push_back(sizes[4]);
+ resultSizes.push_back(sizes[5]);
+
+ return success();
+}
+
+FailureOr<TilingResult> WinogradFilterTransformOp::getTiledImplementation(
+ OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes) {
+ auto oneAttr = builder.getI64IntegerAttr(1);
+
+ Location loc = getLoc();
+ SmallVector<OpFoldResult> strides(6, oneAttr);
+ SmallVector<Value> tiledOperands;
+ tiledOperands.emplace_back(getFilter());
+
+ SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
+ if (failed(getResultTilePosition(builder, 1, offsets, sizes, sliceOffsets,
+ sliceSizes)))
+ return failure();
+
+ tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
+ loc, getOutput(), sliceOffsets, sliceSizes, strides));
+
+ SmallVector<Type, 4> resultTypes;
+ resultTypes.push_back(tiledOperands[1].getType());
+ Operation *tiledOp =
+ mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
+
+ return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
+}
+
//===----------------------------------------------------------------------===//
// WinogradInputTransformOp
//===----------------------------------------------------------------------===//
@@ -2786,6 +2869,112 @@ LogicalResult WinogradInputTransformOp::verify() {
return success();
}
+SmallVector<Range>
+WinogradInputTransformOp::getIterationDomain(OpBuilder &builder) {
+ Location loc = getLoc();
+ Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+ Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
+ Value output = getOutput();
+ SmallVector<Range> loopBounds(6);
+ for (unsigned dim = 0; dim < 6; ++dim) {
+ loopBounds[dim].offset = zero;
+ loopBounds[dim].size = getDimValue(builder, loc, output, dim);
+ loopBounds[dim].stride = one;
+ }
+ return loopBounds;
+}
+
+SmallVector<utils::IteratorType>
+WinogradInputTransformOp::getLoopIteratorTypes() {
+ SmallVector<utils::IteratorType> iteratorTypes(6,
+ utils::IteratorType::parallel);
+ return iteratorTypes;
+}
+
+LogicalResult WinogradInputTransformOp::getResultTilePosition(
+ OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
+ SmallVector<OpFoldResult> &resultSizes) {
+ auto zeroAttr = builder.getI64IntegerAttr(0);
+ auto oneAttr = builder.getI64IntegerAttr(1);
+
+ resultOffsets.push_back(offsets[0]);
+ resultOffsets.push_back(offsets[1]);
+ resultOffsets.push_back(zeroAttr);
+ resultOffsets.push_back(zeroAttr);
+ resultOffsets.push_back(zeroAttr);
+ resultOffsets.push_back(zeroAttr);
+ resultSizes.push_back(oneAttr);
+ resultSizes.push_back(oneAttr);
+ resultSizes.push_back(sizes[2]);
+ resultSizes.push_back(sizes[3]);
+ resultSizes.push_back(sizes[4]);
+ resultSizes.push_back(sizes[5]);
+
+ return success();
+}
+
+FailureOr<TilingResult>
+WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
+ ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes) {
+ auto oneAttr = builder.getI64IntegerAttr(1);
+ auto zeroAttr = builder.getI64IntegerAttr(0);
+ Value input = getInput();
+ auto inputType = cast<ShapedType>(input.getType());
+ auto inputShape = inputType.getShape();
+ int64_t inputH = inputShape[1];
+ int64_t inputW = inputShape[2];
+ int64_t m = getM();
+ int64_t r = getR();
+ int64_t alpha = m + r - 1;
+ int64_t alphaH = inputH != 1 ? alpha : 1;
+ int64_t alphaW = inputW != 1 ? alpha : 1;
+ auto alphaHAttr = builder.getI64IntegerAttr(alphaH);
+ auto alphaWAttr = builder.getI64IntegerAttr(alphaW);
+
+ Location loc = getLoc();
+ SmallVector<Value> tiledOperands;
+ SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
+
+ auto context = builder.getContext();
+ auto affineMap =
+ AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
+ Value mappedOffset1 = builder.create<affine::AffineApplyOp>(
+ loc, affineMap, getValueFromOpFoldResult(offsets[0], builder, loc));
+ Value mappedOffset2 = builder.create<affine::AffineApplyOp>(
+ loc, affineMap, getValueFromOpFoldResult(offsets[1], builder, loc));
+
+ sliceOffsets.push_back(zeroAttr);
+ sliceOffsets.push_back(mappedOffset1);
+ sliceOffsets.push_back(mappedOffset2);
+ sliceOffsets.push_back(zeroAttr);
+ sliceSizes.push_back(sizes[4]);
+ sliceSizes.push_back(alphaHAttr);
+ sliceSizes.push_back(alphaWAttr);
+ sliceSizes.push_back(sizes[5]);
+ SmallVector<OpFoldResult> inputStrides(4, oneAttr);
+ tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
+ loc, getInput(), sliceOffsets, sliceSizes, inputStrides));
+
+ sliceOffsets.clear();
+ sliceSizes.clear();
+ if (failed(getResultTilePosition(builder, 1, offsets, sizes, sliceOffsets,
+ sliceSizes)))
+ return failure();
+
+ SmallVector<OpFoldResult> outputStrides(6, oneAttr);
+ tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
+ loc, getOutput(), sliceOffsets, sliceSizes, outputStrides));
+
+ SmallVector<Type, 4> resultTypes;
+ resultTypes.push_back(tiledOperands[1].getType());
+ Operation *tiledOp =
+ mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
+
+ return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
+}
+
//===----------------------------------------------------------------------===//
// WinogradOutputTransformOp
//===----------------------------------------------------------------------===//
@@ -2812,6 +3001,98 @@ LogicalResult WinogradOutputTransformOp::verify() {
return success();
}
+SmallVector<Range>
+WinogradOutputTransformOp::getIterationDomain(OpBuilder &builder) {
+ Location loc = getLoc();
+ Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+ Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
+ Value value = getValue();
+ SmallVector<Range> loopBounds(6);
+ for (unsigned dim = 0; dim < 6; ++dim) {
+ loopBounds[dim].offset = zero;
+ loopBounds[dim].size = getDimValue(builder, loc, value, dim);
+ loopBounds[dim].stride = one;
+ }
+ return loopBounds;
+}
+
+SmallVector<utils::IteratorType>
+WinogradOutputTransformOp::getLoopIteratorTypes() {
+ SmallVector<utils::IteratorType> iteratorTypes(6,
+ utils::IteratorType::parallel);
+ return iteratorTypes;
+}
+
+LogicalResult WinogradOutputTransformOp::getResultTilePosition(
+ OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
+ SmallVector<OpFoldResult> &resultSizes) {
+ auto zeroAttr = builder.getI64IntegerAttr(0);
+ int64_t m = getM();
+ IntegerAttr mAttr = getMAttr();
+ Location loc = getLoc();
+ auto context = builder.getContext();
+ auto affineMap =
+ AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
+ Value mappedOffset1 = builder.create<affine::AffineApplyOp>(
+ loc, affineMap, getValueFromOpFoldResult(offsets[0], builder, loc));
+ Value mappedOffset2 = builder.create<affine::AffineApplyOp>(
+ loc, affineMap, getValueFromOpFoldResult(offsets[1], builder, loc));
+
+ resultOffsets.push_back(zeroAttr);
+ resultOffsets.push_back(mappedOffset1);
+ resultOffsets.push_back(mappedOffset2);
+ resultOffsets.push_back(zeroAttr);
+ resultSizes.push_back(sizes[4]);
+ resultSizes.push_back(mAttr);
+ resultSizes.push_back(mAttr);
+ resultSizes.push_back(sizes[5]);
+ return success();
+}
+
+FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
+ OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes) {
+ auto oneAttr = builder.getI64IntegerAttr(1);
+ auto zeroAttr = builder.getI64IntegerAttr(0);
+ Location loc = getLoc();
+ SmallVector<Value> tiledOperands;
+ SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
+
+ sliceOffsets.push_back(offsets[0]);
+ sliceOffsets.push_back(offsets[1]);
+ sliceOffsets.push_back(zeroAttr);
+ sliceOffsets.push_back(zeroAttr);
+ sliceOffsets.push_back(zeroAttr);
+ sliceOffsets.push_back(zeroAttr);
+ sliceSizes.push_back(oneAttr);
+ sliceSizes.push_back(oneAttr);
+ sliceSizes.push_back(sizes[2]);
+ sliceSizes.push_back(sizes[3]);
+ sliceSizes.push_back(sizes[4]);
+ sliceSizes.push_back(sizes[5]);
+ SmallVector<OpFoldResult> sliceStrides(6, oneAttr);
+ tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
+ loc, getValue(), sliceOffsets, sliceSizes, sliceStrides));
+
+ sliceOffsets.clear();
+ sliceSizes.clear();
+ if (failed(getResultTilePosition(builder, 1, offsets, sizes, sliceOffsets,
+ sliceSizes)))
+ return failure();
+
+ SmallVector<OpFoldResult> strides(4, oneAttr);
+ tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
+ loc, getOutput(), sliceOffsets, sliceSizes, strides));
+
+ SmallVector<Type, 4> resultTypes;
+ resultTypes.push_back(tiledOperands[1].getType());
+ Operation *tiledOp =
+ mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
+
+ return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
+}
+
//===----------------------------------------------------------------------===//
// LinalgDialect
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index d051b29e1f06f..358c15f145407 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3505,6 +3505,33 @@ DiagnosedSilenceableFailure transform::WinogradConv2DOp::applyToOne(
return DiagnosedSilenceableFailure::success();
}
+DiagnosedSilenceableFailure transform::DecomposeWinogradOp::applyToOne(
+ transform::TransformRewriter &rewriter, Operation *target,
+ transform::ApplyToEachResultList &results,
+ transform::TransformState &state) {
+ rewriter.setInsertionPoint(target);
+ auto maybeTransformed =
+ TypeSwitch<Operation *, FailureOr<Operation *>>(target)
+ .Case([&](linalg::WinogradFilterTransformOp op) {
+ return decomposeWinogradFilterTransformOp(rewriter, op);
+ })
+ .Case([&](linalg::WinogradInputTransformOp op) {
+ return decomposeWinogradInputTransformOp(rewriter, op);
+ })
+ .Case([&](linalg::WinogradOutputTransformOp op) {
+ return decomposeWinogradOutputTransformOp(rewriter, op);
+ })
+ .Default([&](Operation *op) {
+ return rewriter.notifyMatchFailure(op, "not supported");
+ });
+
+ if (failed(maybeTransformed))
+ return emitDefaultSilenceableFailure(target);
+
+ results.push_back(*maybeTransformed);
+ return DiagnosedSilenceableFailure::success();
+}
+
#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc"
#define GET_OP_CLASSES
diff --git a/mlir/lib/Dialect/Linalg/Transfo...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/96184
More information about the llvm-branch-commits
mailing list