[Mlir-commits] [mlir] [mlir][AMDGPU][NFC] Fix overlapping masked load refinements (PR #159805)

Krzysztof Drewniak llvmlistbot at llvm.org
Fri Sep 19 09:03:59 PDT 2025


https://github.com/krzysz00 created https://github.com/llvm/llvm-project/pull/159805

The two paterns for handlig vector.maskedload on AMD GPUs had an overlap - both the "scalar mask becomes an if statement" pattern and the "masked loads become a normal load + a select on buffers" patterns could handle a load with a broadcast mask on a fat buffer resource.

This commet add checks to resolve the overlap.

>From 6cb3912bee7c05bd3eceac65c2d2cf6e8a33369c Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <Krzysztof.Drewniak at amd.com>
Date: Fri, 19 Sep 2025 16:00:32 +0000
Subject: [PATCH] [mlir][AMDGPU][NFC] Fix overlapping masked load refinements

The two paterns for handlig vector.maskedload on AMD GPUs had an
overlap - both the "scalar mask becomes an if statement" pattern and
the "masked loads become a normal load + a select on buffers" patterns
could handle a load with a broadcast mask on a fat buffer resource.

This commet add checks to resolve the overlap.
---
 .../AMDGPU/Transforms/MaskedloadToLoad.cpp    | 34 +++++++++++++------
 1 file changed, 24 insertions(+), 10 deletions(-)

diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp
index f15c63c166e0a..89ef51f922cad 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp
@@ -33,19 +33,18 @@ using namespace mlir::amdgpu;
 
 /// This pattern supports lowering of: `vector.maskedload` to `vector.load`
 /// and `arith.select` if the memref is in buffer address space.
-static LogicalResult baseInBufferAddrSpace(PatternRewriter &rewriter,
-                                           vector::MaskedLoadOp maskedOp) {
-  auto memRefType = dyn_cast<MemRefType>(maskedOp.getBase().getType());
+static LogicalResult hasBufferAddressSpace(Type type) {
+  auto memRefType = dyn_cast<MemRefType>(type);
   if (!memRefType)
-    return rewriter.notifyMatchFailure(maskedOp, "not a memref source");
+    return failure();
 
   Attribute addrSpace = memRefType.getMemorySpace();
   if (!isa_and_nonnull<amdgpu::AddressSpaceAttr>(addrSpace))
-    return rewriter.notifyMatchFailure(maskedOp, "no address space");
+    return failure();
 
   if (dyn_cast<amdgpu::AddressSpaceAttr>(addrSpace).getValue() !=
       amdgpu::AddressSpace::FatRawBuffer)
-    return rewriter.notifyMatchFailure(maskedOp, "not in buffer address space");
+    return failure();
 
   return success();
 }
@@ -83,10 +82,11 @@ struct MaskedLoadLowering final : OpRewritePattern<vector::MaskedLoadOp> {
   LogicalResult matchAndRewrite(vector::MaskedLoadOp maskedOp,
                                 PatternRewriter &rewriter) const override {
     if (maskedOp->hasAttr(kMaskedloadNeedsMask))
-      return failure();
+      return rewriter.notifyMatchFailure(maskedOp, "already rewritten");
 
-    if (failed(baseInBufferAddrSpace(rewriter, maskedOp))) {
-      return failure();
+    if (failed(hasBufferAddressSpace(maskedOp.getBase().getType()))) {
+      return rewriter.notifyMatchFailure(
+          maskedOp, "isn't a load from a fat buffer resource");
     }
 
     // Check if this is either a full inbounds load or an empty, oob load. If
@@ -176,9 +176,14 @@ struct FullMaskedLoadToConditionalLoad
 
   LogicalResult matchAndRewrite(vector::MaskedLoadOp loadOp,
                                 PatternRewriter &rewriter) const override {
+    if (succeeded(hasBufferAddressSpace(loadOp.getBase().getType())))
+      return rewriter.notifyMatchFailure(
+          loadOp, "buffer loads are handled by a more specialized pattern");
+
     FailureOr<Value> maybeCond = matchFullMask(rewriter, loadOp.getMask());
     if (failed(maybeCond)) {
-      return failure();
+      return rewriter.notifyMatchFailure(loadOp,
+                                         "isn't loading a broadcasted scalar");
     }
 
     Value cond = maybeCond.value();
@@ -203,6 +208,15 @@ struct FullMaskedStoreToConditionalStore
 
   LogicalResult matchAndRewrite(vector::MaskedStoreOp storeOp,
                                 PatternRewriter &rewriter) const override {
+    // A condition-free implementation of fully masked stores requires
+    // 1) an accessor for the num_records field on buffer resources/fat pointers
+    // 2) knowledge that said field will always be set accurately - that is,
+    // that writes to x < num_records of offset wouldn't trap, which is
+    // something a pattern user would need to assert or we'd need to prove.
+    //
+    // Therefore, conditional stores to buffers still go down this path at
+    // present.
+
     FailureOr<Value> maybeCond = matchFullMask(rewriter, storeOp.getMask());
     if (failed(maybeCond)) {
       return failure();



More information about the Mlir-commits mailing list