[Mlir-commits] [mlir] [tosa] : Relax dynamic dimension checks for batch for conv decompositions (PR #168764)
Sayan Saha
llvmlistbot at llvm.org
Wed Nov 19 11:38:48 PST 2025
https://github.com/sahas3 created https://github.com/llvm/llvm-project/pull/168764
This PR relaxes the validation checks to allow input/output data to have dynamic batch dimensions.
>From 37122655738dfc1c80ccee0703ce336052933584 Mon Sep 17 00:00:00 2001
From: Sayan Saha <sayans at mathworks.com>
Date: Wed, 19 Nov 2025 09:29:47 -0500
Subject: [PATCH] [tosa] : Relax dynamic dimension checks for batch for conv
decompositions.
---
.../Transforms/TosaDecomposeDepthwise.cpp | 9 ++++++--
.../Transforms/TosaDecomposeTransposeConv.cpp | 18 +++++++++++----
.../Tosa/tosa-decompose-depthwise.mlir | 23 +++++++++++++++++++
.../Tosa/tosa-decompose-transpose-conv.mlir | 21 +++++++++++++++++
4 files changed, 65 insertions(+), 6 deletions(-)
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
index 0bec0da3f4320..022476a2f44cf 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
@@ -33,8 +33,13 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
ShapedType weightType = cast<ShapedType>(weight.getType());
ShapedType resultType = cast<ShapedType>(op.getOutput().getType());
- if (!(inputType.hasStaticShape() && weightType.hasStaticShape() &&
- resultType.hasStaticShape())) {
+ // Any dimensions other than batchSize cannot be dynamic for input/output
+ for (unsigned int i = 1; i < 4; ++i) {
+ if (inputType.isDynamicDim(i) || resultType.isDynamicDim(i))
+ return failure();
+ }
+
+ if (!weightType.hasStaticShape()) {
return failure();
}
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
index dc5c51b0abad5..8b23fd1341bc5 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
@@ -49,8 +49,13 @@ class TransposeConvNonStridedConverter
if (llvm::any_of(stride, [](int64_t v) { return v != 1; }))
return failure();
- if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
- !biasTy.hasStaticShape() || !resultTy.hasStaticShape())
+ // Any dimensions other than batchSize cannot be dynamic for input/output
+ for (unsigned int i = 1; i < 4; ++i) {
+ if (inputTy.isDynamicDim(i) || resultTy.isDynamicDim(i))
+ return failure();
+ }
+
+ if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
return failure();
int64_t kernelHeight = weightTy.getDimSize(1);
@@ -113,8 +118,13 @@ class TransposeConvStridedConverter
if (llvm::all_of(stride, [](int64_t v) { return v == 1; }))
return rewriter.notifyMatchFailure(op, "non-one stride found.");
- if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
- !biasTy.hasStaticShape() || !resultTy.hasStaticShape())
+ // Any dimensions other than batchSize cannot be dynamic for input/output
+ for (unsigned int i = 1; i < 4; ++i) {
+ if (inputTy.isDynamicDim(i) || resultTy.isDynamicDim(i))
+ return failure();
+ }
+
+ if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
return failure();
int64_t batch = inputTy.getDimSize(0);
diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir
index c7eeb5281679b..d4c4595e84ee0 100644
--- a/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir
@@ -98,3 +98,26 @@ func.func @depthwise_conv2d_no_const_zero_point(%arg0: tensor<4x10x10x2xi8>, %ar
%0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = i32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<4x10x10x2xi8>, tensor<1x1x2x3xi8>, tensor<6xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<4x10x10x6xi32>
return %0 : tensor<4x10x10x6xi32>
}
+
+// -----
+// CHECK-LABEL: func.func @depthwise_conv2d_as_mul_dynamic_batch_bias(
+// CHECK-SAME: %[[INP:.*]]: tensor<?x10x10x2xf32>,
+// CHECK-SAME: %[[WTS:.*]]: tensor<1x1x2x3xf32>,
+// CHECK-SAME: %[[BIAS:.*]]: tensor<?xf32>) -> tensor<?x10x10x6xf32> {
+// CHECK: %[[BIAS_EXPANDED_SHAPE:.*]] = tosa.const_shape {values = dense<[1, 1, 1, -1]> : tensor<4xindex>} : () -> !tosa.shape<4>
+// CHECK: %[[RES_EXPANDED_SHAPE:.*]] = tosa.const_shape {values = dense<[-1, 10, 10, 6]> : tensor<4xindex>} : () -> !tosa.shape<4>
+// CHECK: %[[MUL_SHIFT:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+// CHECK: %[[WTS_EXPANDED_SHAPE:.*]] = tosa.const_shape {values = dense<[1, 1, 1, 2, 3]> : tensor<5xindex>} : () -> !tosa.shape<5>
+// CHECK: %[[INP_EXPANDED_SHAPE:.*]] = tosa.const_shape {values = dense<[-1, 10, 10, 2, 1]> : tensor<5xindex>} : () -> !tosa.shape<5>
+// CHECK: %[[INP_RESHAPED:.*]] = tosa.reshape %[[INP]], %[[INP_EXPANDED_SHAPE]] : (tensor<?x10x10x2xf32>, !tosa.shape<5>) -> tensor<?x10x10x2x1xf32>
+// CHECK: %[[WTS_RESHAPED:.*]] = tosa.reshape %[[WTS]], %[[WTS_EXPANDED_SHAPE]] : (tensor<1x1x2x3xf32>, !tosa.shape<5>) -> tensor<1x1x1x2x3xf32>
+// CHECK: %[[MUL:.*]] = tosa.mul %[[INP_RESHAPED]], %[[WTS_RESHAPED]], %[[MUL_SHIFT]] : (tensor<?x10x10x2x1xf32>, tensor<1x1x1x2x3xf32>, tensor<1xi8>) -> tensor<?x10x10x2x3xf32>
+// CHECK: %[[RES_RESHAPED:.*]] = tosa.reshape %[[MUL]], %[[RES_EXPANDED_SHAPE]] : (tensor<?x10x10x2x3xf32>, !tosa.shape<4>) -> tensor<?x10x10x6xf32>
+// CHECK: %[[BIAS_RESHAPED:.*]] = tosa.reshape %[[BIAS]], %[[BIAS_EXPANDED_SHAPE]] : (tensor<?xf32>, !tosa.shape<4>) -> tensor<1x1x1x?xf32>
+// CHECK: %[[RES:.*]] = tosa.add %[[RES_RESHAPED]], %[[BIAS_RESHAPED]] : (tensor<?x10x10x6xf32>, tensor<1x1x1x?xf32>) -> tensor<?x10x10x6xf32>
+// CHECK: return %[[RES]]
+func.func @depthwise_conv2d_as_mul_dynamic_batch_bias(%arg0: tensor<?x10x10x2xf32>, %arg1: tensor<1x1x2x3xf32>, %arg2: tensor<?xf32>) -> tensor<?x10x10x6xf32> {
+ %zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+ %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<?x10x10x2xf32>, tensor<1x1x2x3xf32>, tensor<?xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x10x10x6xf32>
+ return %0 : tensor<?x10x10x6xf32>
+}
diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
index 810135f6f531b..61ca0aedf6a46 100644
--- a/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
@@ -181,3 +181,24 @@ func.func @transpose_conv2d_strided_overpad(%arg0 : tensor<1x16x1x1xi8>, %arg1 :
(tensor<1x16x1x1xi8>, tensor<1x2x1x1xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x19x2x1xi32>
"func.return" (%2) : (tensor<1x19x2x1xi32>) -> ()
}
+
+
+// -----
+// CHECK-LABEL: @transpose_conv2d_non_strided_dynamic_batch
+// CHECK: tosa.conv2d
+// CHECK-NOT: tosa.transpose_conv2d
+func.func @transpose_conv2d_non_strided_dynamic_batch(%arg0: tensor<?x16x14x3xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<5xf32>) -> tensor<?x18x19x5xf32> {
+ %zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+ %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<?x16x14x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x18x19x5xf32>
+ return %0 : tensor<?x18x19x5xf32>
+}
+
+// -----
+// CHECK-LABEL: @transpose_conv2d_strided_dynamic_batch
+// CHECK: tosa.conv2d
+// CHECK-NOT: tosa.transpose_conv2d
+func.func @transpose_conv2d_strided_dynamic_batch(%arg0: tensor<?x17x15x3xf32>, %arg1: tensor<5x3x5x3xf32>, %arg2: tensor<5xf32>) -> tensor<?x35x47x5xf32> {
+ %zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+ %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 2, 3>} : (tensor<?x17x15x3xf32>, tensor<5x3x5x3xf32>, tensor<5xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x35x47x5xf32>
+ return %0 : tensor<?x35x47x5xf32>
+}
More information about the Mlir-commits
mailing list