[Mlir-commits] [mlir] 02f1f69 - [mlir][linalg] Add pure tensor check for `winogradConv2DHelper` (#142299)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jun 13 00:49:57 PDT 2025
Author: Longsheng Mou
Date: 2025-06-13T15:49:54+08:00
New Revision: 02f1f6967a847bba35fc207d61732f3466f39403
URL: https://github.com/llvm/llvm-project/commit/02f1f6967a847bba35fc207d61732f3466f39403
DIFF: https://github.com/llvm/llvm-project/commit/02f1f6967a847bba35fc207d61732f3466f39403.diff
LOG: [mlir][linalg] Add pure tensor check for `winogradConv2DHelper` (#142299)
This PR adds pure tensor semantics check for `winogradConv2DHelper` to
prevent a crash. Fixes #141566.
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
index c6ebd3a53d981..e4221d4748415 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
@@ -904,6 +904,10 @@ static bool hasAllOneValues(DenseIntElementsAttr attr) {
static FailureOr<Operation *>
winogradConv2DHelper(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp,
int64_t m, int64_t r) {
+ if (!convOp.hasPureTensorSemantics())
+ return rewriter.notifyMatchFailure(
+ convOp, "expected pure tensor semantics for linalg.conv_2d_nhwc_fhwc");
+
Value input = convOp.getInputs()[0];
Value filter = convOp.getInputs()[1];
Value output = convOp.getOutputs()[0];
diff --git a/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir b/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir
index c10e0ccebfd7c..1de861e653005 100644
--- a/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir
+++ b/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir
@@ -61,6 +61,22 @@ module attributes {transform.with_named_sequence} {
// -----
+func.func @conv2d_unsupported_type(%arg0: memref<2x10x10x5xf32>, %arg1: memref<2x3x3x5xf32>, %arg2: memref<2x8x8x2xf32>) {
+ linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : memref<2x10x10x5xf32>, memref<2x3x3x5xf32>) outs(%arg2 : memref<2x8x8x2xf32>)
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // expected-error @+1 {{apply Winograd Conv2D failed}}
+ %1 = transform.structured.winograd_conv2d %0 { m = 4, r = 3 } : (!transform.any_op) -> (!transform.any_op)
+ transform.yield
+ }
+}
+
+// -----
+
func.func @conv2d(%arg0: tensor<2x?x?x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>, %arg3: tensor<2x?x?x2xf32>) -> tensor<2x?x?x2xf32> {
%0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x?x?x5xf32>, tensor<2x3x3x5xf32>) outs(%arg3 : tensor<2x?x?x2xf32>) -> tensor<2x?x?x2xf32>
return %0 : tensor<2x?x?x2xf32>
More information about the Mlir-commits
mailing list