[Mlir-commits] [mlir] [mlir][linalg] Handle reassociationIndices correctly for 0D tensor (PR #121683)

Longsheng Mou llvmlistbot at llvm.org
Tue Jan 7 06:29:45 PST 2025


https://github.com/CoTinker updated https://github.com/llvm/llvm-project/pull/121683

>From 72ebc8086c7ef5b2570e5026fe6a97e3abb607d5 Mon Sep 17 00:00:00 2001
From: Longsheng Mou <longshengmou at gmail.com>
Date: Sun, 5 Jan 2025 16:29:48 +0800
Subject: [PATCH] [mlir][linalg] Handle reassociationIndices correctly for 0D
 tensor

This PR fixes a bug where a value is assigned to a 0-sized
reassociationIndices, preventing a crash.
---
 .../Conversion/TosaToLinalg/TosaToLinalg.cpp  | 25 +++++++++++--------
 .../TosaToLinalg/tosa-to-linalg.mlir          | 23 +++++++++++++++++
 2 files changed, 37 insertions(+), 11 deletions(-)

diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 88e544c4e4b5f1..1d7ead16e8b631 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -600,30 +600,33 @@ static Value createLinalgBodyCalculationForElementwiseOp(
 static Value expandRank(PatternRewriter &rewriter, Location loc, Value tensor,
                         int64_t rank) {
   // No need to expand if we are already at the desired rank
-  auto shapedType = dyn_cast<ShapedType>(tensor.getType());
-  assert(shapedType && shapedType.hasRank() && "expected a ranked shaped type");
-  int64_t numExtraDims = rank - shapedType.getRank();
+  auto tensorType = dyn_cast<RankedTensorType>(tensor.getType());
+  assert(tensorType && "expected a ranked tensor type");
+  int64_t tensorRank = tensorType.getRank();
+  int64_t numExtraDims = rank - tensorRank;
   assert(numExtraDims >= 0 && "cannot expand tensor to a lower rank");
   if (!numExtraDims)
     return tensor;
 
   // Compute reassociation indices
-  SmallVector<SmallVector<int64_t, 2>> reassociationIndices(
-      shapedType.getRank());
+  SmallVector<ReassociationIndices> reassociationIndices(tensorRank);
   int64_t index = 0;
-  for (index = 0; index <= numExtraDims; index++)
-    reassociationIndices[0].push_back(index);
-  for (size_t position = 1; position < reassociationIndices.size(); position++)
-    reassociationIndices[position].push_back(index++);
+  if (tensorRank != 0) {
+    for (index = 0; index <= numExtraDims; index++)
+      reassociationIndices[0].push_back(index);
+    for (size_t position = 1; position < reassociationIndices.size();
+         position++)
+      reassociationIndices[position].push_back(index++);
+  }
 
   // Compute result type
   SmallVector<int64_t> resultShape;
   for (index = 0; index < numExtraDims; index++)
     resultShape.push_back(1);
-  for (auto size : shapedType.getShape())
+  for (auto size : tensorType.getShape())
     resultShape.push_back(size);
   auto resultType =
-      RankedTensorType::get(resultShape, shapedType.getElementType());
+      RankedTensorType::get(resultShape, tensorType.getElementType());
 
   // Emit 'tensor.expand_shape' op
   return rewriter.create<tensor::ExpandShapeOp>(loc, resultType, tensor,
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 265a75986c6c8d..c840fb8648d7b7 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -100,6 +100,29 @@ func.func @test_add_0d(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> {
 
 // -----
 
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1) -> (d0, 0)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (0, 0)>
+// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+
+// CHECK-LABEL:   func.func @test_add_0d_broadcast(
+// CHECK-SAME:                                     %[[ARG0:.*]]: tensor<2x1xf32>,
+// CHECK-SAME:                                     %[[ARG1:.*]]: tensor<f32>) -> tensor<2x1xf32> {
+// CHECK:           %[[EXPANDED:.*]] = tensor.expand_shape %[[ARG1]] [] output_shape [1, 1] : tensor<f32> into tensor<1x1xf32>
+// CHECK:           %[[EMPTY_TENSOR:.*]] = tensor.empty() : tensor<2x1xf32>
+// CHECK:           %[[RESULT:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]], %[[EXPANDED]] : tensor<2x1xf32>, tensor<1x1xf32>) outs(%[[EMPTY_TENSOR]] : tensor<2x1xf32>) {
+// CHECK:           ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK:             %[[ADD:.*]] = arith.addf %[[IN0]], %[[IN1]] : f32
+// CHECK:             linalg.yield %[[ADD]] : f32
+// CHECK:           } -> tensor<2x1xf32>
+// CHECK:           return %[[RESULT]] : tensor<2x1xf32>
+// CHECK:         }
+func.func @test_add_0d_broadcast(%arg0: tensor<2x1xf32>, %arg1: tensor<f32>) -> tensor<2x1xf32> {
+  %0 = tosa.add %arg0, %arg1 : (tensor<2x1xf32>, tensor<f32>) -> tensor<2x1xf32>
+  return %0 : tensor<2x1xf32>
+}
+
+// -----
+
 // CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (0)>
 // CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> (d0)>
 // CHECK-LABEL: @test_add_1d_all_dynamic



More information about the Mlir-commits mailing list