[Mlir-commits] [mlir] [mlir][vector] Missing indices on vectorization of 1-d reduction to 1-ranked memref (PR #166959)

Simone Pellegrini llvmlistbot at llvm.org
Fri Nov 7 07:42:10 PST 2025


https://github.com/simpel01 created https://github.com/llvm/llvm-project/pull/166959


Vectorization of a 1-d reduction where the output variable is a 1-ranked memref can generate an invalid `vector.transfer_write` with no indices for the memref, e.g.:

  vector.transfer_write"(%vec, %buff) <{...}> : (vector<f32>, memref<1xf32>) -> ()

This patch solves the problem by providing the expected amount of indices (i.e. matching the rank of the memref).

>From b67ae6a89d12b60cf5bc8198bb28fbb910281e43 Mon Sep 17 00:00:00 2001
From: Simone Pellegrini <simone.pellegrini at arm.com>
Date: Fri, 7 Nov 2025 13:48:48 +0100
Subject: [PATCH] [mlir][vector] Missing indices on vectorization of 1-d
 reduction to 1-ranked memref

Vectorization of a 1-d reduction where the output variable is a 1-ranked memref
can generate an invalid `vector.transfer_write` with no indices for the memref, e.g.:

  vector.transfer_write"(%vec, %buff) <{...}> : (vector<f32>, memref<1xf32>) -> ()

This patch solves the problem by providing the expected amount of indices (i.e.
matching the rank of the memref).
---
 .../Linalg/Transforms/Vectorization.cpp       |  8 ++--
 .../linalg-ops-with-patterns.mlir             | 43 +++++++++++++++++++
 2 files changed, 47 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 19d2d854a3838..4eb2a0cb200a0 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -746,12 +746,12 @@ static Value buildVectorWrite(RewriterBase &rewriter, Value value,
   auto vectorType = state.getCanonicalVecType(
       getElementTypeOrSelf(outputOperand->get().getType()), vectorTypeMap);
 
+  SmallVector<Value> indices(linalgOp.getRank(outputOperand),
+                             arith::ConstantIndexOp::create(rewriter, loc, 0));
+
   Operation *write;
   if (vectorType.getRank() > 0) {
     AffineMap writeMap = inversePermutation(reindexIndexingMap(opOperandMap));
-    SmallVector<Value> indices(
-        linalgOp.getRank(outputOperand),
-        arith::ConstantIndexOp::create(rewriter, loc, 0));
     value = broadcastIfNeeded(rewriter, value, vectorType);
     assert(value.getType() == vectorType && "Incorrect type");
     write = vector::TransferWriteOp::create(
@@ -762,7 +762,7 @@ static Value buildVectorWrite(RewriterBase &rewriter, Value value,
       value = vector::BroadcastOp::create(rewriter, loc, vectorType, value);
     assert(value.getType() == vectorType && "Incorrect type");
     write = vector::TransferWriteOp::create(rewriter, loc, value,
-                                            outputOperand->get(), ValueRange{});
+                                            outputOperand->get(), indices);
   }
 
   write = state.maskOperation(rewriter, write, linalgOp, opOperandMap);
diff --git a/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir b/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir
index 9a14ab7d38d3e..6e63dfe7bb8e5 100644
--- a/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir
@@ -1523,6 +1523,49 @@ module attributes {transform.with_named_sequence} {
 }
 
 
+// -----
+
+//  CHECK-LABEL: func @reduce_1d_memref(
+//   CHECK-SAME:   %[[A:.*]]: memref<32xf32>
+//   CHECK-SAME:   %[[B:.*]]: memref<1xf32>
+func.func @reduce_1d_memref(%arg0: memref<32xf32>, %arg1: memref<1xf32>) {
+  //  CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+
+  //      CHECK: %[[r:.*]] = vector.transfer_read %[[A]][%[[C0]]]
+  // CHECK-SAME:   : memref<32xf32>, vector<32xf32>
+  //      CHECK: %[[init:.*]] = vector.transfer_read %[[B]][%[[C0]]]
+  // CHECK-SAME:   : memref<1xf32>, vector<f32>
+  //      CHECK: %[[init_scl:.*]] = vector.extract %[[init]][]
+  // CHECK-SAME:   : f32 from vector<f32>
+  //      CHECK: %[[red:.*]] = vector.multi_reduction <add>, %[[r]], %[[init_scl]] [0]
+  // CHECK-SAME:   : vector<32xf32> to f32
+  //      CHECK: %[[red_v1:.*]] = vector.broadcast %[[red]] : f32 to vector<f32>
+  //      CHECK: vector.transfer_write %[[red_v1]], %[[B]][%[[C0]]]
+  // CHECK-SAME:   : vector<f32>, memref<1xf32>
+  linalg.generic {
+         indexing_maps = [affine_map<(d0) -> (d0)>,
+                          affine_map<(d0) -> (0)>],
+         iterator_types = ["reduction"]}
+         ins(%arg0 : memref<32xf32>)
+         outs(%arg1 : memref<1xf32>) {
+    ^bb0(%a: f32, %b: f32):
+      %0 = arith.addf %a, %b : f32
+      linalg.yield %0 : f32
+    }
+
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+    %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+
 // -----
 
 // This test checks that vectorization does not occur when an input indexing map



More information about the Mlir-commits mailing list