[llvm] [mlir] [MLIR][AMDGPU] Adding dynamic size check to avoid subword buffer load (PR #135014)
Zhuoran Yin via llvm-commits
llvm-commits at lists.llvm.org
Mon Apr 14 07:20:02 PDT 2025
https://github.com/jerryyin updated https://github.com/llvm/llvm-project/pull/135014
>From 27c5497a5dc96fe8a05e775e0c9f793927c5d59c Mon Sep 17 00:00:00 2001
From: jerryyin <zhuoryin at amd.com>
Date: Tue, 8 Apr 2025 17:30:05 +0000
Subject: [PATCH 1/2] Adding dynamic size check to avoid subword buffer load
---
.../mlir/Dialect/AMDGPU/Transforms/Passes.td | 13 ++-
.../Dialect/AMDGPU/Transforms/CMakeLists.txt | 1 +
.../AMDGPU/Transforms/TransferReadToLoad.cpp | 93 ++++++++++++++++---
.../Dialect/AMDGPU/transfer-read-to-load.mlir | 48 +++++++++-
.../llvm-project-overlay/mlir/BUILD.bazel | 1 +
5 files changed, 137 insertions(+), 19 deletions(-)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
index 761caa448a57c..0e858108acf35 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
@@ -54,15 +54,20 @@ def AmdgpuResolveStridedMetadataPass : Pass<"amdgpu-resolve-strided-metadata"> {
def AmdgpuTransferReadToLoadPass : Pass<"amdgpu-transfer-read-to-load"> {
let summary = "Lower the operations from the vector transfer_read to vector load";
let description = [{
- This pass creates a transfer read op lowering. A vector trasfer read op
- will be lowered to a combination of vector.load, arith.select and
- vector.broadcast.
+ This pass creates a transfer read op lowering optimization. The lowering
+ will produce a conditional check at runtime. If within bounds, a vector
+ trasfer read op will be lowered to a combination of vector.load, arith.select
+ and vector.broadcast. If not, it will fallback to the default lowering
+ of the transfer_read op.
This pattern will make it possible for masked transfer_read to be lowered
towards buffer load with bounds check, allowing a more optimized global
load accessing pattern compared with existing implementation of
llvm.intr.masked.load on vectors.
}];
- let dependentDialects = [];
+ let dependentDialects = [
+ "scf::SCFDialect",
+ "memref::MemRefDialect"
+ ];
}
#endif // MLIR_DIALECT_AMDGPU_TRANSFORMS_PASSES_TD_
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
index bc5b6e9186449..8709a27e0168e 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
@@ -14,6 +14,7 @@ add_mlir_dialect_library(MLIRAMDGPUTransforms
MLIRAMDGPUUtils
MLIRArithDialect
MLIRMemRefDialect
+ MLIRSCFDialect
MLIRVectorDialect
MLIRControlFlowDialect
MLIRFuncDialect
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp
index 3c1a2eb962037..519f695d99f91 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp
@@ -9,6 +9,8 @@
#include "mlir/Dialect/AMDGPU/Transforms/Passes.h"
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
@@ -108,6 +110,8 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
PatternRewriter &rewriter) const override {
+ if (readOp->hasAttr("amdgpu.transformed"))
+ return failure();
bool requiresBroadcasting = false;
VectorType unbroadcastedVectorType;
@@ -117,20 +121,85 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
}
Location loc = readOp.getLoc();
- Value fill = rewriter.create<vector::SplatOp>(loc, unbroadcastedVectorType,
- readOp.getPadding());
- Value load = rewriter.create<vector::LoadOp>(
- loc, unbroadcastedVectorType, readOp.getSource(), readOp.getIndices());
- Value res = rewriter.create<arith::SelectOp>(loc, unbroadcastedVectorType,
- readOp.getMask(), load, fill);
-
- // Insert a broadcasting op if required.
- if (requiresBroadcasting) {
- res = rewriter.create<vector::BroadcastOp>(loc, readOp.getVectorType(),
- res);
+ Value src = readOp.getSource();
+ MemRefType memRefType = cast<MemRefType>(src.getType());
+ ArrayRef<int64_t> shape = memRefType.getShape();
+
+ Value linearIndex = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+ Value stride = one;
+
+ // Compute the linear index by linearIndex += indices[i] * stride
+ for (int i = shape.size() - 1; i >= 0; --i) {
+ Value currentIndex = readOp.getIndices()[i];
+ Value strideIndexed =
+ rewriter.create<arith::MulIOp>(loc, currentIndex, stride);
+ linearIndex =
+ rewriter.create<arith::AddIOp>(loc, linearIndex, strideIndexed);
+
+ if (i == 0)
+ break;
+
+ // Update stride for the next dimension
+ Value nextStride;
+ if (shape[i] != ShapedType::kDynamic) {
+ nextStride = rewriter.create<arith::ConstantIndexOp>(loc, shape[i]);
+ } else {
+ nextStride = rewriter.create<memref::DimOp>(loc, src, i);
+ }
+ stride = rewriter.create<arith::MulIOp>(loc, stride, nextStride);
+ }
+
+ // Add vector size offset to linear index
+ VectorType vectorType = readOp.getVectorType();
+ int64_t vectorSize = vectorType.getNumElements();
+ Value vectorSizeOffset =
+ rewriter.create<arith::ConstantIndexOp>(loc, vectorSize);
+ Value upperBoundIndex =
+ rewriter.create<arith::AddIOp>(loc, linearIndex, vectorSizeOffset);
+
+ Value totalSize = one;
+ for (size_t i = 0; i < shape.size(); ++i) {
+ Value dimensionSize;
+ if (shape[i] != ShapedType::kDynamic) {
+ dimensionSize = rewriter.create<arith::ConstantIndexOp>(loc, shape[i]);
+ } else {
+ dimensionSize = rewriter.create<memref::DimOp>(loc, src, i);
+ }
+ totalSize = rewriter.create<arith::MulIOp>(loc, totalSize, dimensionSize);
}
- rewriter.replaceOp(readOp, res);
+ Value isInBounds = rewriter.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::ule, upperBoundIndex, totalSize);
+
+ auto thenBuilder = [&](OpBuilder &builder, Location loc) {
+ Value fill = builder.create<vector::SplatOp>(loc, unbroadcastedVectorType,
+ readOp.getPadding());
+ Value load = builder.create<vector::LoadOp>(loc, unbroadcastedVectorType,
+ readOp.getSource(),
+ readOp.getIndices());
+ Value res = builder.create<arith::SelectOp>(loc, unbroadcastedVectorType,
+ readOp.getMask(), load, fill);
+
+ // Insert a broadcasting op if required.
+ if (requiresBroadcasting) {
+ res = builder.create<vector::BroadcastOp>(loc, readOp.getVectorType(),
+ res);
+ }
+ rewriter.create<scf::YieldOp>(loc, res);
+ };
+
+ auto elseBuilder = [&](OpBuilder &builder, Location loc) {
+ Operation *read = builder.clone(*readOp.getOperation());
+ read->setAttr("amdgpu.transformed", builder.getUnitAttr());
+ Value readResult = read->getResult(0);
+ builder.create<scf::YieldOp>(loc, readResult);
+ };
+
+ auto ifOp =
+ rewriter.create<scf::IfOp>(loc, isInBounds, thenBuilder, elseBuilder);
+
+ rewriter.replaceOp(readOp, ifOp);
return success();
}
diff --git a/mlir/test/Dialect/AMDGPU/transfer-read-to-load.mlir b/mlir/test/Dialect/AMDGPU/transfer-read-to-load.mlir
index 3e1283579f2b1..776a047e6a85d 100644
--- a/mlir/test/Dialect/AMDGPU/transfer-read-to-load.mlir
+++ b/mlir/test/Dialect/AMDGPU/transfer-read-to-load.mlir
@@ -10,10 +10,54 @@ func.func @transfer_to_maskedload_fatrawbuffer(%mem : memref<8x8xf32, #amdgpu.ad
return %res : vector<4xf32>
}
// CHECK: %[[CST:.*]] = arith.constant 0.0
+// CHECK: %[[C0:.*]] = arith.constant 0
+// CHECK: %[[C1:.*]] = arith.constant 1
+// CHECK: %[[MUL0:.*]] = arith.muli %[[ARG1]], %[[C1]]
+// CHECK: %[[ADD0:.*]] = arith.addi %[[C0]], %[[MUL0]]
+// CHECK: %[[C8:.*]] = arith.constant 8
+// CHECK: %[[MUL1:.*]] = arith.muli %[[C1]], %[[C8]]
+// CHECK: %[[MUL2:.*]] = arith.muli %[[ARG1]], %[[MUL1]]
+// CHECK: %[[ADD1:.*]] = arith.addi %[[ADD0]], %[[MUL2]]
+// CHECK: %[[C4:.*]] = arith.constant 4
+// CHECK: %[[ADD2:.*]] = arith.addi %[[ADD1]], %[[C4]]
+
+// CHECK: %[[MUL3:.*]] = arith.muli %[[C1]], %[[C8]]
+// CHECK: %[[MUL4:.*]] = arith.muli
+
+// CHECK: %[[CMP:.*]] = arith.cmpi ule, %[[ADD2]], %[[MUL4]]
+// CHECK: %[[IF:.*]] = scf.if %[[CMP]] -> (vector<4xf32>) {
+
// CHECK: %[[SPLAT:.*]] = vector.splat %[[CST]]
// CHECK: %[[LOAD:.*]] = vector.load %arg0[%arg1, %arg1]
// CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[SPLAT]]
-// CHECK: return %[[SELECT]] : vector<4xf32>
+
+// CHECK: } else {
+// CHECK: %[[LOAD:.*]] = vector.transfer_read %arg0[%arg1, %arg1], %[[CST]], %arg2 {amdgpu.transformed, in_bounds = [true]} : memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<4xf32>
+
+// CHECK: return %[[IF]] : vector<4xf32>
+
+// -----
+
+// CHECK-LABEL: func @transfer_to_maskedload_fatrawbuffer_dynamic(
+// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xf32, #amdgpu.address_space<fat_raw_buffer>>
+// CHECK-SAME: %[[ARG1:.*]]: index
+// CHECK-SAME: %[[ARG2:.*]]: vector<4xi1>
+func.func @transfer_to_maskedload_fatrawbuffer_dynamic(%mem : memref<?x?xf32, #amdgpu.address_space<fat_raw_buffer>>, %idx : index, %mask : vector<4xi1>) -> vector<4xf32> {
+ %cf0 = arith.constant 0.0 : f32
+ %res = vector.transfer_read %mem[%idx, %idx], %cf0, %mask {in_bounds = [true]} : memref<?x?xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<4xf32>
+ return %res : vector<4xf32>
+}
+
+// CHECK: %[[C1:.*]] = arith.constant 1
+// CHECK: %[[DIM1:.*]] = memref.dim %[[ARG0]], %[[C1]]
+// CHECK: %[[MUL0:.*]] = arith.muli %{{.*}}, %[[DIM1]]
+// CHECK: %[[C0:.*]] = arith.constant 0
+// CHECK: %[[DIM0:.*]] = memref.dim %[[ARG0]], %[[C0]]
+// CHECK: %[[MUL1:.*]] = arith.muli %{{.*}}, %[[DIM0]]
+
+// CHECK: %[[C1_1:.*]] = arith.constant 1
+// CHECK: %[[DIM1_1:.*]] = memref.dim %[[ARG0]], %[[C1_1]]
+// CHECK: %[[MUL2:.*]] = arith.muli %{{.*}}, %[[DIM1_1]]
// -----
@@ -64,7 +108,6 @@ func.func @transfer_broadcasting(%mem : memref<8x8xf32, #amdgpu.address_space<fa
// CHECK: %[[LOAD:.*]] = vector.load %arg0[%arg1, %arg1]
// CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[SPLAT]]
// CHECK: %[[BROADCAST:.*]] = vector.broadcast %[[SELECT]] : vector<1xf32> to vector<4xf32>
-// CHECK: return %[[BROADCAST]] : vector<4xf32>
// -----
@@ -83,4 +126,3 @@ func.func @transfer_scalar(%mem : memref<8x8xf32, #amdgpu.address_space<fat_raw_
// CHECK: %[[SPLAT:.*]] = vector.splat %[[CST]]
// CHECK: %[[LOAD:.*]] = vector.load %arg0[%arg1, %arg1]
// CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[SPLAT]]
-// CHECK: return %[[SELECT]] : vector<1xf32>
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 141986392917e..c4d87484fd5d5 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -1568,6 +1568,7 @@ cc_library(
":IR",
":MemRefDialect",
":Pass",
+ ":SCFDialect",
":SideEffectInterfaces",
":Support",
":TransformUtils",
>From eac8c2bd90117134bd21a0a07f0df33e75b1d081 Mon Sep 17 00:00:00 2001
From: jerryyin <zhuoryin at amd.com>
Date: Thu, 10 Apr 2025 17:30:31 +0000
Subject: [PATCH 2/2] Relaxing condition to do bounds check
---
.../AMDGPU/Transforms/TransferReadToLoad.cpp | 78 ++++++++++++-------
1 file changed, 50 insertions(+), 28 deletions(-)
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp
index 519f695d99f91..9dbcebee1252a 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp
@@ -103,6 +103,23 @@ static LogicalResult transferPreconditions(
return success();
}
+static Value createVectorLoadForMaskedLoad(OpBuilder &builder, Location loc,
+ vector::TransferReadOp readOp,
+ bool requiresBroadcasting,
+ VectorType unbroadcastedVectorType) {
+ Value fill = builder.create<vector::SplatOp>(loc, unbroadcastedVectorType,
+ readOp.getPadding());
+ Value load = builder.create<vector::LoadOp>(
+ loc, unbroadcastedVectorType, readOp.getSource(), readOp.getIndices());
+ Value res = builder.create<arith::SelectOp>(loc, unbroadcastedVectorType,
+ readOp.getMask(), load, fill);
+ // Insert a broadcasting op if required.
+ if (requiresBroadcasting) {
+ res = builder.create<vector::BroadcastOp>(loc, readOp.getVectorType(), res);
+ }
+ return res;
+}
+
namespace {
struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
@@ -150,14 +167,6 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
stride = rewriter.create<arith::MulIOp>(loc, stride, nextStride);
}
- // Add vector size offset to linear index
- VectorType vectorType = readOp.getVectorType();
- int64_t vectorSize = vectorType.getNumElements();
- Value vectorSizeOffset =
- rewriter.create<arith::ConstantIndexOp>(loc, vectorSize);
- Value upperBoundIndex =
- rewriter.create<arith::AddIOp>(loc, linearIndex, vectorSizeOffset);
-
Value totalSize = one;
for (size_t i = 0; i < shape.size(); ++i) {
Value dimensionSize;
@@ -169,35 +178,48 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
totalSize = rewriter.create<arith::MulIOp>(loc, totalSize, dimensionSize);
}
- Value isInBounds = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::ule, upperBoundIndex, totalSize);
+ // delta = bufferSize - linearizedOffset
+ // 1) check if delta < vectorSize
+ VectorType vectorType = readOp.getVectorType();
+ int64_t vectorSize = vectorType.getNumElements();
+ Value vectorSizeOffset =
+ rewriter.create<arith::ConstantIndexOp>(loc, vectorSize);
+ Value delta = rewriter.create<arith::SubIOp>(loc, totalSize, linearIndex);
+ Value isOutofBounds = rewriter.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::ule, delta, vectorSizeOffset);
+
+ // 2) check if (detla(bytes) % (32 / elementBitwidth) != 0)
+ int64_t elementBitWidth = vectorType.getElementTypeBitWidth();
+ Value deltaBytes = rewriter.create<arith::MulIOp>(
+ loc, delta,
+ rewriter.create<arith::ConstantIndexOp>(loc, elementBitWidth / 8));
+ Value elementsPerWord = rewriter.create<arith::ConstantIndexOp>(
+ loc, elementBitWidth < 32 ? 32 / elementBitWidth : 1);
+ Value isNotWordAligned = rewriter.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::ne,
+ rewriter.create<arith::RemUIOp>(loc, deltaBytes, elementsPerWord),
+ rewriter.create<arith::ConstantIndexOp>(loc, 0));
+
+ // We take the fallback of transfer_read default lowering only it is both
+ // out-of-bounds and not word aligned.
+ Value ifCondition =
+ rewriter.create<arith::AndIOp>(loc, isOutofBounds, isNotWordAligned);
auto thenBuilder = [&](OpBuilder &builder, Location loc) {
- Value fill = builder.create<vector::SplatOp>(loc, unbroadcastedVectorType,
- readOp.getPadding());
- Value load = builder.create<vector::LoadOp>(loc, unbroadcastedVectorType,
- readOp.getSource(),
- readOp.getIndices());
- Value res = builder.create<arith::SelectOp>(loc, unbroadcastedVectorType,
- readOp.getMask(), load, fill);
-
- // Insert a broadcasting op if required.
- if (requiresBroadcasting) {
- res = builder.create<vector::BroadcastOp>(loc, readOp.getVectorType(),
- res);
- }
- rewriter.create<scf::YieldOp>(loc, res);
- };
-
- auto elseBuilder = [&](OpBuilder &builder, Location loc) {
Operation *read = builder.clone(*readOp.getOperation());
read->setAttr("amdgpu.transformed", builder.getUnitAttr());
Value readResult = read->getResult(0);
builder.create<scf::YieldOp>(loc, readResult);
};
+ auto elseBuilder = [&](OpBuilder &builder, Location loc) {
+ Value res = createVectorLoadForMaskedLoad(
+ builder, loc, readOp, requiresBroadcasting, unbroadcastedVectorType);
+ rewriter.create<scf::YieldOp>(loc, res);
+ };
+
auto ifOp =
- rewriter.create<scf::IfOp>(loc, isInBounds, thenBuilder, elseBuilder);
+ rewriter.create<scf::IfOp>(loc, ifCondition, thenBuilder, elseBuilder);
rewriter.replaceOp(readOp, ifOp);
More information about the llvm-commits
mailing list