[Mlir-commits] [mlir] [mlir][AMDGPU][NFC] Fix overlapping masked load refinements (PR #159805)
Krzysztof Drewniak
llvmlistbot at llvm.org
Fri Sep 19 09:03:59 PDT 2025
https://github.com/krzysz00 created https://github.com/llvm/llvm-project/pull/159805
The two paterns for handlig vector.maskedload on AMD GPUs had an overlap - both the "scalar mask becomes an if statement" pattern and the "masked loads become a normal load + a select on buffers" patterns could handle a load with a broadcast mask on a fat buffer resource.
This commet add checks to resolve the overlap.
>From 6cb3912bee7c05bd3eceac65c2d2cf6e8a33369c Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <Krzysztof.Drewniak at amd.com>
Date: Fri, 19 Sep 2025 16:00:32 +0000
Subject: [PATCH] [mlir][AMDGPU][NFC] Fix overlapping masked load refinements
The two paterns for handlig vector.maskedload on AMD GPUs had an
overlap - both the "scalar mask becomes an if statement" pattern and
the "masked loads become a normal load + a select on buffers" patterns
could handle a load with a broadcast mask on a fat buffer resource.
This commet add checks to resolve the overlap.
---
.../AMDGPU/Transforms/MaskedloadToLoad.cpp | 34 +++++++++++++------
1 file changed, 24 insertions(+), 10 deletions(-)
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp
index f15c63c166e0a..89ef51f922cad 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp
@@ -33,19 +33,18 @@ using namespace mlir::amdgpu;
/// This pattern supports lowering of: `vector.maskedload` to `vector.load`
/// and `arith.select` if the memref is in buffer address space.
-static LogicalResult baseInBufferAddrSpace(PatternRewriter &rewriter,
- vector::MaskedLoadOp maskedOp) {
- auto memRefType = dyn_cast<MemRefType>(maskedOp.getBase().getType());
+static LogicalResult hasBufferAddressSpace(Type type) {
+ auto memRefType = dyn_cast<MemRefType>(type);
if (!memRefType)
- return rewriter.notifyMatchFailure(maskedOp, "not a memref source");
+ return failure();
Attribute addrSpace = memRefType.getMemorySpace();
if (!isa_and_nonnull<amdgpu::AddressSpaceAttr>(addrSpace))
- return rewriter.notifyMatchFailure(maskedOp, "no address space");
+ return failure();
if (dyn_cast<amdgpu::AddressSpaceAttr>(addrSpace).getValue() !=
amdgpu::AddressSpace::FatRawBuffer)
- return rewriter.notifyMatchFailure(maskedOp, "not in buffer address space");
+ return failure();
return success();
}
@@ -83,10 +82,11 @@ struct MaskedLoadLowering final : OpRewritePattern<vector::MaskedLoadOp> {
LogicalResult matchAndRewrite(vector::MaskedLoadOp maskedOp,
PatternRewriter &rewriter) const override {
if (maskedOp->hasAttr(kMaskedloadNeedsMask))
- return failure();
+ return rewriter.notifyMatchFailure(maskedOp, "already rewritten");
- if (failed(baseInBufferAddrSpace(rewriter, maskedOp))) {
- return failure();
+ if (failed(hasBufferAddressSpace(maskedOp.getBase().getType()))) {
+ return rewriter.notifyMatchFailure(
+ maskedOp, "isn't a load from a fat buffer resource");
}
// Check if this is either a full inbounds load or an empty, oob load. If
@@ -176,9 +176,14 @@ struct FullMaskedLoadToConditionalLoad
LogicalResult matchAndRewrite(vector::MaskedLoadOp loadOp,
PatternRewriter &rewriter) const override {
+ if (succeeded(hasBufferAddressSpace(loadOp.getBase().getType())))
+ return rewriter.notifyMatchFailure(
+ loadOp, "buffer loads are handled by a more specialized pattern");
+
FailureOr<Value> maybeCond = matchFullMask(rewriter, loadOp.getMask());
if (failed(maybeCond)) {
- return failure();
+ return rewriter.notifyMatchFailure(loadOp,
+ "isn't loading a broadcasted scalar");
}
Value cond = maybeCond.value();
@@ -203,6 +208,15 @@ struct FullMaskedStoreToConditionalStore
LogicalResult matchAndRewrite(vector::MaskedStoreOp storeOp,
PatternRewriter &rewriter) const override {
+ // A condition-free implementation of fully masked stores requires
+ // 1) an accessor for the num_records field on buffer resources/fat pointers
+ // 2) knowledge that said field will always be set accurately - that is,
+ // that writes to x < num_records of offset wouldn't trap, which is
+ // something a pattern user would need to assert or we'd need to prove.
+ //
+ // Therefore, conditional stores to buffers still go down this path at
+ // present.
+
FailureOr<Value> maybeCond = matchFullMask(rewriter, storeOp.getMask());
if (failed(maybeCond)) {
return failure();
More information about the Mlir-commits
mailing list