[Mlir-commits] [mlir] [tosa] : Relax dynamic dimension checks for batch for conv decompositions (PR #168764)

Sayan Saha llvmlistbot at llvm.org
Wed Nov 19 11:38:48 PST 2025


https://github.com/sahas3 created https://github.com/llvm/llvm-project/pull/168764

This PR relaxes the validation checks to allow input/output data to have dynamic batch dimensions.

>From 37122655738dfc1c80ccee0703ce336052933584 Mon Sep 17 00:00:00 2001
From: Sayan Saha <sayans at mathworks.com>
Date: Wed, 19 Nov 2025 09:29:47 -0500
Subject: [PATCH] [tosa] : Relax dynamic dimension checks for batch for conv
 decompositions.

---
 .../Transforms/TosaDecomposeDepthwise.cpp     |  9 ++++++--
 .../Transforms/TosaDecomposeTransposeConv.cpp | 18 +++++++++++----
 .../Tosa/tosa-decompose-depthwise.mlir        | 23 +++++++++++++++++++
 .../Tosa/tosa-decompose-transpose-conv.mlir   | 21 +++++++++++++++++
 4 files changed, 65 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
index 0bec0da3f4320..022476a2f44cf 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
@@ -33,8 +33,13 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
     ShapedType weightType = cast<ShapedType>(weight.getType());
     ShapedType resultType = cast<ShapedType>(op.getOutput().getType());
 
-    if (!(inputType.hasStaticShape() && weightType.hasStaticShape() &&
-          resultType.hasStaticShape())) {
+    // Any dimensions other than batchSize cannot be dynamic for input/output
+    for (unsigned int i = 1; i < 4; ++i) {
+      if (inputType.isDynamicDim(i) || resultType.isDynamicDim(i))
+        return failure();
+    }
+
+    if (!weightType.hasStaticShape()) {
       return failure();
     }
 
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
index dc5c51b0abad5..8b23fd1341bc5 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
@@ -49,8 +49,13 @@ class TransposeConvNonStridedConverter
     if (llvm::any_of(stride, [](int64_t v) { return v != 1; }))
       return failure();
 
-    if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
-        !biasTy.hasStaticShape() || !resultTy.hasStaticShape())
+    // Any dimensions other than batchSize cannot be dynamic for input/output
+    for (unsigned int i = 1; i < 4; ++i) {
+      if (inputTy.isDynamicDim(i) || resultTy.isDynamicDim(i))
+        return failure();
+    }
+
+    if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
       return failure();
 
     int64_t kernelHeight = weightTy.getDimSize(1);
@@ -113,8 +118,13 @@ class TransposeConvStridedConverter
     if (llvm::all_of(stride, [](int64_t v) { return v == 1; }))
       return rewriter.notifyMatchFailure(op, "non-one stride found.");
 
-    if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
-        !biasTy.hasStaticShape() || !resultTy.hasStaticShape())
+    // Any dimensions other than batchSize cannot be dynamic for input/output
+    for (unsigned int i = 1; i < 4; ++i) {
+      if (inputTy.isDynamicDim(i) || resultTy.isDynamicDim(i))
+        return failure();
+    }
+
+    if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
       return failure();
 
     int64_t batch = inputTy.getDimSize(0);
diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir
index c7eeb5281679b..d4c4595e84ee0 100644
--- a/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir
@@ -98,3 +98,26 @@ func.func @depthwise_conv2d_no_const_zero_point(%arg0: tensor<4x10x10x2xi8>, %ar
   %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = i32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<4x10x10x2xi8>, tensor<1x1x2x3xi8>, tensor<6xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<4x10x10x6xi32>
   return %0 : tensor<4x10x10x6xi32>
 }
+
+// -----
+// CHECK-LABEL:   func.func @depthwise_conv2d_as_mul_dynamic_batch_bias(
+// CHECK-SAME:      %[[INP:.*]]: tensor<?x10x10x2xf32>,
+// CHECK-SAME:      %[[WTS:.*]]: tensor<1x1x2x3xf32>,
+// CHECK-SAME:      %[[BIAS:.*]]: tensor<?xf32>) -> tensor<?x10x10x6xf32> {
+// CHECK:           %[[BIAS_EXPANDED_SHAPE:.*]] = tosa.const_shape  {values = dense<[1, 1, 1, -1]> : tensor<4xindex>} : () -> !tosa.shape<4>
+// CHECK:           %[[RES_EXPANDED_SHAPE:.*]] = tosa.const_shape  {values = dense<[-1, 10, 10, 6]> : tensor<4xindex>} : () -> !tosa.shape<4>
+// CHECK:           %[[MUL_SHIFT:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+// CHECK:           %[[WTS_EXPANDED_SHAPE:.*]] = tosa.const_shape  {values = dense<[1, 1, 1, 2, 3]> : tensor<5xindex>} : () -> !tosa.shape<5>
+// CHECK:           %[[INP_EXPANDED_SHAPE:.*]] = tosa.const_shape  {values = dense<[-1, 10, 10, 2, 1]> : tensor<5xindex>} : () -> !tosa.shape<5>
+// CHECK:           %[[INP_RESHAPED:.*]] = tosa.reshape %[[INP]], %[[INP_EXPANDED_SHAPE]] : (tensor<?x10x10x2xf32>, !tosa.shape<5>) -> tensor<?x10x10x2x1xf32>
+// CHECK:           %[[WTS_RESHAPED:.*]] = tosa.reshape %[[WTS]], %[[WTS_EXPANDED_SHAPE]] : (tensor<1x1x2x3xf32>, !tosa.shape<5>) -> tensor<1x1x1x2x3xf32>
+// CHECK:           %[[MUL:.*]] = tosa.mul %[[INP_RESHAPED]], %[[WTS_RESHAPED]], %[[MUL_SHIFT]] : (tensor<?x10x10x2x1xf32>, tensor<1x1x1x2x3xf32>, tensor<1xi8>) -> tensor<?x10x10x2x3xf32>
+// CHECK:           %[[RES_RESHAPED:.*]] = tosa.reshape %[[MUL]], %[[RES_EXPANDED_SHAPE]] : (tensor<?x10x10x2x3xf32>, !tosa.shape<4>) -> tensor<?x10x10x6xf32>
+// CHECK:           %[[BIAS_RESHAPED:.*]] = tosa.reshape %[[BIAS]], %[[BIAS_EXPANDED_SHAPE]] : (tensor<?xf32>, !tosa.shape<4>) -> tensor<1x1x1x?xf32>
+// CHECK:           %[[RES:.*]] = tosa.add %[[RES_RESHAPED]], %[[BIAS_RESHAPED]] : (tensor<?x10x10x6xf32>, tensor<1x1x1x?xf32>) -> tensor<?x10x10x6xf32>
+// CHECK:           return %[[RES]]
+func.func @depthwise_conv2d_as_mul_dynamic_batch_bias(%arg0: tensor<?x10x10x2xf32>, %arg1: tensor<1x1x2x3xf32>, %arg2: tensor<?xf32>) -> tensor<?x10x10x6xf32> {
+  %zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+  %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<?x10x10x2xf32>, tensor<1x1x2x3xf32>, tensor<?xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x10x10x6xf32>
+  return %0 : tensor<?x10x10x6xf32>
+}
diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
index 810135f6f531b..61ca0aedf6a46 100644
--- a/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
@@ -181,3 +181,24 @@ func.func @transpose_conv2d_strided_overpad(%arg0 : tensor<1x16x1x1xi8>, %arg1 :
     (tensor<1x16x1x1xi8>, tensor<1x2x1x1xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x19x2x1xi32>
   "func.return" (%2) : (tensor<1x19x2x1xi32>) -> ()
 }
+
+
+// -----
+// CHECK-LABEL: @transpose_conv2d_non_strided_dynamic_batch
+// CHECK: tosa.conv2d
+// CHECK-NOT: tosa.transpose_conv2d
+func.func @transpose_conv2d_non_strided_dynamic_batch(%arg0: tensor<?x16x14x3xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<5xf32>) -> tensor<?x18x19x5xf32> {
+  %zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+  %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<?x16x14x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x18x19x5xf32>
+  return %0 : tensor<?x18x19x5xf32>
+}
+
+// -----
+// CHECK-LABEL: @transpose_conv2d_strided_dynamic_batch
+// CHECK: tosa.conv2d
+// CHECK-NOT: tosa.transpose_conv2d
+func.func @transpose_conv2d_strided_dynamic_batch(%arg0: tensor<?x17x15x3xf32>, %arg1: tensor<5x3x5x3xf32>, %arg2: tensor<5xf32>) -> tensor<?x35x47x5xf32> {
+  %zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+  %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 2, 3>} : (tensor<?x17x15x3xf32>, tensor<5x3x5x3xf32>, tensor<5xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x35x47x5xf32>
+  return %0 : tensor<?x35x47x5xf32>
+}



More information about the Mlir-commits mailing list