[Mlir-commits] [mlir] [AMDGPU] Do load unconditionally when converting from masked load (PR #162372)

Nirvedh Meshram llvmlistbot at llvm.org
Tue Oct 7 14:32:13 PDT 2025


https://github.com/nirvedhmeshram updated https://github.com/llvm/llvm-project/pull/162372

>From 5cd9dcf7b19761dfb20283adb32a9fa4286855c0 Mon Sep 17 00:00:00 2001
From: Nirvedh Meshram <nirvedh at gmail.com>
Date: Tue, 7 Oct 2025 14:01:41 -0700
Subject: [PATCH] [AMDGPU] Do load unconditionally when converting from masked
 load

Signed-off-by: Nirvedh Meshram <nirvedh at gmail.com>
---
 mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp | 8 +++-----
 mlir/test/Dialect/AMDGPU/maskedload-to-load.mlir        | 9 ++++-----
 2 files changed, 7 insertions(+), 10 deletions(-)

diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp
index f15c63c166e0a..e1024eeefd4bd 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp
@@ -88,13 +88,13 @@ struct MaskedLoadLowering final : OpRewritePattern<vector::MaskedLoadOp> {
     if (failed(baseInBufferAddrSpace(rewriter, maskedOp))) {
       return failure();
     }
+    Value load = createVectorLoadForMaskedLoad(rewriter, maskedOp.getLoc(),
+                                               maskedOp, /*passthru=*/true);
 
     // Check if this is either a full inbounds load or an empty, oob load. If
     // so, take the fast path and don't generate an if condition, because we
     // know doing the oob load is always safe.
     if (succeeded(matchFullMask(rewriter, maskedOp.getMask()))) {
-      Value load = createVectorLoadForMaskedLoad(rewriter, maskedOp.getLoc(),
-                                                 maskedOp, /*passthru=*/true);
       rewriter.replaceOp(maskedOp, load);
       return success();
     }
@@ -156,9 +156,7 @@ struct MaskedLoadLowering final : OpRewritePattern<vector::MaskedLoadOp> {
     };
 
     auto elseBuilder = [&](OpBuilder &builder, Location loc) {
-      Value res = createVectorLoadForMaskedLoad(builder, loc, maskedOp,
-                                                /*passthru=*/true);
-      scf::YieldOp::create(rewriter, loc, res);
+      scf::YieldOp::create(rewriter, loc, load);
     };
 
     auto ifOp =
diff --git a/mlir/test/Dialect/AMDGPU/maskedload-to-load.mlir b/mlir/test/Dialect/AMDGPU/maskedload-to-load.mlir
index f1d0ad545539a..fae0d3870d7fd 100644
--- a/mlir/test/Dialect/AMDGPU/maskedload-to-load.mlir
+++ b/mlir/test/Dialect/AMDGPU/maskedload-to-load.mlir
@@ -9,13 +9,13 @@ func.func @transfer_to_maskedload_fatrawbuffer(%mem : memref<8x8xf32, #amdgpu.ad
   %res = vector.maskedload %mem[%idx, %idx], %mask, %passthru : memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<4xi1>, vector<4xf32> into vector<4xf32>
   return %res : vector<4xf32>
 }
-
+// CHECK: %[[LOAD:.*]] = vector.load %arg0[%arg1, %arg1]
+// CHECK: %[[SELECT:.*]] = arith.select %[[ARG2]], %[[LOAD]]
 // CHECK: %[[IF:.*]] = scf.if
 // CHECK: vector.maskedload %[[ARG0]][%[[ARG1]], %[[ARG1]]]
 
 // CHECK: } else {
-// CHECK: %[[LOAD:.*]] = vector.load %arg0[%arg1, %arg1]
-// CHECK: %[[SELECT:.*]] = arith.select %[[ARG2]], %[[LOAD]]
+// CHECK: scf.yield %[[SELECT]]
 
 // CHECK: return %[[IF]] : vector<4xf32>
 
@@ -36,18 +36,17 @@ func.func @transfer_to_maskedload_fatrawbuffer_f16(%mem : memref<8x8xf16, #amdgp
 // CHECK-DAG: %[[BYTES:.*]] = arith.constant 2
 // CHECK-DAG: %[[C4:.*]] = arith.constant 4
 
+// CHECK: %[[LOAD:.*]] = vector.load %[[ARG0]][%[[ARG1]], %[[ARG2]]]
 // CHECK: %[[LINEAR:.*]] = affine.apply #map()[%[[ARG1]], %[[ARG2]]]
 // CHECK: %[[DELTA:.*]] = arith.subi %[[SIZE]], %[[LINEAR]]
 // CHECK: %[[COND1:.*]] = arith.cmpi ult, %[[DELTA]], %[[C4]]
 
 // CHECK: %[[REM:.*]] = arith.remui %[[DELTA]], %[[BYTES]]
 // CHECK: %[[COND2:.*]] = arith.cmpi ne, %[[REM]], %[[C0]]
-
 // CHECK: %[[COND:.*]] = arith.andi %[[COND1]], %[[COND2]]
 // CHECK: %[[IF:.*]] = scf.if %[[COND]] -> (vector<4xf16>) {
 // CHECK: vector.maskedload %[[ARG0]][%[[ARG1]], %[[ARG2]]]
 // CHECK: } else {
-// CHECK: %[[LOAD:.*]] = vector.load %[[ARG0]][%[[ARG1]], %[[ARG2]]]
 // CHECK: return %[[IF]] : vector<4xf16>
 
 // -----



More information about the Mlir-commits mailing list