[Mlir-commits] [mlir] [mlir][linalg] Fix vectorizer generating invalid vector.gather for 0-D tensor.extract (PR #187085)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Mar 17 12:00:27 PDT 2026


https://github.com/edg-l updated https://github.com/llvm/llvm-project/pull/187085

>From a18ee3754b3459967d75465998ebb32644fce3d6 Mon Sep 17 00:00:00 2001
From: Edgar Luque <git at edgl.dev>
Date: Tue, 17 Mar 2026 19:24:46 +0100
Subject: [PATCH] [mlir][linalg] Fix vectorizer generating invalid
 vector.gather for 0-D tensor.extract

When vectorizing a rank-0 linalg.generic whose body contains
tensor.extract with data-dependent indices, the vectorizer incorrectly
classified the access as a Gather (since the 0-D result vector has no
dimension > 1). This produced an invalid vector.gather with a scalar
index operand where a vector of indices is required.

Fix by classifying 0-D result vectors as ScalarBroadcast in
getTensorExtractMemoryAccessPattern, and skipping the masking logic
in the ScalarBroadcast path when the result rank is 0 (0-D vectors
don't support masking).
---
 .../Linalg/Transforms/Vectorization.cpp       | 28 +++++++-----
 .../Dialect/Linalg/vectorization/extract.mlir | 44 +++++++++++++++++++
 2 files changed, 61 insertions(+), 11 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 0477815f329bf..d90d28fbddeec 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1093,6 +1093,11 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
   if (inputShape.getShape().empty())
     return VectorMemoryAccessKind::ScalarBroadcast;
 
+  // 0a. Is the result a 0-D vector? If yes, there are no iteration dimensions
+  // so the tensor.extract is a single scalar load regardless of the index.
+  if (resType.getRank() == 0)
+    return VectorMemoryAccessKind::ScalarBroadcast;
+
   // True for vectors that are effectively 1D, e.g. `vector<1x4x1xi32>`, false
   // otherwise.
   bool isOutput1DVector =
@@ -1254,19 +1259,20 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
         rewriter, loc, resultType, extractOp.getTensor(), transferReadIdxs,
         /*padding=*/std::nullopt, permutationMap, inBounds);
 
-    // Mask this broadcasting xfer_read here rather than relying on the generic
-    // path (the generic path assumes identity masking map, which wouldn't be
-    // valid here).
-    SmallVector<int64_t> readMaskShape = {1};
-    auto readMaskType = VectorType::get(readMaskShape, rewriter.getI1Type());
-    auto allTrue = vector::ConstantMaskOp::create(
-        rewriter, loc, readMaskType, vector::ConstantMaskKind::AllTrue);
-    auto *maskedReadOp =
-        mlir::vector::maskOperation(rewriter, transferReadOp, allTrue);
+    Operation *resultOp = transferReadOp;
+    if (dstRank > 0) {
+      // Mask this broadcasting xfer_read here rather than relying on the
+      // generic path (the generic path assumes identity masking map, which
+      // wouldn't be valid here).
+      SmallVector<int64_t> readMaskShape = {1};
+      auto readMaskType = VectorType::get(readMaskShape, rewriter.getI1Type());
+      auto allTrue = vector::ConstantMaskOp::create(
+          rewriter, loc, readMaskType, vector::ConstantMaskKind::AllTrue);
+      resultOp = mlir::vector::maskOperation(rewriter, transferReadOp, allTrue);
+    }
 
     LDBG() << "Vectorised as scalar broadcast load: " << extractOp;
-    return VectorizationHookResult{VectorizationHookStatus::NewOp,
-                                   maskedReadOp};
+    return VectorizationHookResult{VectorizationHookStatus::NewOp, resultOp};
   }
 
   // 2b. Handle contiguous access.
diff --git a/mlir/test/Dialect/Linalg/vectorization/extract.mlir b/mlir/test/Dialect/Linalg/vectorization/extract.mlir
index 76ac4b8398069..88ee03167aa35 100644
--- a/mlir/test/Dialect/Linalg/vectorization/extract.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/extract.mlir
@@ -477,3 +477,47 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+// Rank-0 linalg.generic with tensor.extract using a data-dependent index.
+// The tensor.extract should be classified as ScalarBroadcast (not Gather),
+// producing a vector.transfer_read of a 0-D vector.
+
+func.func @rank0_tensor_extract_data_dependent_index(
+    %src: tensor<2xi64>,
+    %idx_tensor: tensor<i64>) -> tensor<i64> {
+
+  %init = tensor.empty() : tensor<i64>
+  %res = linalg.generic {
+    indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>],
+    iterator_types = []}
+    ins(%idx_tensor : tensor<i64>) outs(%init : tensor<i64>) {
+  ^bb0(%in: i64, %out: i64):
+    %idx = arith.index_cast %in : i64 to index
+    %val = tensor.extract %src[%idx] : tensor<2xi64>
+    linalg.yield %val : i64
+  } -> tensor<i64>
+
+  return %res : tensor<i64>
+}
+
+// CHECK-LABEL: func.func @rank0_tensor_extract_data_dependent_index(
+// CHECK-SAME:      %[[SRC:.*]]: tensor<2xi64>,
+// CHECK-SAME:      %[[IDX_TENSOR:.*]]: tensor<i64>) -> tensor<i64> {
+// CHECK-DAG:       %[[INIT:.*]] = tensor.empty() : tensor<i64>
+// CHECK-DAG:       %[[PAD:.*]] = ub.poison : i64
+// CHECK:           %[[READ_IDX:.*]] = vector.transfer_read %[[IDX_TENSOR]][], %[[PAD]] : tensor<i64>, vector<i64>
+// CHECK:           %[[SCALAR_IDX:.*]] = vector.extract %[[READ_IDX]][] : i64 from vector<i64>
+// CHECK:           %[[INDEX:.*]] = arith.index_cast %[[SCALAR_IDX]] : i64 to index
+// CHECK:           %[[READ_VAL:.*]] = vector.transfer_read %[[SRC]][%[[INDEX]]], %{{.*}} : tensor<2xi64>, vector<i64>
+// CHECK:           %[[WRITE:.*]] = vector.transfer_write %[[READ_VAL]], %[[INIT]][] : vector<i64>, tensor<i64>
+// CHECK:           return %[[WRITE]] : tensor<i64>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.generic"]} in %module : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 {vectorize_nd_extract} : !transform.any_op
+    transform.yield
+  }
+}



More information about the Mlir-commits mailing list