[Mlir-commits] [mlir] [mlir][AMDGPU][NFC] Fix overlapping masked load refinements (PR #159805)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Sep 19 09:04:29 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-amdgpu

Author: Krzysztof Drewniak (krzysz00)

<details>
<summary>Changes</summary>

The two paterns for handlig vector.maskedload on AMD GPUs had an overlap - both the "scalar mask becomes an if statement" pattern and the "masked loads become a normal load + a select on buffers" patterns could handle a load with a broadcast mask on a fat buffer resource.

This commet add checks to resolve the overlap.

---
Full diff: https://github.com/llvm/llvm-project/pull/159805.diff


1 Files Affected:

- (modified) mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp (+24-10) 


``````````diff
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp
index f15c63c166e0a..89ef51f922cad 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp
@@ -33,19 +33,18 @@ using namespace mlir::amdgpu;
 
 /// This pattern supports lowering of: `vector.maskedload` to `vector.load`
 /// and `arith.select` if the memref is in buffer address space.
-static LogicalResult baseInBufferAddrSpace(PatternRewriter &rewriter,
-                                           vector::MaskedLoadOp maskedOp) {
-  auto memRefType = dyn_cast<MemRefType>(maskedOp.getBase().getType());
+static LogicalResult hasBufferAddressSpace(Type type) {
+  auto memRefType = dyn_cast<MemRefType>(type);
   if (!memRefType)
-    return rewriter.notifyMatchFailure(maskedOp, "not a memref source");
+    return failure();
 
   Attribute addrSpace = memRefType.getMemorySpace();
   if (!isa_and_nonnull<amdgpu::AddressSpaceAttr>(addrSpace))
-    return rewriter.notifyMatchFailure(maskedOp, "no address space");
+    return failure();
 
   if (dyn_cast<amdgpu::AddressSpaceAttr>(addrSpace).getValue() !=
       amdgpu::AddressSpace::FatRawBuffer)
-    return rewriter.notifyMatchFailure(maskedOp, "not in buffer address space");
+    return failure();
 
   return success();
 }
@@ -83,10 +82,11 @@ struct MaskedLoadLowering final : OpRewritePattern<vector::MaskedLoadOp> {
   LogicalResult matchAndRewrite(vector::MaskedLoadOp maskedOp,
                                 PatternRewriter &rewriter) const override {
     if (maskedOp->hasAttr(kMaskedloadNeedsMask))
-      return failure();
+      return rewriter.notifyMatchFailure(maskedOp, "already rewritten");
 
-    if (failed(baseInBufferAddrSpace(rewriter, maskedOp))) {
-      return failure();
+    if (failed(hasBufferAddressSpace(maskedOp.getBase().getType()))) {
+      return rewriter.notifyMatchFailure(
+          maskedOp, "isn't a load from a fat buffer resource");
     }
 
     // Check if this is either a full inbounds load or an empty, oob load. If
@@ -176,9 +176,14 @@ struct FullMaskedLoadToConditionalLoad
 
   LogicalResult matchAndRewrite(vector::MaskedLoadOp loadOp,
                                 PatternRewriter &rewriter) const override {
+    if (succeeded(hasBufferAddressSpace(loadOp.getBase().getType())))
+      return rewriter.notifyMatchFailure(
+          loadOp, "buffer loads are handled by a more specialized pattern");
+
     FailureOr<Value> maybeCond = matchFullMask(rewriter, loadOp.getMask());
     if (failed(maybeCond)) {
-      return failure();
+      return rewriter.notifyMatchFailure(loadOp,
+                                         "isn't loading a broadcasted scalar");
     }
 
     Value cond = maybeCond.value();
@@ -203,6 +208,15 @@ struct FullMaskedStoreToConditionalStore
 
   LogicalResult matchAndRewrite(vector::MaskedStoreOp storeOp,
                                 PatternRewriter &rewriter) const override {
+    // A condition-free implementation of fully masked stores requires
+    // 1) an accessor for the num_records field on buffer resources/fat pointers
+    // 2) knowledge that said field will always be set accurately - that is,
+    // that writes to x < num_records of offset wouldn't trap, which is
+    // something a pattern user would need to assert or we'd need to prove.
+    //
+    // Therefore, conditional stores to buffers still go down this path at
+    // present.
+
     FailureOr<Value> maybeCond = matchFullMask(rewriter, storeOp.getMask());
     if (failed(maybeCond)) {
       return failure();

``````````

</details>


https://github.com/llvm/llvm-project/pull/159805


More information about the Mlir-commits mailing list