[llvm] [mlir] [MLIR][AMDGPU] Adding dynamic size check to avoid subword buffer load (PR #135014)
Zhuoran Yin via llvm-commits
llvm-commits at lists.llvm.org
Mon Apr 14 14:09:32 PDT 2025
https://github.com/jerryyin updated https://github.com/llvm/llvm-project/pull/135014
>From 27c5497a5dc96fe8a05e775e0c9f793927c5d59c Mon Sep 17 00:00:00 2001
From: jerryyin <zhuoryin at amd.com>
Date: Tue, 8 Apr 2025 17:30:05 +0000
Subject: [PATCH 1/6] Adding dynamic size check to avoid subword buffer load
---
.../mlir/Dialect/AMDGPU/Transforms/Passes.td | 13 ++-
.../Dialect/AMDGPU/Transforms/CMakeLists.txt | 1 +
.../AMDGPU/Transforms/TransferReadToLoad.cpp | 93 ++++++++++++++++---
.../Dialect/AMDGPU/transfer-read-to-load.mlir | 48 +++++++++-
.../llvm-project-overlay/mlir/BUILD.bazel | 1 +
5 files changed, 137 insertions(+), 19 deletions(-)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
index 761caa448a57c..0e858108acf35 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
@@ -54,15 +54,20 @@ def AmdgpuResolveStridedMetadataPass : Pass<"amdgpu-resolve-strided-metadata"> {
def AmdgpuTransferReadToLoadPass : Pass<"amdgpu-transfer-read-to-load"> {
let summary = "Lower the operations from the vector transfer_read to vector load";
let description = [{
- This pass creates a transfer read op lowering. A vector trasfer read op
- will be lowered to a combination of vector.load, arith.select and
- vector.broadcast.
+ This pass creates a transfer read op lowering optimization. The lowering
+ will produce a conditional check at runtime. If within bounds, a vector
+ trasfer read op will be lowered to a combination of vector.load, arith.select
+ and vector.broadcast. If not, it will fallback to the default lowering
+ of the transfer_read op.
This pattern will make it possible for masked transfer_read to be lowered
towards buffer load with bounds check, allowing a more optimized global
load accessing pattern compared with existing implementation of
llvm.intr.masked.load on vectors.
}];
- let dependentDialects = [];
+ let dependentDialects = [
+ "scf::SCFDialect",
+ "memref::MemRefDialect"
+ ];
}
#endif // MLIR_DIALECT_AMDGPU_TRANSFORMS_PASSES_TD_
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
index bc5b6e9186449..8709a27e0168e 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
@@ -14,6 +14,7 @@ add_mlir_dialect_library(MLIRAMDGPUTransforms
MLIRAMDGPUUtils
MLIRArithDialect
MLIRMemRefDialect
+ MLIRSCFDialect
MLIRVectorDialect
MLIRControlFlowDialect
MLIRFuncDialect
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp
index 3c1a2eb962037..519f695d99f91 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp
@@ -9,6 +9,8 @@
#include "mlir/Dialect/AMDGPU/Transforms/Passes.h"
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
@@ -108,6 +110,8 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
PatternRewriter &rewriter) const override {
+ if (readOp->hasAttr("amdgpu.transformed"))
+ return failure();
bool requiresBroadcasting = false;
VectorType unbroadcastedVectorType;
@@ -117,20 +121,85 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
}
Location loc = readOp.getLoc();
- Value fill = rewriter.create<vector::SplatOp>(loc, unbroadcastedVectorType,
- readOp.getPadding());
- Value load = rewriter.create<vector::LoadOp>(
- loc, unbroadcastedVectorType, readOp.getSource(), readOp.getIndices());
- Value res = rewriter.create<arith::SelectOp>(loc, unbroadcastedVectorType,
- readOp.getMask(), load, fill);
-
- // Insert a broadcasting op if required.
- if (requiresBroadcasting) {
- res = rewriter.create<vector::BroadcastOp>(loc, readOp.getVectorType(),
- res);
+ Value src = readOp.getSource();
+ MemRefType memRefType = cast<MemRefType>(src.getType());
+ ArrayRef<int64_t> shape = memRefType.getShape();
+
+ Value linearIndex = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+ Value stride = one;
+
+ // Compute the linear index by linearIndex += indices[i] * stride
+ for (int i = shape.size() - 1; i >= 0; --i) {
+ Value currentIndex = readOp.getIndices()[i];
+ Value strideIndexed =
+ rewriter.create<arith::MulIOp>(loc, currentIndex, stride);
+ linearIndex =
+ rewriter.create<arith::AddIOp>(loc, linearIndex, strideIndexed);
+
+ if (i == 0)
+ break;
+
+ // Update stride for the next dimension
+ Value nextStride;
+ if (shape[i] != ShapedType::kDynamic) {
+ nextStride = rewriter.create<arith::ConstantIndexOp>(loc, shape[i]);
+ } else {
+ nextStride = rewriter.create<memref::DimOp>(loc, src, i);
+ }
+ stride = rewriter.create<arith::MulIOp>(loc, stride, nextStride);
+ }
+
+ // Add vector size offset to linear index
+ VectorType vectorType = readOp.getVectorType();
+ int64_t vectorSize = vectorType.getNumElements();
+ Value vectorSizeOffset =
+ rewriter.create<arith::ConstantIndexOp>(loc, vectorSize);
+ Value upperBoundIndex =
+ rewriter.create<arith::AddIOp>(loc, linearIndex, vectorSizeOffset);
+
+ Value totalSize = one;
+ for (size_t i = 0; i < shape.size(); ++i) {
+ Value dimensionSize;
+ if (shape[i] != ShapedType::kDynamic) {
+ dimensionSize = rewriter.create<arith::ConstantIndexOp>(loc, shape[i]);
+ } else {
+ dimensionSize = rewriter.create<memref::DimOp>(loc, src, i);
+ }
+ totalSize = rewriter.create<arith::MulIOp>(loc, totalSize, dimensionSize);
}
- rewriter.replaceOp(readOp, res);
+ Value isInBounds = rewriter.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::ule, upperBoundIndex, totalSize);
+
+ auto thenBuilder = [&](OpBuilder &builder, Location loc) {
+ Value fill = builder.create<vector::SplatOp>(loc, unbroadcastedVectorType,
+ readOp.getPadding());
+ Value load = builder.create<vector::LoadOp>(loc, unbroadcastedVectorType,
+ readOp.getSource(),
+ readOp.getIndices());
+ Value res = builder.create<arith::SelectOp>(loc, unbroadcastedVectorType,
+ readOp.getMask(), load, fill);
+
+ // Insert a broadcasting op if required.
+ if (requiresBroadcasting) {
+ res = builder.create<vector::BroadcastOp>(loc, readOp.getVectorType(),
+ res);
+ }
+ rewriter.create<scf::YieldOp>(loc, res);
+ };
+
+ auto elseBuilder = [&](OpBuilder &builder, Location loc) {
+ Operation *read = builder.clone(*readOp.getOperation());
+ read->setAttr("amdgpu.transformed", builder.getUnitAttr());
+ Value readResult = read->getResult(0);
+ builder.create<scf::YieldOp>(loc, readResult);
+ };
+
+ auto ifOp =
+ rewriter.create<scf::IfOp>(loc, isInBounds, thenBuilder, elseBuilder);
+
+ rewriter.replaceOp(readOp, ifOp);
return success();
}
diff --git a/mlir/test/Dialect/AMDGPU/transfer-read-to-load.mlir b/mlir/test/Dialect/AMDGPU/transfer-read-to-load.mlir
index 3e1283579f2b1..776a047e6a85d 100644
--- a/mlir/test/Dialect/AMDGPU/transfer-read-to-load.mlir
+++ b/mlir/test/Dialect/AMDGPU/transfer-read-to-load.mlir
@@ -10,10 +10,54 @@ func.func @transfer_to_maskedload_fatrawbuffer(%mem : memref<8x8xf32, #amdgpu.ad
return %res : vector<4xf32>
}
// CHECK: %[[CST:.*]] = arith.constant 0.0
+// CHECK: %[[C0:.*]] = arith.constant 0
+// CHECK: %[[C1:.*]] = arith.constant 1
+// CHECK: %[[MUL0:.*]] = arith.muli %[[ARG1]], %[[C1]]
+// CHECK: %[[ADD0:.*]] = arith.addi %[[C0]], %[[MUL0]]
+// CHECK: %[[C8:.*]] = arith.constant 8
+// CHECK: %[[MUL1:.*]] = arith.muli %[[C1]], %[[C8]]
+// CHECK: %[[MUL2:.*]] = arith.muli %[[ARG1]], %[[MUL1]]
+// CHECK: %[[ADD1:.*]] = arith.addi %[[ADD0]], %[[MUL2]]
+// CHECK: %[[C4:.*]] = arith.constant 4
+// CHECK: %[[ADD2:.*]] = arith.addi %[[ADD1]], %[[C4]]
+
+// CHECK: %[[MUL3:.*]] = arith.muli %[[C1]], %[[C8]]
+// CHECK: %[[MUL4:.*]] = arith.muli
+
+// CHECK: %[[CMP:.*]] = arith.cmpi ule, %[[ADD2]], %[[MUL4]]
+// CHECK: %[[IF:.*]] = scf.if %[[CMP]] -> (vector<4xf32>) {
+
// CHECK: %[[SPLAT:.*]] = vector.splat %[[CST]]
// CHECK: %[[LOAD:.*]] = vector.load %arg0[%arg1, %arg1]
// CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[SPLAT]]
-// CHECK: return %[[SELECT]] : vector<4xf32>
+
+// CHECK: } else {
+// CHECK: %[[LOAD:.*]] = vector.transfer_read %arg0[%arg1, %arg1], %[[CST]], %arg2 {amdgpu.transformed, in_bounds = [true]} : memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<4xf32>
+
+// CHECK: return %[[IF]] : vector<4xf32>
+
+// -----
+
+// CHECK-LABEL: func @transfer_to_maskedload_fatrawbuffer_dynamic(
+// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xf32, #amdgpu.address_space<fat_raw_buffer>>
+// CHECK-SAME: %[[ARG1:.*]]: index
+// CHECK-SAME: %[[ARG2:.*]]: vector<4xi1>
+func.func @transfer_to_maskedload_fatrawbuffer_dynamic(%mem : memref<?x?xf32, #amdgpu.address_space<fat_raw_buffer>>, %idx : index, %mask : vector<4xi1>) -> vector<4xf32> {
+ %cf0 = arith.constant 0.0 : f32
+ %res = vector.transfer_read %mem[%idx, %idx], %cf0, %mask {in_bounds = [true]} : memref<?x?xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<4xf32>
+ return %res : vector<4xf32>
+}
+
+// CHECK: %[[C1:.*]] = arith.constant 1
+// CHECK: %[[DIM1:.*]] = memref.dim %[[ARG0]], %[[C1]]
+// CHECK: %[[MUL0:.*]] = arith.muli %{{.*}}, %[[DIM1]]
+// CHECK: %[[C0:.*]] = arith.constant 0
+// CHECK: %[[DIM0:.*]] = memref.dim %[[ARG0]], %[[C0]]
+// CHECK: %[[MUL1:.*]] = arith.muli %{{.*}}, %[[DIM0]]
+
+// CHECK: %[[C1_1:.*]] = arith.constant 1
+// CHECK: %[[DIM1_1:.*]] = memref.dim %[[ARG0]], %[[C1_1]]
+// CHECK: %[[MUL2:.*]] = arith.muli %{{.*}}, %[[DIM1_1]]
// -----
@@ -64,7 +108,6 @@ func.func @transfer_broadcasting(%mem : memref<8x8xf32, #amdgpu.address_space<fa
// CHECK: %[[LOAD:.*]] = vector.load %arg0[%arg1, %arg1]
// CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[SPLAT]]
// CHECK: %[[BROADCAST:.*]] = vector.broadcast %[[SELECT]] : vector<1xf32> to vector<4xf32>
-// CHECK: return %[[BROADCAST]] : vector<4xf32>
// -----
@@ -83,4 +126,3 @@ func.func @transfer_scalar(%mem : memref<8x8xf32, #amdgpu.address_space<fat_raw_
// CHECK: %[[SPLAT:.*]] = vector.splat %[[CST]]
// CHECK: %[[LOAD:.*]] = vector.load %arg0[%arg1, %arg1]
// CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[SPLAT]]
-// CHECK: return %[[SELECT]] : vector<1xf32>
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 141986392917e..c4d87484fd5d5 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -1568,6 +1568,7 @@ cc_library(
":IR",
":MemRefDialect",
":Pass",
+ ":SCFDialect",
":SideEffectInterfaces",
":Support",
":TransformUtils",
>From eac8c2bd90117134bd21a0a07f0df33e75b1d081 Mon Sep 17 00:00:00 2001
From: jerryyin <zhuoryin at amd.com>
Date: Thu, 10 Apr 2025 17:30:31 +0000
Subject: [PATCH 2/6] Relaxing condition to do bounds check
---
.../AMDGPU/Transforms/TransferReadToLoad.cpp | 78 ++++++++++++-------
1 file changed, 50 insertions(+), 28 deletions(-)
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp
index 519f695d99f91..9dbcebee1252a 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp
@@ -103,6 +103,23 @@ static LogicalResult transferPreconditions(
return success();
}
+static Value createVectorLoadForMaskedLoad(OpBuilder &builder, Location loc,
+ vector::TransferReadOp readOp,
+ bool requiresBroadcasting,
+ VectorType unbroadcastedVectorType) {
+ Value fill = builder.create<vector::SplatOp>(loc, unbroadcastedVectorType,
+ readOp.getPadding());
+ Value load = builder.create<vector::LoadOp>(
+ loc, unbroadcastedVectorType, readOp.getSource(), readOp.getIndices());
+ Value res = builder.create<arith::SelectOp>(loc, unbroadcastedVectorType,
+ readOp.getMask(), load, fill);
+ // Insert a broadcasting op if required.
+ if (requiresBroadcasting) {
+ res = builder.create<vector::BroadcastOp>(loc, readOp.getVectorType(), res);
+ }
+ return res;
+}
+
namespace {
struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
@@ -150,14 +167,6 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
stride = rewriter.create<arith::MulIOp>(loc, stride, nextStride);
}
- // Add vector size offset to linear index
- VectorType vectorType = readOp.getVectorType();
- int64_t vectorSize = vectorType.getNumElements();
- Value vectorSizeOffset =
- rewriter.create<arith::ConstantIndexOp>(loc, vectorSize);
- Value upperBoundIndex =
- rewriter.create<arith::AddIOp>(loc, linearIndex, vectorSizeOffset);
-
Value totalSize = one;
for (size_t i = 0; i < shape.size(); ++i) {
Value dimensionSize;
@@ -169,35 +178,48 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
totalSize = rewriter.create<arith::MulIOp>(loc, totalSize, dimensionSize);
}
- Value isInBounds = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::ule, upperBoundIndex, totalSize);
+ // delta = bufferSize - linearizedOffset
+ // 1) check if delta < vectorSize
+ VectorType vectorType = readOp.getVectorType();
+ int64_t vectorSize = vectorType.getNumElements();
+ Value vectorSizeOffset =
+ rewriter.create<arith::ConstantIndexOp>(loc, vectorSize);
+ Value delta = rewriter.create<arith::SubIOp>(loc, totalSize, linearIndex);
+ Value isOutofBounds = rewriter.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::ule, delta, vectorSizeOffset);
+
+ // 2) check if (detla(bytes) % (32 / elementBitwidth) != 0)
+ int64_t elementBitWidth = vectorType.getElementTypeBitWidth();
+ Value deltaBytes = rewriter.create<arith::MulIOp>(
+ loc, delta,
+ rewriter.create<arith::ConstantIndexOp>(loc, elementBitWidth / 8));
+ Value elementsPerWord = rewriter.create<arith::ConstantIndexOp>(
+ loc, elementBitWidth < 32 ? 32 / elementBitWidth : 1);
+ Value isNotWordAligned = rewriter.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::ne,
+ rewriter.create<arith::RemUIOp>(loc, deltaBytes, elementsPerWord),
+ rewriter.create<arith::ConstantIndexOp>(loc, 0));
+
+ // We take the fallback of transfer_read default lowering only it is both
+ // out-of-bounds and not word aligned.
+ Value ifCondition =
+ rewriter.create<arith::AndIOp>(loc, isOutofBounds, isNotWordAligned);
auto thenBuilder = [&](OpBuilder &builder, Location loc) {
- Value fill = builder.create<vector::SplatOp>(loc, unbroadcastedVectorType,
- readOp.getPadding());
- Value load = builder.create<vector::LoadOp>(loc, unbroadcastedVectorType,
- readOp.getSource(),
- readOp.getIndices());
- Value res = builder.create<arith::SelectOp>(loc, unbroadcastedVectorType,
- readOp.getMask(), load, fill);
-
- // Insert a broadcasting op if required.
- if (requiresBroadcasting) {
- res = builder.create<vector::BroadcastOp>(loc, readOp.getVectorType(),
- res);
- }
- rewriter.create<scf::YieldOp>(loc, res);
- };
-
- auto elseBuilder = [&](OpBuilder &builder, Location loc) {
Operation *read = builder.clone(*readOp.getOperation());
read->setAttr("amdgpu.transformed", builder.getUnitAttr());
Value readResult = read->getResult(0);
builder.create<scf::YieldOp>(loc, readResult);
};
+ auto elseBuilder = [&](OpBuilder &builder, Location loc) {
+ Value res = createVectorLoadForMaskedLoad(
+ builder, loc, readOp, requiresBroadcasting, unbroadcastedVectorType);
+ rewriter.create<scf::YieldOp>(loc, res);
+ };
+
auto ifOp =
- rewriter.create<scf::IfOp>(loc, isInBounds, thenBuilder, elseBuilder);
+ rewriter.create<scf::IfOp>(loc, ifCondition, thenBuilder, elseBuilder);
rewriter.replaceOp(readOp, ifOp);
>From ca9d7df27328cebb792b3149716a89402eae97f4 Mon Sep 17 00:00:00 2001
From: jerryyin <zhuoryin at amd.com>
Date: Thu, 10 Apr 2025 20:48:50 +0000
Subject: [PATCH 3/6] Use affine for index and size computations
---
.../AMDGPU/Transforms/TransferReadToLoad.cpp | 92 +++++++++++--------
1 file changed, 52 insertions(+), 40 deletions(-)
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp
index 9dbcebee1252a..a60718041cbcb 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp
@@ -9,10 +9,15 @@
#include "mlir/Dialect/AMDGPU/Transforms/Passes.h"
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Pass/Pass.h"
@@ -139,57 +144,64 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
Location loc = readOp.getLoc();
Value src = readOp.getSource();
- MemRefType memRefType = cast<MemRefType>(src.getType());
- ArrayRef<int64_t> shape = memRefType.getShape();
-
- Value linearIndex = rewriter.create<arith::ConstantIndexOp>(loc, 0);
- Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
- Value stride = one;
-
- // Compute the linear index by linearIndex += indices[i] * stride
- for (int i = shape.size() - 1; i >= 0; --i) {
- Value currentIndex = readOp.getIndices()[i];
- Value strideIndexed =
- rewriter.create<arith::MulIOp>(loc, currentIndex, stride);
- linearIndex =
- rewriter.create<arith::AddIOp>(loc, linearIndex, strideIndexed);
-
- if (i == 0)
- break;
-
- // Update stride for the next dimension
- Value nextStride;
- if (shape[i] != ShapedType::kDynamic) {
- nextStride = rewriter.create<arith::ConstantIndexOp>(loc, shape[i]);
- } else {
- nextStride = rewriter.create<memref::DimOp>(loc, src, i);
- }
- stride = rewriter.create<arith::MulIOp>(loc, stride, nextStride);
- }
- Value totalSize = one;
- for (size_t i = 0; i < shape.size(); ++i) {
- Value dimensionSize;
- if (shape[i] != ShapedType::kDynamic) {
- dimensionSize = rewriter.create<arith::ConstantIndexOp>(loc, shape[i]);
- } else {
- dimensionSize = rewriter.create<memref::DimOp>(loc, src, i);
- }
- totalSize = rewriter.create<arith::MulIOp>(loc, totalSize, dimensionSize);
+ VectorType vectorType = readOp.getVectorType();
+ int64_t vectorSize = vectorType.getNumElements();
+ int64_t elementBitWidth = vectorType.getElementTypeBitWidth();
+ // Value linearIndex = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ SmallVector<OpFoldResult> indices = readOp.getIndices();
+
+ auto stridedMetadata =
+ rewriter.create<memref::ExtractStridedMetadataOp>(loc, src);
+ memref::LinearizedMemRefInfo linearizedInfo;
+ OpFoldResult linearizedIndices;
+ std::tie(linearizedInfo, linearizedIndices) =
+ memref::getLinearizedMemRefOffsetAndSize(
+ rewriter, loc, elementBitWidth, elementBitWidth,
+ stridedMetadata.getConstifiedMixedOffset(),
+ stridedMetadata.getConstifiedMixedSizes(),
+ stridedMetadata.getConstifiedMixedStrides(), indices);
+ // OpFoldResult linearIndexSize = linearizedInfo.linearizedSize;
+ Value linearIndex =
+ getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices);
+
+ // Note below doesn't give the correct result for the linearized size.
+ // It compute the mutiplied sizes of all dimensions instead of taking
+ // the maximum of each dimension size * stride.
+ // TODO(jerryyin): Fix the getLinearizedMemRefOffsetAndSize() function
+ // Value totalSize = getValueOrCreateConstantIndexOp(
+ // rewriter, loc, linearizedInfo.linearizedSize);
+ SmallVector<AffineExpr> productExpressions;
+ SmallVector<Value> productResults;
+ unsigned sourceRank =
+ cast<ShapedType>(readOp.getSource().getType()).getRank();
+
+ SmallVector<AffineExpr> symbols(2 * sourceRank);
+ SmallVector<Value> offsetValues(2 * sourceRank);
+ bindSymbolsList(rewriter.getContext(), MutableArrayRef{symbols});
+ for (size_t i = 0; i < sourceRank; ++i) {
+ unsigned offsetIdx = 2 * i;
+ productExpressions.push_back(symbols[offsetIdx] * symbols[offsetIdx + 1]);
+ offsetValues[offsetIdx] = stridedMetadata.getStrides()[i];
+ offsetValues[offsetIdx + 1] = stridedMetadata.getSizes()[i];
}
+ AffineMap maxMap = AffineMap::get(
+ /*dimCount=*/0, /*symbolCount=*/symbols.size(), productExpressions,
+ rewriter.getContext());
+ Value totalSize =
+ rewriter.create<affine::AffineMaxOp>(loc, maxMap, offsetValues);
+
// delta = bufferSize - linearizedOffset
- // 1) check if delta < vectorSize
- VectorType vectorType = readOp.getVectorType();
- int64_t vectorSize = vectorType.getNumElements();
Value vectorSizeOffset =
rewriter.create<arith::ConstantIndexOp>(loc, vectorSize);
Value delta = rewriter.create<arith::SubIOp>(loc, totalSize, linearIndex);
+
+ // 1) check if delta < vectorSize
Value isOutofBounds = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::ule, delta, vectorSizeOffset);
// 2) check if (detla(bytes) % (32 / elementBitwidth) != 0)
- int64_t elementBitWidth = vectorType.getElementTypeBitWidth();
Value deltaBytes = rewriter.create<arith::MulIOp>(
loc, delta,
rewriter.create<arith::ConstantIndexOp>(loc, elementBitWidth / 8));
>From a09cd5163e5e9d17b3ec2f3e627037320bde9e46 Mon Sep 17 00:00:00 2001
From: jerryyin <zhuoryin at amd.com>
Date: Fri, 11 Apr 2025 18:37:11 +0000
Subject: [PATCH 4/6] Invoking vector transfer lowering pattern in amdgpu pass
---
.../AMDGPU/Transforms/TransferReadToLoad.cpp | 22 +++++++++++--------
1 file changed, 13 insertions(+), 9 deletions(-)
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp
index a60718041cbcb..60839cbfaeae1 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp
@@ -16,12 +16,14 @@
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
namespace mlir::amdgpu {
@@ -132,7 +134,7 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
PatternRewriter &rewriter) const override {
- if (readOp->hasAttr("amdgpu.transformed"))
+ if (readOp->hasAttr("amdgpu.buffer_transfer_read_needs_mask"))
return failure();
bool requiresBroadcasting = false;
@@ -148,7 +150,6 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
VectorType vectorType = readOp.getVectorType();
int64_t vectorSize = vectorType.getNumElements();
int64_t elementBitWidth = vectorType.getElementTypeBitWidth();
- // Value linearIndex = rewriter.create<arith::ConstantIndexOp>(loc, 0);
SmallVector<OpFoldResult> indices = readOp.getIndices();
auto stridedMetadata =
@@ -161,16 +162,15 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
stridedMetadata.getConstifiedMixedOffset(),
stridedMetadata.getConstifiedMixedSizes(),
stridedMetadata.getConstifiedMixedStrides(), indices);
- // OpFoldResult linearIndexSize = linearizedInfo.linearizedSize;
Value linearIndex =
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices);
- // Note below doesn't give the correct result for the linearized size.
- // It compute the mutiplied sizes of all dimensions instead of taking
- // the maximum of each dimension size * stride.
// TODO(jerryyin): Fix the getLinearizedMemRefOffsetAndSize() function
+ // Note below doesn't give the correct result for the linearized size.
// Value totalSize = getValueOrCreateConstantIndexOp(
// rewriter, loc, linearizedInfo.linearizedSize);
+ // It compute the mutiplied sizes of all dimensions instead of taking
+ // the maximum of each dimension size * stride.
SmallVector<AffineExpr> productExpressions;
SmallVector<Value> productResults;
unsigned sourceRank =
@@ -201,7 +201,7 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
Value isOutofBounds = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::ule, delta, vectorSizeOffset);
- // 2) check if (detla(bytes) % (32 / elementBitwidth) != 0)
+ // 2) check if (detla_bytes % (32 / elementBitwidth) != 0)
Value deltaBytes = rewriter.create<arith::MulIOp>(
loc, delta,
rewriter.create<arith::ConstantIndexOp>(loc, elementBitWidth / 8));
@@ -219,7 +219,8 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
auto thenBuilder = [&](OpBuilder &builder, Location loc) {
Operation *read = builder.clone(*readOp.getOperation());
- read->setAttr("amdgpu.transformed", builder.getUnitAttr());
+ read->setAttr("amdgpu.buffer_transfer_read_needs_mask",
+ builder.getUnitAttr());
Value readResult = read->getResult(0);
builder.create<scf::YieldOp>(loc, readResult);
};
@@ -244,6 +245,7 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
void mlir::amdgpu::populateAmdgpuTransferReadToLoadPatterns(
RewritePatternSet &patterns) {
patterns.add<TransferReadLowering>(patterns.getContext());
+ vector::populateVectorTransferLoweringPatterns(patterns);
}
struct AmdgpuTransferReadToLoadPass final
@@ -252,6 +254,8 @@ struct AmdgpuTransferReadToLoadPass final
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateAmdgpuTransferReadToLoadPatterns(patterns);
- walkAndApplyPatterns(getOperation(), std::move(patterns));
+ if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
+ return signalPassFailure();
+ }
}
};
>From d5967040d1889ff6f13d351564bd1f78ec014486 Mon Sep 17 00:00:00 2001
From: jerryyin <zhuoryin at amd.com>
Date: Fri, 11 Apr 2025 21:34:42 +0000
Subject: [PATCH 5/6] Matching unit test with latest implementation
---
.../AMDGPU/Transforms/TransferReadToLoad.cpp | 6 +-
.../Dialect/AMDGPU/transfer-read-to-load.mlir | 119 ++++++++++--------
2 files changed, 71 insertions(+), 54 deletions(-)
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp
index 60839cbfaeae1..4cab311f2b6c7 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp
@@ -154,9 +154,8 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
auto stridedMetadata =
rewriter.create<memref::ExtractStridedMetadataOp>(loc, src);
- memref::LinearizedMemRefInfo linearizedInfo;
OpFoldResult linearizedIndices;
- std::tie(linearizedInfo, linearizedIndices) =
+ std::tie(std::ignore, linearizedIndices) =
memref::getLinearizedMemRefOffsetAndSize(
rewriter, loc, elementBitWidth, elementBitWidth,
stridedMetadata.getConstifiedMixedOffset(),
@@ -173,8 +172,7 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
// the maximum of each dimension size * stride.
SmallVector<AffineExpr> productExpressions;
SmallVector<Value> productResults;
- unsigned sourceRank =
- cast<ShapedType>(readOp.getSource().getType()).getRank();
+ unsigned sourceRank = cast<ShapedType>(src.getType()).getRank();
SmallVector<AffineExpr> symbols(2 * sourceRank);
SmallVector<Value> offsetValues(2 * sourceRank);
diff --git a/mlir/test/Dialect/AMDGPU/transfer-read-to-load.mlir b/mlir/test/Dialect/AMDGPU/transfer-read-to-load.mlir
index 776a047e6a85d..91b6d8b3137c8 100644
--- a/mlir/test/Dialect/AMDGPU/transfer-read-to-load.mlir
+++ b/mlir/test/Dialect/AMDGPU/transfer-read-to-load.mlir
@@ -9,55 +9,72 @@ func.func @transfer_to_maskedload_fatrawbuffer(%mem : memref<8x8xf32, #amdgpu.ad
%res = vector.transfer_read %mem[%idx, %idx], %cf0, %mask {in_bounds = [true]} : memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<4xf32>
return %res : vector<4xf32>
}
-// CHECK: %[[CST:.*]] = arith.constant 0.0
-// CHECK: %[[C0:.*]] = arith.constant 0
-// CHECK: %[[C1:.*]] = arith.constant 1
-// CHECK: %[[MUL0:.*]] = arith.muli %[[ARG1]], %[[C1]]
-// CHECK: %[[ADD0:.*]] = arith.addi %[[C0]], %[[MUL0]]
-// CHECK: %[[C8:.*]] = arith.constant 8
-// CHECK: %[[MUL1:.*]] = arith.muli %[[C1]], %[[C8]]
-// CHECK: %[[MUL2:.*]] = arith.muli %[[ARG1]], %[[MUL1]]
-// CHECK: %[[ADD1:.*]] = arith.addi %[[ADD0]], %[[MUL2]]
-// CHECK: %[[C4:.*]] = arith.constant 4
-// CHECK: %[[ADD2:.*]] = arith.addi %[[ADD1]], %[[C4]]
-
-// CHECK: %[[MUL3:.*]] = arith.muli %[[C1]], %[[C8]]
-// CHECK: %[[MUL4:.*]] = arith.muli
-
-// CHECK: %[[CMP:.*]] = arith.cmpi ule, %[[ADD2]], %[[MUL4]]
-// CHECK: %[[IF:.*]] = scf.if %[[CMP]] -> (vector<4xf32>) {
-
-// CHECK: %[[SPLAT:.*]] = vector.splat %[[CST]]
-// CHECK: %[[LOAD:.*]] = vector.load %arg0[%arg1, %arg1]
-// CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[SPLAT]]
+
+// CHECK: %[[FALSE:.*]] = arith.constant false
+// CHECK: %[[IF:.*]] = scf.if %[[FALSE]] -> (vector<4xf32>) {
+// CHECK: vector.maskedload %[[ARG0]][%[[ARG1]], %[[ARG1]]], %[[ARG2]]
// CHECK: } else {
-// CHECK: %[[LOAD:.*]] = vector.transfer_read %arg0[%arg1, %arg1], %[[CST]], %arg2 {amdgpu.transformed, in_bounds = [true]} : memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<4xf32>
+// CHECK: %[[LOAD:.*]] = vector.load %arg0[%arg1, %arg1]
+// CHECK: %[[SELECT:.*]] = arith.select %[[ARG2]], %[[LOAD]]
// CHECK: return %[[IF]] : vector<4xf32>
// -----
-// CHECK-LABEL: func @transfer_to_maskedload_fatrawbuffer_dynamic(
-// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xf32, #amdgpu.address_space<fat_raw_buffer>>
-// CHECK-SAME: %[[ARG1:.*]]: index
-// CHECK-SAME: %[[ARG2:.*]]: vector<4xi1>
-func.func @transfer_to_maskedload_fatrawbuffer_dynamic(%mem : memref<?x?xf32, #amdgpu.address_space<fat_raw_buffer>>, %idx : index, %mask : vector<4xi1>) -> vector<4xf32> {
- %cf0 = arith.constant 0.0 : f32
- %res = vector.transfer_read %mem[%idx, %idx], %cf0, %mask {in_bounds = [true]} : memref<?x?xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<4xf32>
- return %res : vector<4xf32>
+// CHECK: #map = affine_map<()[s0, s1] -> (s0 * 8 + s1)>
+// CHECK-LABEL: func @transfer_to_maskedload_fatrawbuffer_f16(
+// CHECK-SAME: %[[ARG0:.+]]: memref<8x8xf16, #amdgpu.address_space<fat_raw_buffer>>,
+// CHECK-SAME: %[[ARG1:.+]]: index, %[[ARG2:.+]]: index,
+// CHECK-SAME: %[[ARG3:.+]]: vector<4xi1>)
+func.func @transfer_to_maskedload_fatrawbuffer_f16(%mem : memref<8x8xf16, #amdgpu.address_space<fat_raw_buffer>>, %idx0 : index, %idx1 : index, %mask : vector<4xi1>) -> vector<4xf16> {
+ %cf0 = arith.constant 0.0 : f16
+ %res = vector.transfer_read %mem[%idx0, %idx1], %cf0, %mask {in_bounds = [true]} : memref<8x8xf16, #amdgpu.address_space<fat_raw_buffer>>, vector<4xf16>
+ return %res : vector<4xf16>
}
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0
+// CHECK-DAG: %[[SIZE:.*]] = arith.constant 64
+// CHECK-DAG: %[[BYTES:.*]] = arith.constant 2
+// CHECK-DAG: %[[VECTORSIZE:.*]] = arith.constant 4
+
+// CHECK: %[[LINEAR:.*]] = affine.apply #map()[%[[ARG1]], %[[ARG2]]]
+// CHECK: %[[DELTA:.*]] = arith.subi %[[SIZE]], %[[LINEAR]]
+// CHECK: %[[COND1:.*]] = arith.cmpi ule, %[[DELTA]], %[[VECTORSIZE]]
+
+// CHECK: %[[DELTABYTES:.*]] = arith.muli %[[DELTA]], %[[BYTES]]
+// CHECK: %[[REM:.*]] = arith.remui %[[DELTABYTES]], %[[BYTES]]
+// CHECK: %[[COND2:.*]] = arith.cmpi ne, %[[REM]], %[[C0]]
+
+// CHECK: %[[COND:.*]] = arith.andi %[[COND1]], %[[COND2]]
+// CHECK: %[[IF:.*]] = scf.if %[[COND]] -> (vector<4xf16>) {
+// CHECK: vector.maskedload %[[ARG0]][%[[ARG1]], %[[ARG2]]], %[[ARG3]]
+// CHECK: } else {
+// CHECK: %[[LOAD:.*]] = vector.load %[[ARG0]][%[[ARG1]], %[[ARG2]]]
+// CHECK: return %[[IF]] : vector<4xf16>
+
+// -----
-// CHECK: %[[C1:.*]] = arith.constant 1
-// CHECK: %[[DIM1:.*]] = memref.dim %[[ARG0]], %[[C1]]
-// CHECK: %[[MUL0:.*]] = arith.muli %{{.*}}, %[[DIM1]]
-// CHECK: %[[C0:.*]] = arith.constant 0
-// CHECK: %[[DIM0:.*]] = memref.dim %[[ARG0]], %[[C0]]
-// CHECK: %[[MUL1:.*]] = arith.muli %{{.*}}, %[[DIM0]]
+// CHECK: #map = affine_map<()[s0, s1, s2] -> (s0 * s1 + s2)>
+// CHECK: #map1 = affine_map<()[s0, s1, s2, s3] -> (s0 * s1, s2 * s3)>
+// CHECK-LABEL: func @transfer_to_maskedload_fatrawbuffer_dynamic_i8(
+// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xi8, #amdgpu.address_space<fat_raw_buffer>>
+// CHECK-SAME: %[[ARG1:.*]]: index, %[[ARG2:.*]]: index
+// CHECK-SAME: %[[ARG3:.*]]: vector<4xi1>
+func.func @transfer_to_maskedload_fatrawbuffer_dynamic_i8(%mem : memref<?x?xi8, #amdgpu.address_space<fat_raw_buffer>>, %idx0 : index, %idx1 : index, %mask : vector<4xi1>) -> vector<4xi8> {
+ %cf0 = arith.constant 0 : i8
+ %res = vector.transfer_read %mem[%idx0, %idx1], %cf0, %mask {in_bounds = [true]} : memref<?x?xi8, #amdgpu.address_space<fat_raw_buffer>>, vector<4xi8>
+ return %res : vector<4xi8>
+}
-// CHECK: %[[C1_1:.*]] = arith.constant 1
-// CHECK: %[[DIM1_1:.*]] = memref.dim %[[ARG0]], %[[C1_1]]
-// CHECK: %[[MUL2:.*]] = arith.muli %{{.*}}, %[[DIM1_1]]
+// CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<4xi8>
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[C4:.*]] = arith.constant 4 : index
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG0]]
+// CHECK: %[[LINEAR:.*]] = affine.apply #map()[%[[ARG1]], %[[STRIDES]]#0, %[[ARG2]]]
+// CHECK: %[[SIZE:.*]] = affine.max #map1()[%[[STRIDES]]#0, %[[SIZES]]#0, %[[C1]], %[[SIZES]]#1]
+// CHECK: %[[IF:.*]] = scf.if
+// CHECK: return
// -----
@@ -70,8 +87,8 @@ func.func @transfer_to_maskedload_regular(%mem : memref<8x8xf32>, %idx : index,
%res = vector.transfer_read %mem[%idx, %idx], %cf0, %mask {in_bounds = [true]} : memref<8x8xf32>, vector<4xf32>
return %res : vector<4xf32>
}
-// CHECK: %[[CST:.*]] = arith.constant 0.0
-// CHECK: %[[RES:.*]] = vector.transfer_read %arg0[%arg1, %arg1], %[[CST]], %arg2 {in_bounds = [true]} : memref<8x8xf32>, vector<4xf32>
+// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
+// CHECK: %[[RES:.*]] = vector.maskedload %[[ARG0]][%[[ARG1]], %[[ARG1]]], %[[ARG2]], %[[CST]]
// CHECK: return %[[RES]] : vector<4xf32>
// -----
@@ -85,8 +102,8 @@ func.func @transfer_to_maskedload_addrspace(%mem : memref<8x8xf32, #gpu.address_
%res = vector.transfer_read %mem[%idx, %idx], %cf0, %mask {in_bounds = [true]} : memref<8x8xf32, #gpu.address_space<workgroup>>, vector<4xf32>
return %res : vector<4xf32>
}
-// CHECK: %[[CST:.*]] = arith.constant 0.0
-// CHECK: %[[RES:.*]] = vector.transfer_read %arg0[%arg1, %arg1], %[[CST]], %arg2 {in_bounds = [true]} : memref<8x8xf32, #gpu.address_space<workgroup>>, vector<4xf32>
+// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
+// CHECK: %[[RES:.*]] = vector.maskedload %[[ARG0]][%[[ARG1]], %[[ARG1]]], %[[ARG2]], %[[CST]]
// CHECK: return %[[RES]] : vector<4xf32>
// -----
@@ -103,10 +120,11 @@ func.func @transfer_broadcasting(%mem : memref<8x8xf32, #amdgpu.address_space<fa
: memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<4xf32>
return %res : vector<4xf32>
}
-// CHECK: %[[CST:.*]] = arith.constant 0.0
-// CHECK: %[[SPLAT:.*]] = vector.splat %[[CST]]
+// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<1xf32>
+// CHECK: %[[FALSE:.*]] = arith.constant false
+// CHECK: %[[IF:.*]] = scf.if %[[FALSE]] -> (vector<4xf32>) {
// CHECK: %[[LOAD:.*]] = vector.load %arg0[%arg1, %arg1]
-// CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[SPLAT]]
+// CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[CST]]
// CHECK: %[[BROADCAST:.*]] = vector.broadcast %[[SELECT]] : vector<1xf32> to vector<4xf32>
// -----
@@ -122,7 +140,8 @@ func.func @transfer_scalar(%mem : memref<8x8xf32, #amdgpu.address_space<fat_raw_
: memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<1xf32>
return %res : vector<1xf32>
}
-// CHECK: %[[CST:.*]] = arith.constant 0.0
-// CHECK: %[[SPLAT:.*]] = vector.splat %[[CST]]
-// CHECK: %[[LOAD:.*]] = vector.load %arg0[%arg1, %arg1]
-// CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[SPLAT]]
+// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<1xf32>
+// CHECK: %[[FALSE:.*]] = arith.constant false
+// CHECK: %[[IF:.*]] = scf.if %[[FALSE]] -> (vector<1xf32>) {
+// CHECK: %[[LOAD:.*]] = vector.load %[[ARG0]][%[[ARG1]], %[[ARG1]]]
+// CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[CST]]
>From ddee079d1c12939e82d0972031faecec6ac771d5 Mon Sep 17 00:00:00 2001
From: jerryyin <zhuoryin at amd.com>
Date: Mon, 14 Apr 2025 18:43:02 +0000
Subject: [PATCH 6/6] Addressing review feedbacks
---
.../AMDGPU/Transforms/TransferReadToLoad.cpp | 74 +++++++++++++------
.../Dialect/AMDGPU/transfer-read-to-load.mlir | 19 +++--
2 files changed, 61 insertions(+), 32 deletions(-)
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp
index 4cab311f2b6c7..a7d402752deb7 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp
@@ -24,7 +24,7 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "mlir/Transforms/WalkPatternRewriteDriver.h"
+#include "llvm/Support/MathExtras.h"
namespace mlir::amdgpu {
#define GEN_PASS_DEF_AMDGPUTRANSFERREADTOLOADPASS
@@ -76,6 +76,9 @@ static LogicalResult transferPreconditions(
if (!memRefType.isLastDimUnitStride())
return rewriter.notifyMatchFailure(xferOp, "!= 1 stride needs VectorToSCF");
+ if (memRefType.getElementTypeBitWidth() < 8)
+ return rewriter.notifyMatchFailure(xferOp, "unsupported sub-byte type");
+
// If there is broadcasting involved then we first load the unbroadcasted
// vector, and then broadcast it with `vector.broadcast`.
ArrayRef<int64_t> vectorShape = xferOp.getVectorType().getShape();
@@ -127,6 +130,9 @@ static Value createVectorLoadForMaskedLoad(OpBuilder &builder, Location loc,
return res;
}
+static constexpr char kTransferReadNeedsMask[] =
+ "amdgpu.buffer_transfer_read_needs_mask";
+
namespace {
struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
@@ -134,7 +140,7 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
PatternRewriter &rewriter) const override {
- if (readOp->hasAttr("amdgpu.buffer_transfer_read_needs_mask"))
+ if (readOp->hasAttr(kTransferReadNeedsMask))
return failure();
bool requiresBroadcasting = false;
@@ -154,38 +160,60 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
auto stridedMetadata =
rewriter.create<memref::ExtractStridedMetadataOp>(loc, src);
+ SmallVector<OpFoldResult> strides =
+ stridedMetadata.getConstifiedMixedStrides();
+ SmallVector<OpFoldResult> sizes =
+ stridedMetadata.getConstifiedMixedSizes();
+ OpFoldResult offset =
+ stridedMetadata.getConstifiedMixedOffset();
OpFoldResult linearizedIndices;
std::tie(std::ignore, linearizedIndices) =
- memref::getLinearizedMemRefOffsetAndSize(
- rewriter, loc, elementBitWidth, elementBitWidth,
- stridedMetadata.getConstifiedMixedOffset(),
- stridedMetadata.getConstifiedMixedSizes(),
- stridedMetadata.getConstifiedMixedStrides(), indices);
- Value linearIndex =
- getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices);
+ memref::getLinearizedMemRefOffsetAndSize(rewriter, loc, elementBitWidth,
+ elementBitWidth, offset, sizes,
+ strides, indices);
// TODO(jerryyin): Fix the getLinearizedMemRefOffsetAndSize() function
// Note below doesn't give the correct result for the linearized size.
// Value totalSize = getValueOrCreateConstantIndexOp(
// rewriter, loc, linearizedInfo.linearizedSize);
- // It compute the mutiplied sizes of all dimensions instead of taking
+ // It computes the multiplied sizes of all dimensions instead of taking
// the maximum of each dimension size * stride.
SmallVector<AffineExpr> productExpressions;
SmallVector<Value> productResults;
unsigned sourceRank = cast<ShapedType>(src.getType()).getRank();
SmallVector<AffineExpr> symbols(2 * sourceRank);
- SmallVector<Value> offsetValues(2 * sourceRank);
+ SmallVector<Value> offsetValues;
bindSymbolsList(rewriter.getContext(), MutableArrayRef{symbols});
+
+ size_t symbolIndex = 0;
for (size_t i = 0; i < sourceRank; ++i) {
- unsigned offsetIdx = 2 * i;
- productExpressions.push_back(symbols[offsetIdx] * symbols[offsetIdx + 1]);
- offsetValues[offsetIdx] = stridedMetadata.getStrides()[i];
- offsetValues[offsetIdx + 1] = stridedMetadata.getSizes()[i];
+ AffineExpr strideExpr, sizeExpr;
+ OpFoldResult stride = strides[i];
+ OpFoldResult size = sizes[i];
+ if (auto constantStride =
+ getConstantIntValue(stride)) {
+ strideExpr = rewriter.getAffineConstantExpr(*constantStride);
+ } else {
+ strideExpr = symbols[symbolIndex++];
+ offsetValues.push_back(getValueOrCreateConstantIndexOp(
+ rewriter, loc, stride));
+ }
+
+ if (auto constantSize =
+ getConstantIntValue(size)) {
+ sizeExpr = rewriter.getAffineConstantExpr(*constantSize);
+ } else {
+ sizeExpr = symbols[symbolIndex++];
+ offsetValues.push_back(getValueOrCreateConstantIndexOp(
+ rewriter, loc, size));
+ }
+
+ productExpressions.push_back(strideExpr * sizeExpr);
}
AffineMap maxMap = AffineMap::get(
- /*dimCount=*/0, /*symbolCount=*/symbols.size(), productExpressions,
+ /*dimCount=*/0, /*symbolCount=*/symbolIndex, productExpressions,
rewriter.getContext());
Value totalSize =
rewriter.create<affine::AffineMaxOp>(loc, maxMap, offsetValues);
@@ -193,32 +221,35 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
// delta = bufferSize - linearizedOffset
Value vectorSizeOffset =
rewriter.create<arith::ConstantIndexOp>(loc, vectorSize);
+ Value linearIndex =
+ getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices);
Value delta = rewriter.create<arith::SubIOp>(loc, totalSize, linearIndex);
// 1) check if delta < vectorSize
Value isOutofBounds = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::ule, delta, vectorSizeOffset);
+ loc, arith::CmpIPredicate::ult, delta, vectorSizeOffset);
// 2) check if (detla_bytes % (32 / elementBitwidth) != 0)
Value deltaBytes = rewriter.create<arith::MulIOp>(
loc, delta,
rewriter.create<arith::ConstantIndexOp>(loc, elementBitWidth / 8));
Value elementsPerWord = rewriter.create<arith::ConstantIndexOp>(
- loc, elementBitWidth < 32 ? 32 / elementBitWidth : 1);
+ loc, llvm::divideCeil(32, elementBitWidth));
Value isNotWordAligned = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::ne,
rewriter.create<arith::RemUIOp>(loc, deltaBytes, elementsPerWord),
rewriter.create<arith::ConstantIndexOp>(loc, 0));
// We take the fallback of transfer_read default lowering only it is both
- // out-of-bounds and not word aligned.
+ // out-of-bounds and not word aligned. The fallback ensures correct results
+ // when loading at the boundary of the buffer since buffer load returns
+ // inconsistent zeros for the whole word when boundary is crossed.
Value ifCondition =
rewriter.create<arith::AndIOp>(loc, isOutofBounds, isNotWordAligned);
auto thenBuilder = [&](OpBuilder &builder, Location loc) {
Operation *read = builder.clone(*readOp.getOperation());
- read->setAttr("amdgpu.buffer_transfer_read_needs_mask",
- builder.getUnitAttr());
+ read->setAttr(kTransferReadNeedsMask, builder.getUnitAttr());
Value readResult = read->getResult(0);
builder.create<scf::YieldOp>(loc, readResult);
};
@@ -243,7 +274,6 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
void mlir::amdgpu::populateAmdgpuTransferReadToLoadPatterns(
RewritePatternSet &patterns) {
patterns.add<TransferReadLowering>(patterns.getContext());
- vector::populateVectorTransferLoweringPatterns(patterns);
}
struct AmdgpuTransferReadToLoadPass final
diff --git a/mlir/test/Dialect/AMDGPU/transfer-read-to-load.mlir b/mlir/test/Dialect/AMDGPU/transfer-read-to-load.mlir
index 91b6d8b3137c8..d0805b6b8a973 100644
--- a/mlir/test/Dialect/AMDGPU/transfer-read-to-load.mlir
+++ b/mlir/test/Dialect/AMDGPU/transfer-read-to-load.mlir
@@ -12,7 +12,7 @@ func.func @transfer_to_maskedload_fatrawbuffer(%mem : memref<8x8xf32, #amdgpu.ad
// CHECK: %[[FALSE:.*]] = arith.constant false
// CHECK: %[[IF:.*]] = scf.if %[[FALSE]] -> (vector<4xf32>) {
-// CHECK: vector.maskedload %[[ARG0]][%[[ARG1]], %[[ARG1]]], %[[ARG2]]
+// CHECK: vector.transfer_read %[[ARG0]][%[[ARG1]], %[[ARG1]]]
// CHECK: } else {
// CHECK: %[[LOAD:.*]] = vector.load %arg0[%arg1, %arg1]
@@ -39,7 +39,7 @@ func.func @transfer_to_maskedload_fatrawbuffer_f16(%mem : memref<8x8xf16, #amdgp
// CHECK: %[[LINEAR:.*]] = affine.apply #map()[%[[ARG1]], %[[ARG2]]]
// CHECK: %[[DELTA:.*]] = arith.subi %[[SIZE]], %[[LINEAR]]
-// CHECK: %[[COND1:.*]] = arith.cmpi ule, %[[DELTA]], %[[VECTORSIZE]]
+// CHECK: %[[COND1:.*]] = arith.cmpi ult, %[[DELTA]], %[[VECTORSIZE]]
// CHECK: %[[DELTABYTES:.*]] = arith.muli %[[DELTA]], %[[BYTES]]
// CHECK: %[[REM:.*]] = arith.remui %[[DELTABYTES]], %[[BYTES]]
@@ -47,7 +47,7 @@ func.func @transfer_to_maskedload_fatrawbuffer_f16(%mem : memref<8x8xf16, #amdgp
// CHECK: %[[COND:.*]] = arith.andi %[[COND1]], %[[COND2]]
// CHECK: %[[IF:.*]] = scf.if %[[COND]] -> (vector<4xf16>) {
-// CHECK: vector.maskedload %[[ARG0]][%[[ARG1]], %[[ARG2]]], %[[ARG3]]
+// CHECK: vector.transfer_read %[[ARG0]][%[[ARG1]], %[[ARG2]]]
// CHECK: } else {
// CHECK: %[[LOAD:.*]] = vector.load %[[ARG0]][%[[ARG1]], %[[ARG2]]]
// CHECK: return %[[IF]] : vector<4xf16>
@@ -55,7 +55,7 @@ func.func @transfer_to_maskedload_fatrawbuffer_f16(%mem : memref<8x8xf16, #amdgp
// -----
// CHECK: #map = affine_map<()[s0, s1, s2] -> (s0 * s1 + s2)>
-// CHECK: #map1 = affine_map<()[s0, s1, s2, s3] -> (s0 * s1, s2 * s3)>
+// CHECK: #map1 = affine_map<()[s0, s1, s2] -> (s0 * s1, s2)>
// CHECK-LABEL: func @transfer_to_maskedload_fatrawbuffer_dynamic_i8(
// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xi8, #amdgpu.address_space<fat_raw_buffer>>
// CHECK-SAME: %[[ARG1:.*]]: index, %[[ARG2:.*]]: index
@@ -69,10 +69,9 @@ func.func @transfer_to_maskedload_fatrawbuffer_dynamic_i8(%mem : memref<?x?xi8,
// CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<4xi8>
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[C4:.*]] = arith.constant 4 : index
-// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG0]]
// CHECK: %[[LINEAR:.*]] = affine.apply #map()[%[[ARG1]], %[[STRIDES]]#0, %[[ARG2]]]
-// CHECK: %[[SIZE:.*]] = affine.max #map1()[%[[STRIDES]]#0, %[[SIZES]]#0, %[[C1]], %[[SIZES]]#1]
+// CHECK: %[[SIZE:.*]] = affine.max #map1()[%[[STRIDES]]#0, %[[SIZES]]#0, %[[SIZES]]#1]
// CHECK: %[[IF:.*]] = scf.if
// CHECK: return
@@ -87,8 +86,8 @@ func.func @transfer_to_maskedload_regular(%mem : memref<8x8xf32>, %idx : index,
%res = vector.transfer_read %mem[%idx, %idx], %cf0, %mask {in_bounds = [true]} : memref<8x8xf32>, vector<4xf32>
return %res : vector<4xf32>
}
-// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
-// CHECK: %[[RES:.*]] = vector.maskedload %[[ARG0]][%[[ARG1]], %[[ARG1]]], %[[ARG2]], %[[CST]]
+// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00
+// CHECK: %[[RES:.*]] = vector.transfer_read %[[ARG0]][%[[ARG1]], %[[ARG1]]], %[[CST]], %[[ARG2]]
// CHECK: return %[[RES]] : vector<4xf32>
// -----
@@ -102,8 +101,8 @@ func.func @transfer_to_maskedload_addrspace(%mem : memref<8x8xf32, #gpu.address_
%res = vector.transfer_read %mem[%idx, %idx], %cf0, %mask {in_bounds = [true]} : memref<8x8xf32, #gpu.address_space<workgroup>>, vector<4xf32>
return %res : vector<4xf32>
}
-// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
-// CHECK: %[[RES:.*]] = vector.maskedload %[[ARG0]][%[[ARG1]], %[[ARG1]]], %[[ARG2]], %[[CST]]
+// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00
+// CHECK: %[[RES:.*]] = vector.transfer_read %[[ARG0]][%[[ARG1]], %[[ARG1]]], %[[CST]], %[[ARG2]]
// CHECK: return %[[RES]] : vector<4xf32>
// -----
More information about the llvm-commits
mailing list