[llvm-branch-commits] [mlir] [mlir][linalg] Decompose winograd operators (PR #96183)

Hsiangkai Wang via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Thu Jun 27 01:41:00 PDT 2024


================
@@ -289,6 +938,123 @@ FailureOr<Operation *> winogradConv2DHelper(RewriterBase &rewriter,
   return transformedOutput.getDefiningOp();
 }
 
+FailureOr<Operation *>
+decomposeWinogradFilterTransformHelper(RewriterBase &rewriter,
+                                       linalg::WinogradFilterTransformOp op) {
+  Location loc = op.getLoc();
+  Value filter = op.getFilter();
+  auto filterType = cast<ShapedType>(filter.getType());
+  auto filterShape = filterType.getShape();
+  int64_t filterH = filterShape[1];
+  int64_t filterW = filterShape[2];
+
+  // For F(m x 1, r x 1), we only need to do left side transform.
+  bool leftTransform = filterH != 1;
+  // For F(1 x m, 1 x r), we only need to do right side transform.
+  bool rightTransform = filterW != 1;
+  Value transformedFilter =
+      filterTransform(rewriter, loc, filter, op.getOutput(), op.getM(),
+                      op.getR(), leftTransform, rightTransform);
+  if (!transformedFilter)
+    return failure();
+
+  rewriter.replaceOp(op, transformedFilter);
+
+  return transformedFilter.getDefiningOp();
+}
+
+FailureOr<Operation *>
+decomposeWinogradInputTransformHelper(RewriterBase &rewriter,
+                                      linalg::WinogradInputTransformOp op) {
+  Location loc = op.getLoc();
+  Value input = op.getInput();
+  auto inputType = cast<ShapedType>(input.getType());
+  auto inputShape = inputType.getShape();
+  int64_t inputH = inputShape[1];
+  int64_t inputW = inputShape[2];
+
+  // For F(m x 1, r x 1), we only need to do left side transform.
+  bool leftTransform = inputH != 1;
+  // For F(1 x m, 1 x r), we only need to do right side transform.
+  bool rightTransform = inputW != 1;
+  Value transformedInput =
+      inputTransform(rewriter, loc, op.getInput(), op.getOutput(), op.getM(),
+                     op.getR(), leftTransform, rightTransform);
+  if (!transformedInput)
+    return failure();
+
+  rewriter.replaceOp(op, transformedInput);
+
+  return transformedInput.getDefiningOp();
+}
+
+FailureOr<Operation *>
+decomposeWinogradOutputTransformHelper(RewriterBase &rewriter,
+                                       linalg::WinogradOutputTransformOp op) {
+  Location loc = op.getLoc();
+  Value value = op.getValue();
+  auto valueType = cast<ShapedType>(value.getType());
+  auto valueShape = valueType.getShape();
+  int64_t valueH = valueShape[2];
+  int64_t valueW = valueShape[3];
+
+  // For F(m x 1, r x 1), we only need to do left side transform.
+  bool leftTransform = valueH != 1;
+  // For F(1 x m, 1 x r), we only need to do right side transform.
+  bool rightTransform = valueW != 1;
+  Value transformedOutput =
+      outputTransform(rewriter, loc, value, op.getOutput(), op.getM(),
+                      op.getR(), leftTransform, rightTransform);
+  if (!transformedOutput)
+    return failure();
+
+  rewriter.replaceOp(op, transformedOutput);
+
+  return transformedOutput.getDefiningOp();
+}
+
+class DecomposeWinogradFilterTransform final
+    : public OpRewritePattern<linalg::WinogradFilterTransformOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(linalg::WinogradFilterTransformOp op,
+                                PatternRewriter &rewriter) const override {
+    if (failed(decomposeWinogradFilterTransformHelper(rewriter, op)))
+      return failure();
+
+    return success();
+  }
+};
+
+class DecomposeWinogradInputTransform final
+    : public OpRewritePattern<linalg::WinogradInputTransformOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(linalg::WinogradInputTransformOp op,
+                                PatternRewriter &rewriter) const override {
+    if (failed(decomposeWinogradInputTransformHelper(rewriter, op)))
+      return failure();
+
+    return success();
----------------
Hsiangkai wrote:

Done.

https://github.com/llvm/llvm-project/pull/96183


More information about the llvm-branch-commits mailing list