[Mlir-commits] [mlir] 751c3f5 - [mlir][tosa] Update TileOp infer shape (#134732)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Apr 9 09:29:33 PDT 2025
Author: Jerry-Ge
Date: 2025-04-09T09:29:29-07:00
New Revision: 751c3f51eb65b493bf2fc2c6b2788f89e16a1fbe
URL: https://github.com/llvm/llvm-project/commit/751c3f51eb65b493bf2fc2c6b2788f89e16a1fbe
DIFF: https://github.com/llvm/llvm-project/commit/751c3f51eb65b493bf2fc2c6b2788f89e16a1fbe.diff
LOG: [mlir][tosa] Update TileOp infer shape (#134732)
update to use getConstShapeValues in TileOp's shape inference
Signed-off-by: Tai Ly <tai.ly at arm.com>
Co-authored-by: Tai Ly <tai.ly at arm.com>
Added:
Modified:
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 59946ca54b933..5941be8403480 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1616,19 +1616,25 @@ LogicalResult tosa::TileOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
TileOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
- DenseIntElementsAttr multiplesAttr;
- if (!matchPattern(adaptor.getMultiples(), m_Constant(&multiplesAttr)))
- return failure();
-
- SmallVector<int64_t> multiples = llvm::to_vector(
- llvm::map_range(multiplesAttr.getValues<APInt>(),
- [](const APInt &val) { return val.getSExtValue(); }));
+ Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
+ SmallVector<int64_t> multiples;
+ if (!tosa::getConstShapeValues(adaptor.getMultiples().getDefiningOp(),
+ multiples)) {
+ auto rank =
+ cast<tosa::shapeType>(adaptor.getMultiples().getType()).getRank();
+ SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
+ inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
+ return success();
+ } else {
+ multiples = convertToMlirShape(multiples);
+ }
ShapeAdaptor inputShape(adaptor.getInput1().getType());
SmallVector<int64_t> outputShape;
if (!inputShape.hasRank()) {
outputShape.resize(multiples.size(), ShapedType::kDynamic);
- inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
+ inferredReturnShapes.push_back(
+ ShapedTypeComponents(outputShape, inputType));
return success();
} else if (static_cast<size_t>(inputShape.getRank()) != multiples.size())
return failure();
@@ -1636,13 +1642,17 @@ LogicalResult tosa::TileOp::inferReturnTypeComponents(
// Any non dynamic dimension can be multiplied to a known size.
outputShape.reserve(multiples.size());
for (int i = 0, s = inputShape.getRank(); i < s; i++) {
- int64_t dim = inputShape.getDimSize(i);
- if (dim != ShapedType::kDynamic)
- dim *= multiples[i];
- outputShape.push_back(dim);
+ if (multiples[i] == ShapedType::kDynamic) {
+ outputShape.push_back(ShapedType::kDynamic);
+ } else {
+ int64_t dim = inputShape.getDimSize(i);
+ if (dim != ShapedType::kDynamic)
+ dim *= multiples[i];
+ outputShape.push_back(dim);
+ }
}
- inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
+ inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
return success();
}
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index 9160f388be053..fe9da2ac09650 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -599,6 +599,17 @@ func.func @test_tile(%arg0 : tensor<2x3x?xi32>) -> () {
// -----
+// CHECK-LABEL: @test_tile_unknown_multiples
+func.func @test_tile_unknown_multiples(%arg0 : tensor<2x3x?xi32>) -> () {
+ // CHECK: %[[CST:.*]] = tosa.const_shape {values = dense<[2, -1, 5]> : tensor<3xindex>} : () -> !tosa.shape<3>
+ // CHECK: tosa.tile %arg0, %[[CST]] : (tensor<2x3x?xi32>, !tosa.shape<3>) -> tensor<4x?x?xi32>
+ %cst = tosa.const_shape {values = dense<[2, -1, 5]> : tensor<3xindex>} : () -> !tosa.shape<3>
+ %0 = tosa.tile %arg0, %cst : (tensor<2x3x?xi32>, !tosa.shape<3>) -> tensor<?x?x?xi32>
+ return
+}
+
+// -----
+
// CHECK-LABEL: @test_transpose_static
func.func @test_transpose_static(%arg0 : tensor<3x4x5xi32>) -> () {
// CHECK: tosa.transpose %arg0 {perms = array<i32: 2, 1, 0>} : (tensor<3x4x5xi32>) -> tensor<5x4x3xi32>
More information about the Mlir-commits
mailing list