[Mlir-commits] [mlir] f964922 - [mlir][AMDGPU] Add better load/store lowering for full mask (#146748)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jul 10 08:11:24 PDT 2025
Author: Kunwar Grover
Date: 2025-07-10T16:11:19+01:00
New Revision: f96492221d2bf272053bca6660fc4bdd86592478
URL: https://github.com/llvm/llvm-project/commit/f96492221d2bf272053bca6660fc4bdd86592478
DIFF: https://github.com/llvm/llvm-project/commit/f96492221d2bf272053bca6660fc4bdd86592478.diff
LOG: [mlir][AMDGPU] Add better load/store lowering for full mask (#146748)
This patch adds a better maskedload/maskedstore lowering on amdgpu
backend for loads which are either fully masked or fully unmasked. For
these cases, we can either generate a oob buffer load with no if
condition, or we can generate a normal load with a if condition (if no
fat_raw_buffer space).
Added:
Modified:
mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp
mlir/test/Dialect/AMDGPU/maskedload-to-load.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp
index 9a368f372c296..60c8660658a95 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp
@@ -15,6 +15,7 @@
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
@@ -52,13 +53,25 @@ static LogicalResult baseInBufferAddrSpace(PatternRewriter &rewriter,
}
static Value createVectorLoadForMaskedLoad(OpBuilder &builder, Location loc,
- vector::MaskedLoadOp maskedOp) {
+ vector::MaskedLoadOp maskedOp,
+ bool passthru) {
VectorType vectorType = maskedOp.getVectorType();
Value load = builder.create<vector::LoadOp>(
loc, vectorType, maskedOp.getBase(), maskedOp.getIndices());
- Value res = builder.create<arith::SelectOp>(
- loc, vectorType, maskedOp.getMask(), load, maskedOp.getPassThru());
- return res;
+ if (passthru)
+ load = builder.create<arith::SelectOp>(loc, vectorType, maskedOp.getMask(),
+ load, maskedOp.getPassThru());
+ return load;
+}
+
+/// Check if the given value comes from a broadcasted i1 condition.
+static FailureOr<Value> matchFullMask(OpBuilder &b, Value val) {
+ auto broadcastOp = val.getDefiningOp<vector::BroadcastOp>();
+ if (!broadcastOp)
+ return failure();
+ if (isa<VectorType>(broadcastOp.getSourceType()))
+ return failure();
+ return broadcastOp.getSource();
}
static constexpr char kMaskedloadNeedsMask[] =
@@ -78,6 +91,16 @@ struct MaskedLoadLowering final : OpRewritePattern<vector::MaskedLoadOp> {
return failure();
}
+ // Check if this is either a full inbounds load or an empty, oob load. If
+ // so, take the fast path and don't generate an if condition, because we
+ // know doing the oob load is always safe.
+ if (succeeded(matchFullMask(rewriter, maskedOp.getMask()))) {
+ Value load = createVectorLoadForMaskedLoad(rewriter, maskedOp.getLoc(),
+ maskedOp, /*passthru=*/true);
+ rewriter.replaceOp(maskedOp, load);
+ return success();
+ }
+
Location loc = maskedOp.getLoc();
Value src = maskedOp.getBase();
@@ -135,7 +158,8 @@ struct MaskedLoadLowering final : OpRewritePattern<vector::MaskedLoadOp> {
};
auto elseBuilder = [&](OpBuilder &builder, Location loc) {
- Value res = createVectorLoadForMaskedLoad(builder, loc, maskedOp);
+ Value res = createVectorLoadForMaskedLoad(builder, loc, maskedOp,
+ /*passthru=*/true);
rewriter.create<scf::YieldOp>(loc, res);
};
@@ -148,11 +172,63 @@ struct MaskedLoadLowering final : OpRewritePattern<vector::MaskedLoadOp> {
}
};
+struct FullMaskedLoadToConditionalLoad
+ : OpRewritePattern<vector::MaskedLoadOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::MaskedLoadOp loadOp,
+ PatternRewriter &rewriter) const override {
+ FailureOr<Value> maybeCond = matchFullMask(rewriter, loadOp.getMask());
+ if (failed(maybeCond)) {
+ return failure();
+ }
+
+ Value cond = maybeCond.value();
+ auto trueBuilder = [&](OpBuilder &builder, Location loc) {
+ Value res = createVectorLoadForMaskedLoad(builder, loc, loadOp,
+ /*passthru=*/false);
+ rewriter.create<scf::YieldOp>(loc, res);
+ };
+ auto falseBuilder = [&](OpBuilder &builder, Location loc) {
+ rewriter.create<scf::YieldOp>(loc, loadOp.getPassThru());
+ };
+ auto ifOp = rewriter.create<scf::IfOp>(loadOp.getLoc(), cond, trueBuilder,
+ falseBuilder);
+ rewriter.replaceOp(loadOp, ifOp);
+ return success();
+ }
+};
+
+struct FullMaskedStoreToConditionalStore
+ : OpRewritePattern<vector::MaskedStoreOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::MaskedStoreOp storeOp,
+ PatternRewriter &rewriter) const override {
+ FailureOr<Value> maybeCond = matchFullMask(rewriter, storeOp.getMask());
+ if (failed(maybeCond)) {
+ return failure();
+ }
+ Value cond = maybeCond.value();
+
+ auto trueBuilder = [&](OpBuilder &builder, Location loc) {
+ rewriter.create<vector::StoreOp>(loc, storeOp.getValueToStore(),
+ storeOp.getBase(), storeOp.getIndices());
+ rewriter.create<scf::YieldOp>(loc);
+ };
+ auto ifOp = rewriter.create<scf::IfOp>(storeOp.getLoc(), cond, trueBuilder);
+ rewriter.replaceOp(storeOp, ifOp);
+ return success();
+ }
+};
+
} // namespace
void mlir::amdgpu::populateAmdgpuMaskedloadToLoadPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
- patterns.add<MaskedLoadLowering>(patterns.getContext(), benefit);
+ patterns.add<MaskedLoadLowering, FullMaskedLoadToConditionalLoad,
+ FullMaskedStoreToConditionalStore>(patterns.getContext(),
+ benefit);
}
struct AmdgpuMaskedloadToLoadPass final
diff --git a/mlir/test/Dialect/AMDGPU/maskedload-to-load.mlir b/mlir/test/Dialect/AMDGPU/maskedload-to-load.mlir
index febe46bf7a759..f1d0ad545539a 100644
--- a/mlir/test/Dialect/AMDGPU/maskedload-to-load.mlir
+++ b/mlir/test/Dialect/AMDGPU/maskedload-to-load.mlir
@@ -114,3 +114,56 @@ func.func @transfer_scalar(%mem : memref<8x8xf32, #amdgpu.address_space<fat_raw_
// CHECK: %[[IF:.*]] = scf.if
// CHECK: %[[LOAD:.*]] = vector.load %[[ARG0]][%[[ARG1]], %[[ARG1]]]
// CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[ARG3]]
+
+// -----
+
+func.func @transfer_to_maskedload_fatrawbuffer(%mem : memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, %idx : index, %mask : vector<4xi1>, %passthru : vector<4xf32>) -> vector<4xf32> {
+ %res = vector.maskedload %mem[%idx, %idx], %mask, %passthru : memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<4xi1>, vector<4xf32> into vector<4xf32>
+ return %res : vector<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @full_select_maskedload_fatrawbuffer_to_load
+func.func @full_select_maskedload_fatrawbuffer_to_load(%arg0: memref<8x8xf16, #amdgpu.address_space<fat_raw_buffer>>, %arg1: index, %arg2: i1, %arg3: vector<4xf16>) -> vector<4xf16> {
+ %0 = vector.broadcast %arg2 : i1 to vector<4xi1>
+ %1 = vector.maskedload %arg0[%arg1, %arg1], %0, %arg3 : memref<8x8xf16, #amdgpu.address_space<fat_raw_buffer>>, vector<4xi1>, vector<4xf16> into vector<4xf16>
+ return %1 : vector<4xf16>
+}
+// CHECK-NOT: vector.maskedload
+// CHECK: vector.load
+// CHECK: arith.select
+
+// -----
+
+// CHECK-LABEL: func.func @full_select_maskedload_to_load
+// CHECK-SAME: %[[MEM:.+]]: memref<8x8xf16>,
+// CHECK-SAME: %[[IDX:.+]]: index,
+// CHECK-SAME: %[[PRED:.+]]: i1,
+// CHECK-SAME: %[[PASSTHRU:.+]]: vector<4xf16>)
+func.func @full_select_maskedload_to_load(%arg0: memref<8x8xf16>, %arg1: index, %arg2: i1, %arg3: vector<4xf16>) -> vector<4xf16> {
+ %0 = vector.broadcast %arg2 : i1 to vector<4xi1>
+ %1 = vector.maskedload %arg0[%arg1, %arg1], %0, %arg3 : memref<8x8xf16>, vector<4xi1>, vector<4xf16> into vector<4xf16>
+ return %1 : vector<4xf16>
+}
+// CHECK-NOT: vector.maskedload
+// CHECK: scf.if %[[PRED]]
+// CHECK: %[[LOAD:.+]] = vector.load
+// CHECK: scf.yield %[[LOAD]]
+// CHECK: else
+// CHECK: scf.yield %[[PASSTHRU]]
+
+// -----
+
+// CHECK-LABEL: func.func @full_mask_maskedstore_to_store
+// CHECK-SAME: %[[MEM:.+]]: memref<8x8xf16>,
+// CHECK-SAME: %[[IDX:.+]]: index,
+// CHECK-SAME: %[[PRED:.+]]: i1,
+func.func @full_mask_maskedstore_to_store(%arg0: memref<8x8xf16>, %arg1: index, %arg2: i1, %arg3: vector<4xf16>) {
+ %0 = vector.broadcast %arg2 : i1 to vector<4xi1>
+ vector.maskedstore %arg0[%arg1, %arg1], %0, %arg3 : memref<8x8xf16>, vector<4xi1>, vector<4xf16>
+ return
+}
+// CHECK-NOT: vector.maskedstore
+// CHECK: scf.if %[[PRED]]
+// CHECK: vector.store
More information about the Mlir-commits
mailing list