[Mlir-commits] [mlir] [mlir][tosa] Require PadOp's pad_const to be rank1 (PR #129156)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Feb 27 15:54:52 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-tosa

Author: Jerry-Ge (Jerry-Ge)

<details>
<summary>Changes</summary>

Update PadOp's pad_const input to be rank1.

Fix various lit tests for this change including some conv ops

---
Full diff: https://github.com/llvm/llvm-project/pull/129156.diff


8 Files Affected:

- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+1-1) 
- (modified) mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp (+1-1) 
- (modified) mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp (+1-1) 
- (modified) mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir (+2-2) 
- (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+3-3) 
- (modified) mlir/test/Dialect/Tosa/invalid.mlir (+10-9) 
- (modified) mlir/test/Dialect/Tosa/ops.mlir (+2-2) 
- (modified) mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir (+2-2) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index ddfec2c9bfcd3..65e5956e2072b 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1883,7 +1883,7 @@ def Tosa_PadOp : Tosa_InferShapedTypeOp<"pad"> {
   let arguments = (ins
     Tosa_RankedTensor:$input1,
     Tosa_Shape:$padding,
-    Optional<Tosa_Rank0Tensor>:$pad_const,
+    Optional<Tosa_ScalarTensor>:$pad_const,
     OptionalAttr<I32Attr>:$input_zp
   );
 
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index dcba9ef67a008..363b5958bc0fd 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -206,7 +206,7 @@ struct MaterializePadValue : public OpRewritePattern<tosa::PadOp> {
     }
 
     auto denseAttr = DenseElementsAttr::get(
-        RankedTensorType::get({}, elementTy), constantAttr);
+        RankedTensorType::get({1}, elementTy), constantAttr);
     auto constantVal = rewriter.create<tosa::ConstOp>(
         op.getLoc(), denseAttr.getType(), denseAttr);
 
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
index fc945928e4908..25a159bbc9644 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
@@ -127,7 +127,7 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
 
       Value padSizeVal = getTosaConstShape(rewriter, op->getLoc(), pad);
 
-      auto padTy = RankedTensorType::get({}, inputETy);
+      auto padTy = RankedTensorType::get({1}, inputETy);
       auto padAttr = DenseElementsAttr::get(padTy, zeroAttr);
       Value padVal =
           rewriter.create<tosa::ConstOp>(op->getLoc(), padTy, padAttr);
diff --git a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
index c2eaba4c563d0..6b7f622d3303f 100644
--- a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
+++ b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
@@ -542,8 +542,8 @@ func.func @pad_float_explicit(%arg0 : tensor<1x2xf32>) -> (tensor<4x9xf32>) {
   // CHECK: tensor.pad %[[ARG0]] low{{\[}}[[INDEX1]], [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]]  {
   // CHECK:   tensor.yield [[CST]]
   // CHECK: } : tensor<1x2xf32> to tensor<4x9xf32>
-  %1 = arith.constant dense<42.0> : tensor<f32>
-  %2 = "tosa.pad"(%arg0, %0, %1)  : (tensor<1x2xf32>, !tosa.shape<4>, tensor<f32>)  -> (tensor<4x9xf32>)
+  %1 = arith.constant dense<42.0> : tensor<1xf32>
+  %2 = "tosa.pad"(%arg0, %0, %1)  : (tensor<1x2xf32>, !tosa.shape<4>, tensor<1xf32>)  -> (tensor<4x9xf32>)
   return %2 : tensor<4x9xf32>
 }
 
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 66de6b23eae01..175145f332f8e 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -288,7 +288,7 @@ func.func @pad_noop_type_mismatch_nofold(%arg0: tensor<10xf32>) -> tensor<?xf32>
 
 // CHECK-LABEL: @pad_determine_val_i32
 func.func @pad_determine_val_i32(%arg0: tensor<?x?xi32>, %arg1 : tensor<2x2xi32>) -> tensor<?x?xi32> {
-  // CHECK-DAG: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0> : tensor<i32>}
+  // CHECK-DAG: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0> : tensor<1xi32>}
   // CHECK-DAG: %[[PADDING:.+]] = tosa.const_shape {value = dense<[1, 0, 0, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
   // CHECK: tosa.pad %arg0, %[[PADDING]], %[[ZERO]]
   %0 = tosa.const_shape { value = dense<[1, 0, 0, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
@@ -300,7 +300,7 @@ func.func @pad_determine_val_i32(%arg0: tensor<?x?xi32>, %arg1 : tensor<2x2xi32>
 
 // CHECK-LABEL: @pad_determine_val_f32
 func.func @pad_determine_val_f32(%arg0: tensor<?x?xf32>, %arg1 : tensor<2x2xi32>) -> tensor<?x?xf32> {
-  // CHECK-DAG: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<f32>}
+  // CHECK-DAG: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1xf32>}
   // CHECK-DAG: %[[PADDING:.+]] = tosa.const_shape {value = dense<[1, 0, 0, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
   // CHECK: tosa.pad %arg0, %[[PADDING]], %[[ZERO]]
   %0 = tosa.const_shape { value = dense<[1, 0, 0, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
@@ -312,7 +312,7 @@ func.func @pad_determine_val_f32(%arg0: tensor<?x?xf32>, %arg1 : tensor<2x2xi32>
 
 // CHECK-LABEL: @pad_determine_val_quant
 func.func @pad_determine_val_quant(%arg0: tensor<?x?xi32>, %arg1 : tensor<2x2xi32>) -> tensor<?x?xi32> {
-  // CHECK-DAG: %[[ZERO:.+]] = "tosa.const"() <{value = dense<42> : tensor<i32>}
+  // CHECK-DAG: %[[ZERO:.+]] = "tosa.const"() <{value = dense<42> : tensor<1xi32>}
   // CHECK-DAG: %[[PADDING:.+]] = tosa.const_shape {value = dense<[1, 0, 0, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
   // CHECK: tosa.pad %arg0, %[[PADDING]], %[[ZERO]]
   %0 = tosa.const_shape { value = dense<[1, 0, 0, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 123c65e1b4fcd..c68d1eb2b007b 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -218,10 +218,10 @@ func.func @test_pad_non_const(%arg0: tensor<13x21x3xf32>, %arg1: !tosa.shape<6>)
 
 // -----
 
-func.func @test_pad_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<i8>) -> tensor<13x21x3xi8> {
+func.func @test_pad_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<1xi8>) -> tensor<13x21x3xi8> {
   %0 = tosa.const_shape {value = dense<[0, 0, 0, 1, 0, 1]> : tensor<6xindex>} : () -> !tosa.shape<6>
   // expected-error at +1 {{'tosa.pad' op pad_const of pad is not constant}}
-  %1 = tosa.pad %arg0, %0, %arg1 : (tensor<13x21x3xi8>, !tosa.shape<6>, tensor<i8>) -> tensor<13x21x3xi8>
+  %1 = tosa.pad %arg0, %0, %arg1 : (tensor<13x21x3xi8>, !tosa.shape<6>, tensor<1xi8>) -> tensor<13x21x3xi8>
   return %1 : tensor<13x21x3xi8>
 }
 
@@ -248,7 +248,7 @@ func.func @test_pad_invalid_padding_rank(%arg0: tensor<13x21xf32>) {
 func.func @test_pad_invalid_padConst_rank(%arg0: tensor<13x21xf32>, %arg1: tensor<2x2xi32>) {
   %0 = tosa.const_shape {value = dense<1> : tensor<4xindex>} : () -> !tosa.shape<4>
   %1 = "tosa.const"() {value = dense<3.14> : tensor<2xf32>} : () -> tensor<2xf32>
-  // expected-error at +1 {{'tosa.pad' op operand #2 must be 0D tensor of number values, but got 'tensor<2xf32>'}}
+  // expected-error at +1 {{'tosa.pad' op operand #2 must be tosa-conformant scalar tensor of number values, but got 'tensor<2xf32>'}}
   %2 = tosa.pad %arg0, %0, %1 : (tensor<13x21xf32>, !tosa.shape<4>, tensor<2xf32>) -> tensor<13x21xf32>
   return
 }
@@ -545,22 +545,22 @@ func.func @test_const_attribute_type_mismatch() -> tensor<100x100xf32> {
 // -----
 
 func.func @test_conv2d_static_zero_dim_input(%arg0: tensor<1x29x0x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<16xf32>) -> tensor<1x27x27x16xf32> {
-  %input_zp = "tosa.const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
-  %weight_zp = "tosa.const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
+  %input_zp = "tosa.const"() <{value = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
+  %weight_zp = "tosa.const"() <{value = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
   // expected-error at +1 {{'tosa.conv2d' op operand #0 must be 4-d tosa-conformant tensor, but got 'tensor<1x29x0x4xf32>'}}
   %0 = tosa.conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
-           : (tensor<1x29x0x4xf32>, tensor<16x3x3x4xf32>, tensor<16xf32>, tensor<i32>, tensor<i32>) -> tensor<1x27x27x16xf32>
+           : (tensor<1x29x0x4xf32>, tensor<16x3x3x4xf32>, tensor<16xf32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x27x27x16xf32>
   return %0 : tensor<1x27x27x16xf32>
 }
 
 // -----
 
 func.func @test_conv2d_zero_dim_input(%arg0: tensor<1x?x0x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<16xf32>) -> tensor<1x27x27x16xf32> {
-  %input_zp = "tosa.const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
-  %weight_zp = "tosa.const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
+  %input_zp = "tosa.const"() <{value = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
+  %weight_zp = "tosa.const"() <{value = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
   // expected-error at +1 {{'tosa.conv2d' op operand #0 must be 4-d tosa-conformant tensor, but got 'tensor<1x?x0x4xf32>'}}
   %0 = tosa.conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
-           : (tensor<1x?x0x4xf32>, tensor<16x3x3x4xf32>, tensor<16xf32>, tensor<i32>, tensor<i32>) -> tensor<1x27x27x16xf32>
+           : (tensor<1x?x0x4xf32>, tensor<16x3x3x4xf32>, tensor<16xf32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x27x27x16xf32>
   return %0 : tensor<1x27x27x16xf32>
 }
 
@@ -1055,6 +1055,7 @@ func.func @test_shape_type(%arg0: !tosa.shape<-1>) -> !tosa.shape<-1> {
 }
 
 // -----
+
 func.func @test_const_shape() -> !tosa.shape<4> {
   // expected-error at +1 {{'tosa.const_shape' op attribute 'value' failed to satisfy constraint: index elements attribute}}
   %cst = tosa.const_shape {value = dense<[1, 2, 3, 4]> : tensor<4xi32>} : () -> !tosa.shape<4>
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index bea73ab92f2e3..cb45c4465cde6 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -591,9 +591,9 @@ func.func @test_pad(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
 // -----
 // CHECK-LABEL: pad_explicit_value
 func.func @test_pad_explicit_value(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
-  %0 = "tosa.const"() {value = dense<3.14> : tensor<f32>} : () -> tensor<f32>
+  %0 = "tosa.const"() {value = dense<3.14> : tensor<1xf32>} : () -> tensor<1xf32>
   %padding = tosa.const_shape {value = dense<0> : tensor<6xindex>} : () -> !tosa.shape<6>
-  %1 = tosa.pad %arg0, %padding, %0 : (tensor<13x21x3xf32>, !tosa.shape<6>, tensor<f32>) -> tensor<13x21x3xf32>
+  %1 = tosa.pad %arg0, %padding, %0 : (tensor<13x21x3xf32>, !tosa.shape<6>, tensor<1xf32>) -> tensor<13x21x3xf32>
   return %1 : tensor<13x21x3xf32>
 }
 
diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir
index f9f3c074b3716..7cd44ba475dbb 100644
--- a/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir
@@ -60,13 +60,13 @@ func.func @depthwise_conv2d_as_mul_q(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor<
 func.func @depthwise_conv2d_as_mul_padded(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<1x1x2x3xf32>, %arg2: tensor<6xf32>) -> tensor<4x12x12x6xf32> {
   // CHECK-DAG: %[[CONST0:.+]] = tosa.const_shape {value = dense<[4, 10, 10, 2, 1]> : tensor<5xindex>}
   // CHECK-DAG: %[[PAD:.+]] = tosa.const_shape  {value = dense<[0, 0, 1, 1, 1, 1, 0, 0, 0, 0]> : tensor<10xindex>} : () -> !tosa.shape<10>
-  // CHECK-DAG: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<f32>}
+  // CHECK-DAG: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1xf32>}
   // CHECK-DAG: %[[CONST3:.+]] = tosa.const_shape {value = dense<[1, 1, 1, 2, 3]> : tensor<5xindex>}
   // CHECK-DAG: %[[CONST4:.+]] = tosa.const_shape {value = dense<[4, 12, 12, 6]> : tensor<4xindex>}
   // CHECK-DAG: %[[CONST5:.+]] = tosa.const_shape {value = dense<[1, 1, 1, 6]> : tensor<4xindex>}
   // CHECK-DAG: %[[SHIFT:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
   // CHECK: %[[RESHAPE_I:.+]] = tosa.reshape %arg0, %[[CONST0]]
-  // CHECK: %[[PAD_I:.+]] = tosa.pad %[[RESHAPE_I]], %[[PAD]], %[[ZERO]] : (tensor<4x10x10x2x1xf32>, !tosa.shape<10>, tensor<f32>) -> tensor<4x12x12x2x1xf32>
+  // CHECK: %[[PAD_I:.+]] = tosa.pad %[[RESHAPE_I]], %[[PAD]], %[[ZERO]] : (tensor<4x10x10x2x1xf32>, !tosa.shape<10>, tensor<1xf32>) -> tensor<4x12x12x2x1xf32>
   // CHECK: %[[RESHAPE_ARG1:.+]] = tosa.reshape %arg1, %[[CONST3]]
   // CHECK: %[[MUL:.+]] = tosa.mul %[[PAD_I]], %[[RESHAPE_ARG1]], %[[SHIFT]]
   // CHECK: %[[RESHAPE_O:.+]] = tosa.reshape %[[MUL]], %[[CONST4]]

``````````

</details>


https://github.com/llvm/llvm-project/pull/129156


More information about the Mlir-commits mailing list