[Mlir-commits] [mlir] [mlir][vector] Fix masked load/store emulation for rank-0 memrefs (PR #173325)
Prathamesh Tagore
llvmlistbot at llvm.org
Mon Dec 22 17:38:57 PST 2025
https://github.com/meshtag created https://github.com/llvm/llvm-project/pull/173325
Added rank‑0 handling to masked load/store emulation by reinterpreting rank‑0 memrefs as 1‑D buffers with a synthetic index, preventing empty‑indices crashes.
Fixes https://github.com/llvm/llvm-project/issues/131243
>From 163f2244fd1f99919b68c4d7fa4141b2375db43a Mon Sep 17 00:00:00 2001
From: Prathamesh Tagore <prathameshtagore at gmail.com>
Date: Tue, 23 Dec 2025 02:30:55 +0100
Subject: [PATCH] [mlir][vector] Fix masked load/store emulation for rank-0
memrefs
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Added rank‑0 handling to masked load/store emulation by reinterpreting rank‑0
memrefs as 1‑D buffers with a synthetic index, preventing empty‑indices
crashes.
---
.../VectorEmulateMaskedLoadStore.cpp | 22 +++++++++++++++++++
.../vector-emulate-masked-load-store.mlir | 18 +++++++++++++++
2 files changed, 40 insertions(+)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp
index 7acc120508a44..cfd478b27908b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp
@@ -19,6 +19,26 @@ using namespace mlir;
namespace {
+/// Ensure that `base` has at least one index by reinterpreting rank-0 memrefs
+/// as 1-D memrefs. This avoids crashing on rank-0 memrefs for the pass.
+static void ensureBaseHasIndex(PatternRewriter &rewriter, Location loc,
+ Value &base, SmallVectorImpl<Value> &indices,
+ int64_t maskLength) {
+ if (!indices.empty())
+ return;
+
+ // Rank-0 memrefs have no indices, reinterpret as 1-D to step through lanes.
+ auto meta = memref::ExtractStridedMetadataOp::create(rewriter, loc, base);
+ SmallVector<OpFoldResult, 1> sizes = {rewriter.getIndexAttr(maskLength)};
+ SmallVector<OpFoldResult, 1> strides = {rewriter.getIndexAttr(1)};
+ base = memref::ReinterpretCastOp::create(rewriter, loc, meta.getBaseBuffer(),
+ meta.getOffset(), sizes, strides);
+
+ Type indexType = rewriter.getIndexType();
+ indices.push_back(arith::ConstantOp::create(rewriter, loc, indexType,
+ IntegerAttr::get(indexType, 0)));
+}
+
/// Convert vector.maskedload
///
/// Before:
@@ -65,6 +85,7 @@ struct VectorMaskedLoadOpConverter final
Value base = maskedLoadOp.getBase();
Value iValue = maskedLoadOp.getPassThru();
auto indices = llvm::to_vector_of<Value>(maskedLoadOp.getIndices());
+ ensureBaseHasIndex(rewriter, loc, base, indices, maskLength);
Value one = arith::ConstantOp::create(rewriter, loc, indexType,
IntegerAttr::get(indexType, 1));
for (int64_t i = 0; i < maskLength; ++i) {
@@ -135,6 +156,7 @@ struct VectorMaskedStoreOpConverter final
Value value = maskedStoreOp.getValueToStore();
bool nontemporal = false;
auto indices = llvm::to_vector_of<Value>(maskedStoreOp.getIndices());
+ ensureBaseHasIndex(rewriter, loc, base, indices, maskLength);
Value one = arith::ConstantOp::create(rewriter, loc, indexType,
IntegerAttr::get(indexType, 1));
for (int64_t i = 0; i < maskLength; ++i) {
diff --git a/mlir/test/Dialect/Vector/vector-emulate-masked-load-store.mlir b/mlir/test/Dialect/Vector/vector-emulate-masked-load-store.mlir
index 6e5d68c859e2c..8bace2ca9875b 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-masked-load-store.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-masked-load-store.mlir
@@ -123,3 +123,21 @@ func.func @vector_maskedstore_with_alignment(%arg0 : memref<4x5xf32>, %arg1 : ve
vector.maskedstore %arg0[%idx_0, %idx_4], %mask, %arg1 { alignment = 8 } : memref<4x5xf32>, vector<4xi1>, vector<4xf32>
return
}
+
+// CHECK-LABEL: @vector_maskedstore_rank0
+// CHECK-SAME: (%[[ARG0:.*]]: memref<12x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: vector<1xi1>, %[[ARG4:.*]]: vector<1xf32>) {
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-NEXT: %[[SUBVIEW:.*]] = memref.subview %[[ARG0]]{{\[}}%[[ARG1]], %[[ARG2]]] [1, 1] [1, 1] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
+// CHECK-NEXT: %[[BASE:.*]], %[[OFFSET:.*]] = memref.extract_strided_metadata %[[SUBVIEW]] : memref<f32, strided<[], offset: ?>> -> memref<f32>, index
+// CHECK-NEXT: %[[REINT:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[OFFSET]]], sizes: [1], strides: [1] : memref<f32> to memref<1xf32, strided<[1], offset: ?>>
+// CHECK-NEXT: %[[M0:.*]] = vector.extract %[[ARG3]][0] : i1 from vector<1xi1>
+// CHECK-NEXT: scf.if %[[M0]] {
+// CHECK-NEXT: %[[V0:.*]] = vector.extract %[[ARG4]][0] : f32 from vector<1xf32>
+// CHECK-NEXT: memref.store %[[V0]], %[[REINT]][%[[C0]]] : memref<1xf32, strided<[1], offset: ?>>
+func.func @vector_maskedstore_rank0(%arg0: memref<12x32xf32>, %arg1: index,
+ %arg2: index, %arg3: vector<1xi1>,
+ %arg4: vector<1xf32>) {
+ %subview = memref.subview %arg0[%arg1, %arg2] [1, 1] [1, 1] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
+ vector.maskedstore %subview[], %arg3, %arg4 : memref<f32, strided<[], offset: ?>>, vector<1xi1>, vector<1xf32>
+ return
+}
More information about the Mlir-commits
mailing list