[Mlir-commits] [mlir] [mlir][vector] Fix masked load/store emulation for rank-0 memrefs (PR #173325)

Prathamesh Tagore llvmlistbot at llvm.org
Mon Dec 22 17:38:57 PST 2025


https://github.com/meshtag created https://github.com/llvm/llvm-project/pull/173325

Added rank‑0 handling to masked load/store emulation by reinterpreting rank‑0 memrefs as 1‑D buffers with a synthetic index, preventing empty‑indices crashes.

Fixes https://github.com/llvm/llvm-project/issues/131243

>From 163f2244fd1f99919b68c4d7fa4141b2375db43a Mon Sep 17 00:00:00 2001
From: Prathamesh Tagore <prathameshtagore at gmail.com>
Date: Tue, 23 Dec 2025 02:30:55 +0100
Subject: [PATCH] [mlir][vector] Fix masked load/store emulation for rank-0
 memrefs
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Added rank‑0 handling to masked load/store emulation by reinterpreting rank‑0
memrefs as 1‑D buffers with a synthetic index, preventing empty‑indices
crashes.
---
 .../VectorEmulateMaskedLoadStore.cpp          | 22 +++++++++++++++++++
 .../vector-emulate-masked-load-store.mlir     | 18 +++++++++++++++
 2 files changed, 40 insertions(+)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp
index 7acc120508a44..cfd478b27908b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp
@@ -19,6 +19,26 @@ using namespace mlir;
 
 namespace {
 
+/// Ensure that `base` has at least one index by reinterpreting rank-0 memrefs
+/// as 1-D memrefs. This avoids crashing on rank-0 memrefs for the pass.
+static void ensureBaseHasIndex(PatternRewriter &rewriter, Location loc,
+                               Value &base, SmallVectorImpl<Value> &indices,
+                               int64_t maskLength) {
+  if (!indices.empty())
+    return;
+
+  // Rank-0 memrefs have no indices, reinterpret as 1-D to step through lanes.
+  auto meta = memref::ExtractStridedMetadataOp::create(rewriter, loc, base);
+  SmallVector<OpFoldResult, 1> sizes = {rewriter.getIndexAttr(maskLength)};
+  SmallVector<OpFoldResult, 1> strides = {rewriter.getIndexAttr(1)};
+  base = memref::ReinterpretCastOp::create(rewriter, loc, meta.getBaseBuffer(),
+                                           meta.getOffset(), sizes, strides);
+
+  Type indexType = rewriter.getIndexType();
+  indices.push_back(arith::ConstantOp::create(rewriter, loc, indexType,
+                                              IntegerAttr::get(indexType, 0)));
+}
+
 /// Convert vector.maskedload
 ///
 /// Before:
@@ -65,6 +85,7 @@ struct VectorMaskedLoadOpConverter final
     Value base = maskedLoadOp.getBase();
     Value iValue = maskedLoadOp.getPassThru();
     auto indices = llvm::to_vector_of<Value>(maskedLoadOp.getIndices());
+    ensureBaseHasIndex(rewriter, loc, base, indices, maskLength);
     Value one = arith::ConstantOp::create(rewriter, loc, indexType,
                                           IntegerAttr::get(indexType, 1));
     for (int64_t i = 0; i < maskLength; ++i) {
@@ -135,6 +156,7 @@ struct VectorMaskedStoreOpConverter final
     Value value = maskedStoreOp.getValueToStore();
     bool nontemporal = false;
     auto indices = llvm::to_vector_of<Value>(maskedStoreOp.getIndices());
+    ensureBaseHasIndex(rewriter, loc, base, indices, maskLength);
     Value one = arith::ConstantOp::create(rewriter, loc, indexType,
                                           IntegerAttr::get(indexType, 1));
     for (int64_t i = 0; i < maskLength; ++i) {
diff --git a/mlir/test/Dialect/Vector/vector-emulate-masked-load-store.mlir b/mlir/test/Dialect/Vector/vector-emulate-masked-load-store.mlir
index 6e5d68c859e2c..8bace2ca9875b 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-masked-load-store.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-masked-load-store.mlir
@@ -123,3 +123,21 @@ func.func @vector_maskedstore_with_alignment(%arg0 : memref<4x5xf32>, %arg1 : ve
   vector.maskedstore %arg0[%idx_0, %idx_4], %mask, %arg1 { alignment = 8 } : memref<4x5xf32>, vector<4xi1>, vector<4xf32>
   return
 }
+
+// CHECK-LABEL:  @vector_maskedstore_rank0
+// CHECK-SAME:   (%[[ARG0:.*]]: memref<12x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: vector<1xi1>, %[[ARG4:.*]]: vector<1xf32>) {
+// CHECK-DAG:    %[[C0:.*]] = arith.constant 0 : index
+// CHECK-NEXT:   %[[SUBVIEW:.*]] = memref.subview %[[ARG0]]{{\[}}%[[ARG1]], %[[ARG2]]] [1, 1] [1, 1] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
+// CHECK-NEXT:   %[[BASE:.*]], %[[OFFSET:.*]] = memref.extract_strided_metadata %[[SUBVIEW]] : memref<f32, strided<[], offset: ?>> -> memref<f32>, index
+// CHECK-NEXT:   %[[REINT:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[OFFSET]]], sizes: [1], strides: [1] : memref<f32> to memref<1xf32, strided<[1], offset: ?>>
+// CHECK-NEXT:   %[[M0:.*]] = vector.extract %[[ARG3]][0] : i1 from vector<1xi1>
+// CHECK-NEXT:   scf.if %[[M0]] {
+// CHECK-NEXT:     %[[V0:.*]] = vector.extract %[[ARG4]][0] : f32 from vector<1xf32>
+// CHECK-NEXT:     memref.store %[[V0]], %[[REINT]][%[[C0]]] : memref<1xf32, strided<[1], offset: ?>>
+func.func @vector_maskedstore_rank0(%arg0: memref<12x32xf32>, %arg1: index,
+                                   %arg2: index, %arg3: vector<1xi1>,
+                                   %arg4: vector<1xf32>) {
+  %subview = memref.subview %arg0[%arg1, %arg2] [1, 1] [1, 1] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
+  vector.maskedstore %subview[], %arg3, %arg4 : memref<f32, strided<[], offset: ?>>, vector<1xi1>, vector<1xf32>
+  return
+}



More information about the Mlir-commits mailing list