[Mlir-commits] [mlir] f964922 - [mlir][AMDGPU] Add better load/store lowering for full mask (#146748)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jul 10 08:11:24 PDT 2025


Author: Kunwar Grover
Date: 2025-07-10T16:11:19+01:00
New Revision: f96492221d2bf272053bca6660fc4bdd86592478

URL: https://github.com/llvm/llvm-project/commit/f96492221d2bf272053bca6660fc4bdd86592478
DIFF: https://github.com/llvm/llvm-project/commit/f96492221d2bf272053bca6660fc4bdd86592478.diff

LOG: [mlir][AMDGPU] Add better load/store lowering for full mask (#146748)

This patch adds a better maskedload/maskedstore lowering on amdgpu
backend for loads which are either fully masked or fully unmasked. For
these cases, we can either generate a oob buffer load with no if
condition, or we can generate a normal load with a if condition (if no
fat_raw_buffer space).

Added: 
    

Modified: 
    mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp
    mlir/test/Dialect/AMDGPU/maskedload-to-load.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp
index 9a368f372c296..60c8660658a95 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp
@@ -15,6 +15,7 @@
 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/PatternMatch.h"
@@ -52,13 +53,25 @@ static LogicalResult baseInBufferAddrSpace(PatternRewriter &rewriter,
 }
 
 static Value createVectorLoadForMaskedLoad(OpBuilder &builder, Location loc,
-                                           vector::MaskedLoadOp maskedOp) {
+                                           vector::MaskedLoadOp maskedOp,
+                                           bool passthru) {
   VectorType vectorType = maskedOp.getVectorType();
   Value load = builder.create<vector::LoadOp>(
       loc, vectorType, maskedOp.getBase(), maskedOp.getIndices());
-  Value res = builder.create<arith::SelectOp>(
-      loc, vectorType, maskedOp.getMask(), load, maskedOp.getPassThru());
-  return res;
+  if (passthru)
+    load = builder.create<arith::SelectOp>(loc, vectorType, maskedOp.getMask(),
+                                           load, maskedOp.getPassThru());
+  return load;
+}
+
+/// Check if the given value comes from a broadcasted i1 condition.
+static FailureOr<Value> matchFullMask(OpBuilder &b, Value val) {
+  auto broadcastOp = val.getDefiningOp<vector::BroadcastOp>();
+  if (!broadcastOp)
+    return failure();
+  if (isa<VectorType>(broadcastOp.getSourceType()))
+    return failure();
+  return broadcastOp.getSource();
 }
 
 static constexpr char kMaskedloadNeedsMask[] =
@@ -78,6 +91,16 @@ struct MaskedLoadLowering final : OpRewritePattern<vector::MaskedLoadOp> {
       return failure();
     }
 
+    // Check if this is either a full inbounds load or an empty, oob load. If
+    // so, take the fast path and don't generate an if condition, because we
+    // know doing the oob load is always safe.
+    if (succeeded(matchFullMask(rewriter, maskedOp.getMask()))) {
+      Value load = createVectorLoadForMaskedLoad(rewriter, maskedOp.getLoc(),
+                                                 maskedOp, /*passthru=*/true);
+      rewriter.replaceOp(maskedOp, load);
+      return success();
+    }
+
     Location loc = maskedOp.getLoc();
     Value src = maskedOp.getBase();
 
@@ -135,7 +158,8 @@ struct MaskedLoadLowering final : OpRewritePattern<vector::MaskedLoadOp> {
     };
 
     auto elseBuilder = [&](OpBuilder &builder, Location loc) {
-      Value res = createVectorLoadForMaskedLoad(builder, loc, maskedOp);
+      Value res = createVectorLoadForMaskedLoad(builder, loc, maskedOp,
+                                                /*passthru=*/true);
       rewriter.create<scf::YieldOp>(loc, res);
     };
 
@@ -148,11 +172,63 @@ struct MaskedLoadLowering final : OpRewritePattern<vector::MaskedLoadOp> {
   }
 };
 
+struct FullMaskedLoadToConditionalLoad
+    : OpRewritePattern<vector::MaskedLoadOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::MaskedLoadOp loadOp,
+                                PatternRewriter &rewriter) const override {
+    FailureOr<Value> maybeCond = matchFullMask(rewriter, loadOp.getMask());
+    if (failed(maybeCond)) {
+      return failure();
+    }
+
+    Value cond = maybeCond.value();
+    auto trueBuilder = [&](OpBuilder &builder, Location loc) {
+      Value res = createVectorLoadForMaskedLoad(builder, loc, loadOp,
+                                                /*passthru=*/false);
+      rewriter.create<scf::YieldOp>(loc, res);
+    };
+    auto falseBuilder = [&](OpBuilder &builder, Location loc) {
+      rewriter.create<scf::YieldOp>(loc, loadOp.getPassThru());
+    };
+    auto ifOp = rewriter.create<scf::IfOp>(loadOp.getLoc(), cond, trueBuilder,
+                                           falseBuilder);
+    rewriter.replaceOp(loadOp, ifOp);
+    return success();
+  }
+};
+
+struct FullMaskedStoreToConditionalStore
+    : OpRewritePattern<vector::MaskedStoreOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::MaskedStoreOp storeOp,
+                                PatternRewriter &rewriter) const override {
+    FailureOr<Value> maybeCond = matchFullMask(rewriter, storeOp.getMask());
+    if (failed(maybeCond)) {
+      return failure();
+    }
+    Value cond = maybeCond.value();
+
+    auto trueBuilder = [&](OpBuilder &builder, Location loc) {
+      rewriter.create<vector::StoreOp>(loc, storeOp.getValueToStore(),
+                                       storeOp.getBase(), storeOp.getIndices());
+      rewriter.create<scf::YieldOp>(loc);
+    };
+    auto ifOp = rewriter.create<scf::IfOp>(storeOp.getLoc(), cond, trueBuilder);
+    rewriter.replaceOp(storeOp, ifOp);
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::amdgpu::populateAmdgpuMaskedloadToLoadPatterns(
     RewritePatternSet &patterns, PatternBenefit benefit) {
-  patterns.add<MaskedLoadLowering>(patterns.getContext(), benefit);
+  patterns.add<MaskedLoadLowering, FullMaskedLoadToConditionalLoad,
+               FullMaskedStoreToConditionalStore>(patterns.getContext(),
+                                                  benefit);
 }
 
 struct AmdgpuMaskedloadToLoadPass final

diff  --git a/mlir/test/Dialect/AMDGPU/maskedload-to-load.mlir b/mlir/test/Dialect/AMDGPU/maskedload-to-load.mlir
index febe46bf7a759..f1d0ad545539a 100644
--- a/mlir/test/Dialect/AMDGPU/maskedload-to-load.mlir
+++ b/mlir/test/Dialect/AMDGPU/maskedload-to-load.mlir
@@ -114,3 +114,56 @@ func.func @transfer_scalar(%mem : memref<8x8xf32, #amdgpu.address_space<fat_raw_
 // CHECK: %[[IF:.*]] = scf.if
 // CHECK: %[[LOAD:.*]] = vector.load %[[ARG0]][%[[ARG1]], %[[ARG1]]]
 // CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[ARG3]]
+
+// -----
+
+func.func @transfer_to_maskedload_fatrawbuffer(%mem : memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, %idx : index, %mask : vector<4xi1>, %passthru : vector<4xf32>) -> vector<4xf32> {
+  %res = vector.maskedload %mem[%idx, %idx], %mask, %passthru : memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<4xi1>, vector<4xf32> into vector<4xf32>
+  return %res : vector<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @full_select_maskedload_fatrawbuffer_to_load
+func.func @full_select_maskedload_fatrawbuffer_to_load(%arg0: memref<8x8xf16, #amdgpu.address_space<fat_raw_buffer>>, %arg1: index, %arg2: i1, %arg3: vector<4xf16>) -> vector<4xf16> {
+  %0 = vector.broadcast %arg2 : i1 to vector<4xi1>
+  %1 = vector.maskedload %arg0[%arg1, %arg1], %0, %arg3 : memref<8x8xf16, #amdgpu.address_space<fat_raw_buffer>>, vector<4xi1>, vector<4xf16> into vector<4xf16>
+  return %1 : vector<4xf16>
+}
+// CHECK-NOT: vector.maskedload
+// CHECK: vector.load
+// CHECK: arith.select
+
+// -----
+
+// CHECK-LABEL: func.func @full_select_maskedload_to_load
+// CHECK-SAME: %[[MEM:.+]]: memref<8x8xf16>,
+// CHECK-SAME: %[[IDX:.+]]: index,
+// CHECK-SAME: %[[PRED:.+]]: i1,
+// CHECK-SAME: %[[PASSTHRU:.+]]: vector<4xf16>)
+func.func @full_select_maskedload_to_load(%arg0: memref<8x8xf16>, %arg1: index, %arg2: i1, %arg3: vector<4xf16>) -> vector<4xf16> {
+  %0 = vector.broadcast %arg2 : i1 to vector<4xi1>
+  %1 = vector.maskedload %arg0[%arg1, %arg1], %0, %arg3 : memref<8x8xf16>, vector<4xi1>, vector<4xf16> into vector<4xf16>
+  return %1 : vector<4xf16>
+}
+// CHECK-NOT: vector.maskedload
+// CHECK: scf.if %[[PRED]]
+// CHECK:   %[[LOAD:.+]] = vector.load
+// CHECK:   scf.yield %[[LOAD]]
+// CHECK: else
+// CHECK:   scf.yield %[[PASSTHRU]]
+
+// -----
+
+// CHECK-LABEL: func.func @full_mask_maskedstore_to_store
+// CHECK-SAME: %[[MEM:.+]]: memref<8x8xf16>,
+// CHECK-SAME: %[[IDX:.+]]: index,
+// CHECK-SAME: %[[PRED:.+]]: i1,
+func.func @full_mask_maskedstore_to_store(%arg0: memref<8x8xf16>, %arg1: index, %arg2: i1, %arg3: vector<4xf16>) {
+  %0 = vector.broadcast %arg2 : i1 to vector<4xi1>
+  vector.maskedstore %arg0[%arg1, %arg1], %0, %arg3 : memref<8x8xf16>, vector<4xi1>, vector<4xf16>
+  return
+}
+// CHECK-NOT: vector.maskedstore
+// CHECK: scf.if %[[PRED]]
+// CHECK:   vector.store


        


More information about the Mlir-commits mailing list