[Mlir-commits] [mlir] 751c3f5 - [mlir][tosa] Update TileOp infer shape (#134732)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Apr 9 09:29:33 PDT 2025


Author: Jerry-Ge
Date: 2025-04-09T09:29:29-07:00
New Revision: 751c3f51eb65b493bf2fc2c6b2788f89e16a1fbe

URL: https://github.com/llvm/llvm-project/commit/751c3f51eb65b493bf2fc2c6b2788f89e16a1fbe
DIFF: https://github.com/llvm/llvm-project/commit/751c3f51eb65b493bf2fc2c6b2788f89e16a1fbe.diff

LOG: [mlir][tosa] Update TileOp infer shape (#134732)

update to use getConstShapeValues in TileOp's shape inference

Signed-off-by: Tai Ly <tai.ly at arm.com>
Co-authored-by: Tai Ly <tai.ly at arm.com>

Added: 
    

Modified: 
    mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
    mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 59946ca54b933..5941be8403480 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1616,19 +1616,25 @@ LogicalResult tosa::TileOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
     TileOp::Adaptor adaptor,
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
-  DenseIntElementsAttr multiplesAttr;
-  if (!matchPattern(adaptor.getMultiples(), m_Constant(&multiplesAttr)))
-    return failure();
-
-  SmallVector<int64_t> multiples = llvm::to_vector(
-      llvm::map_range(multiplesAttr.getValues<APInt>(),
-                      [](const APInt &val) { return val.getSExtValue(); }));
+  Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
+  SmallVector<int64_t> multiples;
+  if (!tosa::getConstShapeValues(adaptor.getMultiples().getDefiningOp(),
+                                 multiples)) {
+    auto rank =
+        cast<tosa::shapeType>(adaptor.getMultiples().getType()).getRank();
+    SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
+    inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
+    return success();
+  } else {
+    multiples = convertToMlirShape(multiples);
+  }
 
   ShapeAdaptor inputShape(adaptor.getInput1().getType());
   SmallVector<int64_t> outputShape;
   if (!inputShape.hasRank()) {
     outputShape.resize(multiples.size(), ShapedType::kDynamic);
-    inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
+    inferredReturnShapes.push_back(
+        ShapedTypeComponents(outputShape, inputType));
     return success();
   } else if (static_cast<size_t>(inputShape.getRank()) != multiples.size())
     return failure();
@@ -1636,13 +1642,17 @@ LogicalResult tosa::TileOp::inferReturnTypeComponents(
   // Any non dynamic dimension can be multiplied to a known size.
   outputShape.reserve(multiples.size());
   for (int i = 0, s = inputShape.getRank(); i < s; i++) {
-    int64_t dim = inputShape.getDimSize(i);
-    if (dim != ShapedType::kDynamic)
-      dim *= multiples[i];
-    outputShape.push_back(dim);
+    if (multiples[i] == ShapedType::kDynamic) {
+      outputShape.push_back(ShapedType::kDynamic);
+    } else {
+      int64_t dim = inputShape.getDimSize(i);
+      if (dim != ShapedType::kDynamic)
+        dim *= multiples[i];
+      outputShape.push_back(dim);
+    }
   }
 
-  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
+  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
   return success();
 }
 

diff  --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index 9160f388be053..fe9da2ac09650 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -599,6 +599,17 @@ func.func @test_tile(%arg0 : tensor<2x3x?xi32>) -> () {
 
 // -----
 
+// CHECK-LABEL: @test_tile_unknown_multiples
+func.func @test_tile_unknown_multiples(%arg0 : tensor<2x3x?xi32>) -> () {
+  // CHECK: %[[CST:.*]] = tosa.const_shape {values = dense<[2, -1, 5]> : tensor<3xindex>} : () -> !tosa.shape<3>
+  // CHECK: tosa.tile %arg0, %[[CST]] : (tensor<2x3x?xi32>, !tosa.shape<3>) -> tensor<4x?x?xi32>
+  %cst = tosa.const_shape {values = dense<[2, -1, 5]> : tensor<3xindex>} : () -> !tosa.shape<3>
+  %0 = tosa.tile %arg0, %cst : (tensor<2x3x?xi32>, !tosa.shape<3>) -> tensor<?x?x?xi32>
+  return
+}
+
+// -----
+
 // CHECK-LABEL: @test_transpose_static
 func.func @test_transpose_static(%arg0 : tensor<3x4x5xi32>) -> () {
   // CHECK: tosa.transpose %arg0 {perms = array<i32: 2, 1, 0>} : (tensor<3x4x5xi32>) -> tensor<5x4x3xi32>


        


More information about the Mlir-commits mailing list