[llvm-branch-commits] [mlir] [mlir][linalg] Decompose winograd operators (PR #96183)
Hsiangkai Wang via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Thu Jun 27 01:41:00 PDT 2024
================
@@ -289,6 +938,123 @@ FailureOr<Operation *> winogradConv2DHelper(RewriterBase &rewriter,
return transformedOutput.getDefiningOp();
}
+FailureOr<Operation *>
+decomposeWinogradFilterTransformHelper(RewriterBase &rewriter,
+ linalg::WinogradFilterTransformOp op) {
+ Location loc = op.getLoc();
+ Value filter = op.getFilter();
+ auto filterType = cast<ShapedType>(filter.getType());
+ auto filterShape = filterType.getShape();
+ int64_t filterH = filterShape[1];
+ int64_t filterW = filterShape[2];
+
+ // For F(m x 1, r x 1), we only need to do left side transform.
+ bool leftTransform = filterH != 1;
+ // For F(1 x m, 1 x r), we only need to do right side transform.
+ bool rightTransform = filterW != 1;
+ Value transformedFilter =
+ filterTransform(rewriter, loc, filter, op.getOutput(), op.getM(),
+ op.getR(), leftTransform, rightTransform);
+ if (!transformedFilter)
+ return failure();
+
+ rewriter.replaceOp(op, transformedFilter);
+
+ return transformedFilter.getDefiningOp();
+}
+
+FailureOr<Operation *>
+decomposeWinogradInputTransformHelper(RewriterBase &rewriter,
+ linalg::WinogradInputTransformOp op) {
+ Location loc = op.getLoc();
+ Value input = op.getInput();
+ auto inputType = cast<ShapedType>(input.getType());
+ auto inputShape = inputType.getShape();
+ int64_t inputH = inputShape[1];
+ int64_t inputW = inputShape[2];
+
+ // For F(m x 1, r x 1), we only need to do left side transform.
+ bool leftTransform = inputH != 1;
+ // For F(1 x m, 1 x r), we only need to do right side transform.
+ bool rightTransform = inputW != 1;
+ Value transformedInput =
+ inputTransform(rewriter, loc, op.getInput(), op.getOutput(), op.getM(),
+ op.getR(), leftTransform, rightTransform);
+ if (!transformedInput)
+ return failure();
+
+ rewriter.replaceOp(op, transformedInput);
+
+ return transformedInput.getDefiningOp();
+}
+
+FailureOr<Operation *>
+decomposeWinogradOutputTransformHelper(RewriterBase &rewriter,
+ linalg::WinogradOutputTransformOp op) {
+ Location loc = op.getLoc();
+ Value value = op.getValue();
+ auto valueType = cast<ShapedType>(value.getType());
+ auto valueShape = valueType.getShape();
+ int64_t valueH = valueShape[2];
+ int64_t valueW = valueShape[3];
+
+ // For F(m x 1, r x 1), we only need to do left side transform.
+ bool leftTransform = valueH != 1;
+ // For F(1 x m, 1 x r), we only need to do right side transform.
+ bool rightTransform = valueW != 1;
+ Value transformedOutput =
+ outputTransform(rewriter, loc, value, op.getOutput(), op.getM(),
+ op.getR(), leftTransform, rightTransform);
+ if (!transformedOutput)
+ return failure();
+
+ rewriter.replaceOp(op, transformedOutput);
+
+ return transformedOutput.getDefiningOp();
+}
+
+class DecomposeWinogradFilterTransform final
+ : public OpRewritePattern<linalg::WinogradFilterTransformOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(linalg::WinogradFilterTransformOp op,
+ PatternRewriter &rewriter) const override {
+ if (failed(decomposeWinogradFilterTransformHelper(rewriter, op)))
+ return failure();
+
+ return success();
+ }
+};
+
+class DecomposeWinogradInputTransform final
+ : public OpRewritePattern<linalg::WinogradInputTransformOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(linalg::WinogradInputTransformOp op,
+ PatternRewriter &rewriter) const override {
+ if (failed(decomposeWinogradInputTransformHelper(rewriter, op)))
+ return failure();
+
+ return success();
----------------
Hsiangkai wrote:
Done.
https://github.com/llvm/llvm-project/pull/96183
More information about the llvm-branch-commits
mailing list