[llvm] [mlir] [MLIR][AMDGPU] Adding dynamic size check to avoid subword buffer load (PR #135014)

Zhuoran Yin via llvm-commits llvm-commits at lists.llvm.org
Mon Apr 14 07:20:02 PDT 2025


https://github.com/jerryyin updated https://github.com/llvm/llvm-project/pull/135014

>From 27c5497a5dc96fe8a05e775e0c9f793927c5d59c Mon Sep 17 00:00:00 2001
From: jerryyin <zhuoryin at amd.com>
Date: Tue, 8 Apr 2025 17:30:05 +0000
Subject: [PATCH 1/2] Adding dynamic size check to avoid subword buffer load

---
 .../mlir/Dialect/AMDGPU/Transforms/Passes.td  | 13 ++-
 .../Dialect/AMDGPU/Transforms/CMakeLists.txt  |  1 +
 .../AMDGPU/Transforms/TransferReadToLoad.cpp  | 93 ++++++++++++++++---
 .../Dialect/AMDGPU/transfer-read-to-load.mlir | 48 +++++++++-
 .../llvm-project-overlay/mlir/BUILD.bazel     |  1 +
 5 files changed, 137 insertions(+), 19 deletions(-)

diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
index 761caa448a57c..0e858108acf35 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
@@ -54,15 +54,20 @@ def AmdgpuResolveStridedMetadataPass : Pass<"amdgpu-resolve-strided-metadata"> {
 def AmdgpuTransferReadToLoadPass : Pass<"amdgpu-transfer-read-to-load"> {
   let summary = "Lower the operations from the vector transfer_read to vector load";
   let description = [{
-    This pass creates a transfer read op lowering. A vector trasfer read op
-    will be lowered to a combination of vector.load, arith.select and
-    vector.broadcast.
+    This pass creates a transfer read op lowering optimization. The lowering
+    will produce a conditional check at runtime. If within bounds, a vector
+    trasfer read op will be lowered to a combination of vector.load, arith.select
+    and vector.broadcast. If not, it will fallback to the default lowering
+    of the transfer_read op.
 
     This pattern will make it possible for masked transfer_read to be lowered
     towards buffer load with bounds check, allowing a more optimized global
     load accessing pattern compared with existing implementation of
     llvm.intr.masked.load on vectors.
   }];
-  let dependentDialects = [];
+  let dependentDialects = [
+    "scf::SCFDialect",
+    "memref::MemRefDialect"
+  ];
 }
 #endif // MLIR_DIALECT_AMDGPU_TRANSFORMS_PASSES_TD_
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
index bc5b6e9186449..8709a27e0168e 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
@@ -14,6 +14,7 @@ add_mlir_dialect_library(MLIRAMDGPUTransforms
   MLIRAMDGPUUtils
   MLIRArithDialect
   MLIRMemRefDialect
+  MLIRSCFDialect
   MLIRVectorDialect
   MLIRControlFlowDialect
   MLIRFuncDialect
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp
index 3c1a2eb962037..519f695d99f91 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp
@@ -9,6 +9,8 @@
 #include "mlir/Dialect/AMDGPU/Transforms/Passes.h"
 
 #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/PatternMatch.h"
@@ -108,6 +110,8 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
 
   LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
                                 PatternRewriter &rewriter) const override {
+    if (readOp->hasAttr("amdgpu.transformed"))
+      return failure();
 
     bool requiresBroadcasting = false;
     VectorType unbroadcastedVectorType;
@@ -117,20 +121,85 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
     }
 
     Location loc = readOp.getLoc();
-    Value fill = rewriter.create<vector::SplatOp>(loc, unbroadcastedVectorType,
-                                                  readOp.getPadding());
-    Value load = rewriter.create<vector::LoadOp>(
-        loc, unbroadcastedVectorType, readOp.getSource(), readOp.getIndices());
-    Value res = rewriter.create<arith::SelectOp>(loc, unbroadcastedVectorType,
-                                                 readOp.getMask(), load, fill);
-
-    // Insert a broadcasting op if required.
-    if (requiresBroadcasting) {
-      res = rewriter.create<vector::BroadcastOp>(loc, readOp.getVectorType(),
-                                                 res);
+    Value src = readOp.getSource();
+    MemRefType memRefType = cast<MemRefType>(src.getType());
+    ArrayRef<int64_t> shape = memRefType.getShape();
+
+    Value linearIndex = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+    Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+    Value stride = one;
+
+    // Compute the linear index by linearIndex += indices[i] * stride
+    for (int i = shape.size() - 1; i >= 0; --i) {
+      Value currentIndex = readOp.getIndices()[i];
+      Value strideIndexed =
+          rewriter.create<arith::MulIOp>(loc, currentIndex, stride);
+      linearIndex =
+          rewriter.create<arith::AddIOp>(loc, linearIndex, strideIndexed);
+
+      if (i == 0)
+        break;
+
+      // Update stride for the next dimension
+      Value nextStride;
+      if (shape[i] != ShapedType::kDynamic) {
+        nextStride = rewriter.create<arith::ConstantIndexOp>(loc, shape[i]);
+      } else {
+        nextStride = rewriter.create<memref::DimOp>(loc, src, i);
+      }
+      stride = rewriter.create<arith::MulIOp>(loc, stride, nextStride);
+    }
+
+    // Add vector size offset to linear index
+    VectorType vectorType = readOp.getVectorType();
+    int64_t vectorSize = vectorType.getNumElements();
+    Value vectorSizeOffset =
+        rewriter.create<arith::ConstantIndexOp>(loc, vectorSize);
+    Value upperBoundIndex =
+        rewriter.create<arith::AddIOp>(loc, linearIndex, vectorSizeOffset);
+
+    Value totalSize = one;
+    for (size_t i = 0; i < shape.size(); ++i) {
+      Value dimensionSize;
+      if (shape[i] != ShapedType::kDynamic) {
+        dimensionSize = rewriter.create<arith::ConstantIndexOp>(loc, shape[i]);
+      } else {
+        dimensionSize = rewriter.create<memref::DimOp>(loc, src, i);
+      }
+      totalSize = rewriter.create<arith::MulIOp>(loc, totalSize, dimensionSize);
     }
 
-    rewriter.replaceOp(readOp, res);
+    Value isInBounds = rewriter.create<arith::CmpIOp>(
+        loc, arith::CmpIPredicate::ule, upperBoundIndex, totalSize);
+
+    auto thenBuilder = [&](OpBuilder &builder, Location loc) {
+      Value fill = builder.create<vector::SplatOp>(loc, unbroadcastedVectorType,
+                                                   readOp.getPadding());
+      Value load = builder.create<vector::LoadOp>(loc, unbroadcastedVectorType,
+                                                  readOp.getSource(),
+                                                  readOp.getIndices());
+      Value res = builder.create<arith::SelectOp>(loc, unbroadcastedVectorType,
+                                                  readOp.getMask(), load, fill);
+
+      // Insert a broadcasting op if required.
+      if (requiresBroadcasting) {
+        res = builder.create<vector::BroadcastOp>(loc, readOp.getVectorType(),
+                                                  res);
+      }
+      rewriter.create<scf::YieldOp>(loc, res);
+    };
+
+    auto elseBuilder = [&](OpBuilder &builder, Location loc) {
+      Operation *read = builder.clone(*readOp.getOperation());
+      read->setAttr("amdgpu.transformed", builder.getUnitAttr());
+      Value readResult = read->getResult(0);
+      builder.create<scf::YieldOp>(loc, readResult);
+    };
+
+    auto ifOp =
+        rewriter.create<scf::IfOp>(loc, isInBounds, thenBuilder, elseBuilder);
+
+    rewriter.replaceOp(readOp, ifOp);
 
     return success();
   }
diff --git a/mlir/test/Dialect/AMDGPU/transfer-read-to-load.mlir b/mlir/test/Dialect/AMDGPU/transfer-read-to-load.mlir
index 3e1283579f2b1..776a047e6a85d 100644
--- a/mlir/test/Dialect/AMDGPU/transfer-read-to-load.mlir
+++ b/mlir/test/Dialect/AMDGPU/transfer-read-to-load.mlir
@@ -10,10 +10,54 @@ func.func @transfer_to_maskedload_fatrawbuffer(%mem : memref<8x8xf32, #amdgpu.ad
   return %res : vector<4xf32>
 }
 // CHECK: %[[CST:.*]] = arith.constant 0.0
+// CHECK: %[[C0:.*]] = arith.constant 0
+// CHECK: %[[C1:.*]] = arith.constant 1
+// CHECK: %[[MUL0:.*]] = arith.muli %[[ARG1]], %[[C1]]
+// CHECK: %[[ADD0:.*]] = arith.addi %[[C0]], %[[MUL0]]
+// CHECK: %[[C8:.*]] = arith.constant 8
+// CHECK: %[[MUL1:.*]] = arith.muli %[[C1]], %[[C8]]
+// CHECK: %[[MUL2:.*]] = arith.muli %[[ARG1]], %[[MUL1]]
+// CHECK: %[[ADD1:.*]] = arith.addi %[[ADD0]], %[[MUL2]]
+// CHECK: %[[C4:.*]] = arith.constant 4
+// CHECK: %[[ADD2:.*]] = arith.addi %[[ADD1]], %[[C4]]
+
+// CHECK: %[[MUL3:.*]] = arith.muli %[[C1]], %[[C8]]
+// CHECK: %[[MUL4:.*]] = arith.muli
+
+// CHECK: %[[CMP:.*]] = arith.cmpi ule, %[[ADD2]], %[[MUL4]]
+// CHECK: %[[IF:.*]] = scf.if %[[CMP]] -> (vector<4xf32>) {
+
 // CHECK: %[[SPLAT:.*]] = vector.splat %[[CST]]
 // CHECK: %[[LOAD:.*]] = vector.load %arg0[%arg1, %arg1]
 // CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[SPLAT]]
-// CHECK: return %[[SELECT]] : vector<4xf32>
+
+// CHECK: } else {
+// CHECK: %[[LOAD:.*]] = vector.transfer_read %arg0[%arg1, %arg1], %[[CST]], %arg2 {amdgpu.transformed, in_bounds = [true]} : memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<4xf32>
+
+// CHECK: return %[[IF]] : vector<4xf32>
+
+// -----
+
+// CHECK-LABEL: func @transfer_to_maskedload_fatrawbuffer_dynamic(
+// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xf32, #amdgpu.address_space<fat_raw_buffer>>
+// CHECK-SAME: %[[ARG1:.*]]: index
+// CHECK-SAME: %[[ARG2:.*]]: vector<4xi1>
+func.func @transfer_to_maskedload_fatrawbuffer_dynamic(%mem : memref<?x?xf32, #amdgpu.address_space<fat_raw_buffer>>, %idx : index, %mask : vector<4xi1>) -> vector<4xf32> {
+  %cf0 = arith.constant 0.0 : f32
+  %res = vector.transfer_read %mem[%idx, %idx], %cf0, %mask {in_bounds = [true]} : memref<?x?xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<4xf32>
+  return %res : vector<4xf32>
+}
+
+// CHECK: %[[C1:.*]] = arith.constant 1
+// CHECK: %[[DIM1:.*]] = memref.dim %[[ARG0]], %[[C1]]
+// CHECK: %[[MUL0:.*]] = arith.muli %{{.*}}, %[[DIM1]]
+// CHECK: %[[C0:.*]] = arith.constant 0
+// CHECK: %[[DIM0:.*]] = memref.dim %[[ARG0]], %[[C0]]
+// CHECK: %[[MUL1:.*]] = arith.muli %{{.*}}, %[[DIM0]]
+
+// CHECK: %[[C1_1:.*]] = arith.constant 1
+// CHECK: %[[DIM1_1:.*]] = memref.dim %[[ARG0]], %[[C1_1]]
+// CHECK: %[[MUL2:.*]] = arith.muli %{{.*}}, %[[DIM1_1]]
 
 // -----
 
@@ -64,7 +108,6 @@ func.func @transfer_broadcasting(%mem : memref<8x8xf32, #amdgpu.address_space<fa
 // CHECK: %[[LOAD:.*]] = vector.load %arg0[%arg1, %arg1]
 // CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[SPLAT]]
 // CHECK: %[[BROADCAST:.*]] = vector.broadcast %[[SELECT]] : vector<1xf32> to vector<4xf32>
-// CHECK: return %[[BROADCAST]] : vector<4xf32>
 
 // -----
 
@@ -83,4 +126,3 @@ func.func @transfer_scalar(%mem : memref<8x8xf32, #amdgpu.address_space<fat_raw_
 // CHECK: %[[SPLAT:.*]] = vector.splat %[[CST]]
 // CHECK: %[[LOAD:.*]] = vector.load %arg0[%arg1, %arg1]
 // CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[SPLAT]]
-// CHECK: return %[[SELECT]] : vector<1xf32>
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 141986392917e..c4d87484fd5d5 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -1568,6 +1568,7 @@ cc_library(
         ":IR",
         ":MemRefDialect",
         ":Pass",
+        ":SCFDialect",
         ":SideEffectInterfaces",
         ":Support",
         ":TransformUtils",

>From eac8c2bd90117134bd21a0a07f0df33e75b1d081 Mon Sep 17 00:00:00 2001
From: jerryyin <zhuoryin at amd.com>
Date: Thu, 10 Apr 2025 17:30:31 +0000
Subject: [PATCH 2/2] Relaxing condition to do bounds check

---
 .../AMDGPU/Transforms/TransferReadToLoad.cpp  | 78 ++++++++++++-------
 1 file changed, 50 insertions(+), 28 deletions(-)

diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp
index 519f695d99f91..9dbcebee1252a 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp
@@ -103,6 +103,23 @@ static LogicalResult transferPreconditions(
   return success();
 }
 
+static Value createVectorLoadForMaskedLoad(OpBuilder &builder, Location loc,
+                                           vector::TransferReadOp readOp,
+                                           bool requiresBroadcasting,
+                                           VectorType unbroadcastedVectorType) {
+  Value fill = builder.create<vector::SplatOp>(loc, unbroadcastedVectorType,
+                                               readOp.getPadding());
+  Value load = builder.create<vector::LoadOp>(
+      loc, unbroadcastedVectorType, readOp.getSource(), readOp.getIndices());
+  Value res = builder.create<arith::SelectOp>(loc, unbroadcastedVectorType,
+                                              readOp.getMask(), load, fill);
+  // Insert a broadcasting op if required.
+  if (requiresBroadcasting) {
+    res = builder.create<vector::BroadcastOp>(loc, readOp.getVectorType(), res);
+  }
+  return res;
+}
+
 namespace {
 
 struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
@@ -150,14 +167,6 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
       stride = rewriter.create<arith::MulIOp>(loc, stride, nextStride);
     }
 
-    // Add vector size offset to linear index
-    VectorType vectorType = readOp.getVectorType();
-    int64_t vectorSize = vectorType.getNumElements();
-    Value vectorSizeOffset =
-        rewriter.create<arith::ConstantIndexOp>(loc, vectorSize);
-    Value upperBoundIndex =
-        rewriter.create<arith::AddIOp>(loc, linearIndex, vectorSizeOffset);
-
     Value totalSize = one;
     for (size_t i = 0; i < shape.size(); ++i) {
       Value dimensionSize;
@@ -169,35 +178,48 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
       totalSize = rewriter.create<arith::MulIOp>(loc, totalSize, dimensionSize);
     }
 
-    Value isInBounds = rewriter.create<arith::CmpIOp>(
-        loc, arith::CmpIPredicate::ule, upperBoundIndex, totalSize);
+    // delta = bufferSize - linearizedOffset
+    // 1) check if delta < vectorSize
+    VectorType vectorType = readOp.getVectorType();
+    int64_t vectorSize = vectorType.getNumElements();
+    Value vectorSizeOffset =
+        rewriter.create<arith::ConstantIndexOp>(loc, vectorSize);
+    Value delta = rewriter.create<arith::SubIOp>(loc, totalSize, linearIndex);
+    Value isOutofBounds = rewriter.create<arith::CmpIOp>(
+        loc, arith::CmpIPredicate::ule, delta, vectorSizeOffset);
+
+    // 2) check if (detla(bytes) % (32 / elementBitwidth) != 0)
+    int64_t elementBitWidth = vectorType.getElementTypeBitWidth();
+    Value deltaBytes = rewriter.create<arith::MulIOp>(
+        loc, delta,
+        rewriter.create<arith::ConstantIndexOp>(loc, elementBitWidth / 8));
+    Value elementsPerWord = rewriter.create<arith::ConstantIndexOp>(
+        loc, elementBitWidth < 32 ? 32 / elementBitWidth : 1);
+    Value isNotWordAligned = rewriter.create<arith::CmpIOp>(
+        loc, arith::CmpIPredicate::ne,
+        rewriter.create<arith::RemUIOp>(loc, deltaBytes, elementsPerWord),
+        rewriter.create<arith::ConstantIndexOp>(loc, 0));
+
+    // We take the fallback of transfer_read default lowering only it is both
+    // out-of-bounds and not word aligned.
+    Value ifCondition =
+        rewriter.create<arith::AndIOp>(loc, isOutofBounds, isNotWordAligned);
 
     auto thenBuilder = [&](OpBuilder &builder, Location loc) {
-      Value fill = builder.create<vector::SplatOp>(loc, unbroadcastedVectorType,
-                                                   readOp.getPadding());
-      Value load = builder.create<vector::LoadOp>(loc, unbroadcastedVectorType,
-                                                  readOp.getSource(),
-                                                  readOp.getIndices());
-      Value res = builder.create<arith::SelectOp>(loc, unbroadcastedVectorType,
-                                                  readOp.getMask(), load, fill);
-
-      // Insert a broadcasting op if required.
-      if (requiresBroadcasting) {
-        res = builder.create<vector::BroadcastOp>(loc, readOp.getVectorType(),
-                                                  res);
-      }
-      rewriter.create<scf::YieldOp>(loc, res);
-    };
-
-    auto elseBuilder = [&](OpBuilder &builder, Location loc) {
       Operation *read = builder.clone(*readOp.getOperation());
       read->setAttr("amdgpu.transformed", builder.getUnitAttr());
       Value readResult = read->getResult(0);
       builder.create<scf::YieldOp>(loc, readResult);
     };
 
+    auto elseBuilder = [&](OpBuilder &builder, Location loc) {
+      Value res = createVectorLoadForMaskedLoad(
+          builder, loc, readOp, requiresBroadcasting, unbroadcastedVectorType);
+      rewriter.create<scf::YieldOp>(loc, res);
+    };
+
     auto ifOp =
-        rewriter.create<scf::IfOp>(loc, isInBounds, thenBuilder, elseBuilder);
+        rewriter.create<scf::IfOp>(loc, ifCondition, thenBuilder, elseBuilder);
 
     rewriter.replaceOp(readOp, ifOp);
 



More information about the llvm-commits mailing list