[Mlir-commits] [mlir] [mlir][linalg] Handle reassociationIndices correctly for 0D tensor (PR #121683)
Longsheng Mou
llvmlistbot at llvm.org
Sun Jan 5 01:03:46 PST 2025
https://github.com/CoTinker updated https://github.com/llvm/llvm-project/pull/121683
>From 6fa17acb0febf0977522b7570379c3bcc1035ac4 Mon Sep 17 00:00:00 2001
From: Longsheng Mou <longshengmou at gmail.com>
Date: Sun, 5 Jan 2025 16:29:48 +0800
Subject: [PATCH] [mlir][linalg] Handle reassociationIndices correctly for 0D
tensor
This PR fixes a bug where a value is assigned to a 0-sized
reassociationIndices, preventing a crash.
---
.../Conversion/TosaToLinalg/TosaToLinalg.cpp | 11 +++++----
.../TosaToLinalg/tosa-to-linalg.mlir | 23 +++++++++++++++++++
2 files changed, 30 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 88e544c4e4b5f1..ac4078a9ffe0cb 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -611,10 +611,13 @@ static Value expandRank(PatternRewriter &rewriter, Location loc, Value tensor,
SmallVector<SmallVector<int64_t, 2>> reassociationIndices(
shapedType.getRank());
int64_t index = 0;
- for (index = 0; index <= numExtraDims; index++)
- reassociationIndices[0].push_back(index);
- for (size_t position = 1; position < reassociationIndices.size(); position++)
- reassociationIndices[position].push_back(index++);
+ if (shapedType.getRank() != 0) {
+ for (index = 0; index <= numExtraDims; index++)
+ reassociationIndices[0].push_back(index);
+ for (size_t position = 1; position < reassociationIndices.size();
+ position++)
+ reassociationIndices[position].push_back(index++);
+ }
// Compute result type
SmallVector<int64_t> resultShape;
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 265a75986c6c8d..50f4a84bcc89a0 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1964,3 +1964,26 @@ func.func @test_cast_fp32_i64(%arg0: tensor<1xf32>) -> (tensor<1xi64>) {
%0 = tosa.cast %arg0 : (tensor<1xf32>) -> tensor<1xi64>
return %0: tensor<1xi64>
}
+
+// -----
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1) -> (d0, 0)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (0, 0)>
+// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+
+// CHECK-LABEL: func.func @test_add_0d_broadcast(
+// CHECK-SAME: %[[ARG0:.*]]: tensor<2x1xf32>,
+// CHECK-SAME: %[[ARG1:.*]]: tensor<f32>) -> tensor<2x1xf32> {
+// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[ARG1]] [] output_shape [1, 1] : tensor<f32> into tensor<1x1xf32>
+// CHECK: %[[EMPTY_TENSOR:.*]] = tensor.empty() : tensor<2x1xf32>
+// CHECK: %[[RESULT:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]], %[[EXPANDED]] : tensor<2x1xf32>, tensor<1x1xf32>) outs(%[[EMPTY_TENSOR]] : tensor<2x1xf32>) {
+// CHECK: ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK: %[[ADD:.*]] = arith.addf %[[IN0]], %[[IN1]] : f32
+// CHECK: linalg.yield %[[ADD]] : f32
+// CHECK: } -> tensor<2x1xf32>
+// CHECK: return %[[RESULT]] : tensor<2x1xf32>
+// CHECK: }
+func.func @test_add_0d_broadcast(%arg0: tensor<2x1xf32>, %arg1: tensor<f32>) -> tensor<2x1xf32> {
+ %0 = tosa.add %arg0, %arg1 : (tensor<2x1xf32>, tensor<f32>) -> tensor<2x1xf32>
+ return %0 : tensor<2x1xf32>
+}
More information about the Mlir-commits
mailing list