[llvm] [mlir] [MLIR][AMDGPU] Adding dynamic size check to avoid subword buffer load (PR #135014)

Zhuoran Yin via llvm-commits llvm-commits at lists.llvm.org
Mon Apr 14 14:09:32 PDT 2025


https://github.com/jerryyin updated https://github.com/llvm/llvm-project/pull/135014

>From 27c5497a5dc96fe8a05e775e0c9f793927c5d59c Mon Sep 17 00:00:00 2001
From: jerryyin <zhuoryin at amd.com>
Date: Tue, 8 Apr 2025 17:30:05 +0000
Subject: [PATCH 1/6] Adding dynamic size check to avoid subword buffer load

---
 .../mlir/Dialect/AMDGPU/Transforms/Passes.td  | 13 ++-
 .../Dialect/AMDGPU/Transforms/CMakeLists.txt  |  1 +
 .../AMDGPU/Transforms/TransferReadToLoad.cpp  | 93 ++++++++++++++++---
 .../Dialect/AMDGPU/transfer-read-to-load.mlir | 48 +++++++++-
 .../llvm-project-overlay/mlir/BUILD.bazel     |  1 +
 5 files changed, 137 insertions(+), 19 deletions(-)

diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
index 761caa448a57c..0e858108acf35 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
@@ -54,15 +54,20 @@ def AmdgpuResolveStridedMetadataPass : Pass<"amdgpu-resolve-strided-metadata"> {
 def AmdgpuTransferReadToLoadPass : Pass<"amdgpu-transfer-read-to-load"> {
   let summary = "Lower the operations from the vector transfer_read to vector load";
   let description = [{
-    This pass creates a transfer read op lowering. A vector trasfer read op
-    will be lowered to a combination of vector.load, arith.select and
-    vector.broadcast.
+    This pass creates a transfer read op lowering optimization. The lowering
+    will produce a conditional check at runtime. If within bounds, a vector
+    trasfer read op will be lowered to a combination of vector.load, arith.select
+    and vector.broadcast. If not, it will fallback to the default lowering
+    of the transfer_read op.
 
     This pattern will make it possible for masked transfer_read to be lowered
     towards buffer load with bounds check, allowing a more optimized global
     load accessing pattern compared with existing implementation of
     llvm.intr.masked.load on vectors.
   }];
-  let dependentDialects = [];
+  let dependentDialects = [
+    "scf::SCFDialect",
+    "memref::MemRefDialect"
+  ];
 }
 #endif // MLIR_DIALECT_AMDGPU_TRANSFORMS_PASSES_TD_
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
index bc5b6e9186449..8709a27e0168e 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
@@ -14,6 +14,7 @@ add_mlir_dialect_library(MLIRAMDGPUTransforms
   MLIRAMDGPUUtils
   MLIRArithDialect
   MLIRMemRefDialect
+  MLIRSCFDialect
   MLIRVectorDialect
   MLIRControlFlowDialect
   MLIRFuncDialect
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp
index 3c1a2eb962037..519f695d99f91 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp
@@ -9,6 +9,8 @@
 #include "mlir/Dialect/AMDGPU/Transforms/Passes.h"
 
 #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/PatternMatch.h"
@@ -108,6 +110,8 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
 
   LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
                                 PatternRewriter &rewriter) const override {
+    if (readOp->hasAttr("amdgpu.transformed"))
+      return failure();
 
     bool requiresBroadcasting = false;
     VectorType unbroadcastedVectorType;
@@ -117,20 +121,85 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
     }
 
     Location loc = readOp.getLoc();
-    Value fill = rewriter.create<vector::SplatOp>(loc, unbroadcastedVectorType,
-                                                  readOp.getPadding());
-    Value load = rewriter.create<vector::LoadOp>(
-        loc, unbroadcastedVectorType, readOp.getSource(), readOp.getIndices());
-    Value res = rewriter.create<arith::SelectOp>(loc, unbroadcastedVectorType,
-                                                 readOp.getMask(), load, fill);
-
-    // Insert a broadcasting op if required.
-    if (requiresBroadcasting) {
-      res = rewriter.create<vector::BroadcastOp>(loc, readOp.getVectorType(),
-                                                 res);
+    Value src = readOp.getSource();
+    MemRefType memRefType = cast<MemRefType>(src.getType());
+    ArrayRef<int64_t> shape = memRefType.getShape();
+
+    Value linearIndex = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+    Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+    Value stride = one;
+
+    // Compute the linear index by linearIndex += indices[i] * stride
+    for (int i = shape.size() - 1; i >= 0; --i) {
+      Value currentIndex = readOp.getIndices()[i];
+      Value strideIndexed =
+          rewriter.create<arith::MulIOp>(loc, currentIndex, stride);
+      linearIndex =
+          rewriter.create<arith::AddIOp>(loc, linearIndex, strideIndexed);
+
+      if (i == 0)
+        break;
+
+      // Update stride for the next dimension
+      Value nextStride;
+      if (shape[i] != ShapedType::kDynamic) {
+        nextStride = rewriter.create<arith::ConstantIndexOp>(loc, shape[i]);
+      } else {
+        nextStride = rewriter.create<memref::DimOp>(loc, src, i);
+      }
+      stride = rewriter.create<arith::MulIOp>(loc, stride, nextStride);
+    }
+
+    // Add vector size offset to linear index
+    VectorType vectorType = readOp.getVectorType();
+    int64_t vectorSize = vectorType.getNumElements();
+    Value vectorSizeOffset =
+        rewriter.create<arith::ConstantIndexOp>(loc, vectorSize);
+    Value upperBoundIndex =
+        rewriter.create<arith::AddIOp>(loc, linearIndex, vectorSizeOffset);
+
+    Value totalSize = one;
+    for (size_t i = 0; i < shape.size(); ++i) {
+      Value dimensionSize;
+      if (shape[i] != ShapedType::kDynamic) {
+        dimensionSize = rewriter.create<arith::ConstantIndexOp>(loc, shape[i]);
+      } else {
+        dimensionSize = rewriter.create<memref::DimOp>(loc, src, i);
+      }
+      totalSize = rewriter.create<arith::MulIOp>(loc, totalSize, dimensionSize);
     }
 
-    rewriter.replaceOp(readOp, res);
+    Value isInBounds = rewriter.create<arith::CmpIOp>(
+        loc, arith::CmpIPredicate::ule, upperBoundIndex, totalSize);
+
+    auto thenBuilder = [&](OpBuilder &builder, Location loc) {
+      Value fill = builder.create<vector::SplatOp>(loc, unbroadcastedVectorType,
+                                                   readOp.getPadding());
+      Value load = builder.create<vector::LoadOp>(loc, unbroadcastedVectorType,
+                                                  readOp.getSource(),
+                                                  readOp.getIndices());
+      Value res = builder.create<arith::SelectOp>(loc, unbroadcastedVectorType,
+                                                  readOp.getMask(), load, fill);
+
+      // Insert a broadcasting op if required.
+      if (requiresBroadcasting) {
+        res = builder.create<vector::BroadcastOp>(loc, readOp.getVectorType(),
+                                                  res);
+      }
+      rewriter.create<scf::YieldOp>(loc, res);
+    };
+
+    auto elseBuilder = [&](OpBuilder &builder, Location loc) {
+      Operation *read = builder.clone(*readOp.getOperation());
+      read->setAttr("amdgpu.transformed", builder.getUnitAttr());
+      Value readResult = read->getResult(0);
+      builder.create<scf::YieldOp>(loc, readResult);
+    };
+
+    auto ifOp =
+        rewriter.create<scf::IfOp>(loc, isInBounds, thenBuilder, elseBuilder);
+
+    rewriter.replaceOp(readOp, ifOp);
 
     return success();
   }
diff --git a/mlir/test/Dialect/AMDGPU/transfer-read-to-load.mlir b/mlir/test/Dialect/AMDGPU/transfer-read-to-load.mlir
index 3e1283579f2b1..776a047e6a85d 100644
--- a/mlir/test/Dialect/AMDGPU/transfer-read-to-load.mlir
+++ b/mlir/test/Dialect/AMDGPU/transfer-read-to-load.mlir
@@ -10,10 +10,54 @@ func.func @transfer_to_maskedload_fatrawbuffer(%mem : memref<8x8xf32, #amdgpu.ad
   return %res : vector<4xf32>
 }
 // CHECK: %[[CST:.*]] = arith.constant 0.0
+// CHECK: %[[C0:.*]] = arith.constant 0
+// CHECK: %[[C1:.*]] = arith.constant 1
+// CHECK: %[[MUL0:.*]] = arith.muli %[[ARG1]], %[[C1]]
+// CHECK: %[[ADD0:.*]] = arith.addi %[[C0]], %[[MUL0]]
+// CHECK: %[[C8:.*]] = arith.constant 8
+// CHECK: %[[MUL1:.*]] = arith.muli %[[C1]], %[[C8]]
+// CHECK: %[[MUL2:.*]] = arith.muli %[[ARG1]], %[[MUL1]]
+// CHECK: %[[ADD1:.*]] = arith.addi %[[ADD0]], %[[MUL2]]
+// CHECK: %[[C4:.*]] = arith.constant 4
+// CHECK: %[[ADD2:.*]] = arith.addi %[[ADD1]], %[[C4]]
+
+// CHECK: %[[MUL3:.*]] = arith.muli %[[C1]], %[[C8]]
+// CHECK: %[[MUL4:.*]] = arith.muli
+
+// CHECK: %[[CMP:.*]] = arith.cmpi ule, %[[ADD2]], %[[MUL4]]
+// CHECK: %[[IF:.*]] = scf.if %[[CMP]] -> (vector<4xf32>) {
+
 // CHECK: %[[SPLAT:.*]] = vector.splat %[[CST]]
 // CHECK: %[[LOAD:.*]] = vector.load %arg0[%arg1, %arg1]
 // CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[SPLAT]]
-// CHECK: return %[[SELECT]] : vector<4xf32>
+
+// CHECK: } else {
+// CHECK: %[[LOAD:.*]] = vector.transfer_read %arg0[%arg1, %arg1], %[[CST]], %arg2 {amdgpu.transformed, in_bounds = [true]} : memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<4xf32>
+
+// CHECK: return %[[IF]] : vector<4xf32>
+
+// -----
+
+// CHECK-LABEL: func @transfer_to_maskedload_fatrawbuffer_dynamic(
+// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xf32, #amdgpu.address_space<fat_raw_buffer>>
+// CHECK-SAME: %[[ARG1:.*]]: index
+// CHECK-SAME: %[[ARG2:.*]]: vector<4xi1>
+func.func @transfer_to_maskedload_fatrawbuffer_dynamic(%mem : memref<?x?xf32, #amdgpu.address_space<fat_raw_buffer>>, %idx : index, %mask : vector<4xi1>) -> vector<4xf32> {
+  %cf0 = arith.constant 0.0 : f32
+  %res = vector.transfer_read %mem[%idx, %idx], %cf0, %mask {in_bounds = [true]} : memref<?x?xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<4xf32>
+  return %res : vector<4xf32>
+}
+
+// CHECK: %[[C1:.*]] = arith.constant 1
+// CHECK: %[[DIM1:.*]] = memref.dim %[[ARG0]], %[[C1]]
+// CHECK: %[[MUL0:.*]] = arith.muli %{{.*}}, %[[DIM1]]
+// CHECK: %[[C0:.*]] = arith.constant 0
+// CHECK: %[[DIM0:.*]] = memref.dim %[[ARG0]], %[[C0]]
+// CHECK: %[[MUL1:.*]] = arith.muli %{{.*}}, %[[DIM0]]
+
+// CHECK: %[[C1_1:.*]] = arith.constant 1
+// CHECK: %[[DIM1_1:.*]] = memref.dim %[[ARG0]], %[[C1_1]]
+// CHECK: %[[MUL2:.*]] = arith.muli %{{.*}}, %[[DIM1_1]]
 
 // -----
 
@@ -64,7 +108,6 @@ func.func @transfer_broadcasting(%mem : memref<8x8xf32, #amdgpu.address_space<fa
 // CHECK: %[[LOAD:.*]] = vector.load %arg0[%arg1, %arg1]
 // CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[SPLAT]]
 // CHECK: %[[BROADCAST:.*]] = vector.broadcast %[[SELECT]] : vector<1xf32> to vector<4xf32>
-// CHECK: return %[[BROADCAST]] : vector<4xf32>
 
 // -----
 
@@ -83,4 +126,3 @@ func.func @transfer_scalar(%mem : memref<8x8xf32, #amdgpu.address_space<fat_raw_
 // CHECK: %[[SPLAT:.*]] = vector.splat %[[CST]]
 // CHECK: %[[LOAD:.*]] = vector.load %arg0[%arg1, %arg1]
 // CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[SPLAT]]
-// CHECK: return %[[SELECT]] : vector<1xf32>
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 141986392917e..c4d87484fd5d5 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -1568,6 +1568,7 @@ cc_library(
         ":IR",
         ":MemRefDialect",
         ":Pass",
+        ":SCFDialect",
         ":SideEffectInterfaces",
         ":Support",
         ":TransformUtils",

>From eac8c2bd90117134bd21a0a07f0df33e75b1d081 Mon Sep 17 00:00:00 2001
From: jerryyin <zhuoryin at amd.com>
Date: Thu, 10 Apr 2025 17:30:31 +0000
Subject: [PATCH 2/6] Relaxing condition to do bounds check

---
 .../AMDGPU/Transforms/TransferReadToLoad.cpp  | 78 ++++++++++++-------
 1 file changed, 50 insertions(+), 28 deletions(-)

diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp
index 519f695d99f91..9dbcebee1252a 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp
@@ -103,6 +103,23 @@ static LogicalResult transferPreconditions(
   return success();
 }
 
+static Value createVectorLoadForMaskedLoad(OpBuilder &builder, Location loc,
+                                           vector::TransferReadOp readOp,
+                                           bool requiresBroadcasting,
+                                           VectorType unbroadcastedVectorType) {
+  Value fill = builder.create<vector::SplatOp>(loc, unbroadcastedVectorType,
+                                               readOp.getPadding());
+  Value load = builder.create<vector::LoadOp>(
+      loc, unbroadcastedVectorType, readOp.getSource(), readOp.getIndices());
+  Value res = builder.create<arith::SelectOp>(loc, unbroadcastedVectorType,
+                                              readOp.getMask(), load, fill);
+  // Insert a broadcasting op if required.
+  if (requiresBroadcasting) {
+    res = builder.create<vector::BroadcastOp>(loc, readOp.getVectorType(), res);
+  }
+  return res;
+}
+
 namespace {
 
 struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
@@ -150,14 +167,6 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
       stride = rewriter.create<arith::MulIOp>(loc, stride, nextStride);
     }
 
-    // Add vector size offset to linear index
-    VectorType vectorType = readOp.getVectorType();
-    int64_t vectorSize = vectorType.getNumElements();
-    Value vectorSizeOffset =
-        rewriter.create<arith::ConstantIndexOp>(loc, vectorSize);
-    Value upperBoundIndex =
-        rewriter.create<arith::AddIOp>(loc, linearIndex, vectorSizeOffset);
-
     Value totalSize = one;
     for (size_t i = 0; i < shape.size(); ++i) {
       Value dimensionSize;
@@ -169,35 +178,48 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
       totalSize = rewriter.create<arith::MulIOp>(loc, totalSize, dimensionSize);
     }
 
-    Value isInBounds = rewriter.create<arith::CmpIOp>(
-        loc, arith::CmpIPredicate::ule, upperBoundIndex, totalSize);
+    // delta = bufferSize - linearizedOffset
+    // 1) check if delta < vectorSize
+    VectorType vectorType = readOp.getVectorType();
+    int64_t vectorSize = vectorType.getNumElements();
+    Value vectorSizeOffset =
+        rewriter.create<arith::ConstantIndexOp>(loc, vectorSize);
+    Value delta = rewriter.create<arith::SubIOp>(loc, totalSize, linearIndex);
+    Value isOutofBounds = rewriter.create<arith::CmpIOp>(
+        loc, arith::CmpIPredicate::ule, delta, vectorSizeOffset);
+
+    // 2) check if (detla(bytes) % (32 / elementBitwidth) != 0)
+    int64_t elementBitWidth = vectorType.getElementTypeBitWidth();
+    Value deltaBytes = rewriter.create<arith::MulIOp>(
+        loc, delta,
+        rewriter.create<arith::ConstantIndexOp>(loc, elementBitWidth / 8));
+    Value elementsPerWord = rewriter.create<arith::ConstantIndexOp>(
+        loc, elementBitWidth < 32 ? 32 / elementBitWidth : 1);
+    Value isNotWordAligned = rewriter.create<arith::CmpIOp>(
+        loc, arith::CmpIPredicate::ne,
+        rewriter.create<arith::RemUIOp>(loc, deltaBytes, elementsPerWord),
+        rewriter.create<arith::ConstantIndexOp>(loc, 0));
+
+    // We take the fallback of transfer_read default lowering only it is both
+    // out-of-bounds and not word aligned.
+    Value ifCondition =
+        rewriter.create<arith::AndIOp>(loc, isOutofBounds, isNotWordAligned);
 
     auto thenBuilder = [&](OpBuilder &builder, Location loc) {
-      Value fill = builder.create<vector::SplatOp>(loc, unbroadcastedVectorType,
-                                                   readOp.getPadding());
-      Value load = builder.create<vector::LoadOp>(loc, unbroadcastedVectorType,
-                                                  readOp.getSource(),
-                                                  readOp.getIndices());
-      Value res = builder.create<arith::SelectOp>(loc, unbroadcastedVectorType,
-                                                  readOp.getMask(), load, fill);
-
-      // Insert a broadcasting op if required.
-      if (requiresBroadcasting) {
-        res = builder.create<vector::BroadcastOp>(loc, readOp.getVectorType(),
-                                                  res);
-      }
-      rewriter.create<scf::YieldOp>(loc, res);
-    };
-
-    auto elseBuilder = [&](OpBuilder &builder, Location loc) {
       Operation *read = builder.clone(*readOp.getOperation());
       read->setAttr("amdgpu.transformed", builder.getUnitAttr());
       Value readResult = read->getResult(0);
       builder.create<scf::YieldOp>(loc, readResult);
     };
 
+    auto elseBuilder = [&](OpBuilder &builder, Location loc) {
+      Value res = createVectorLoadForMaskedLoad(
+          builder, loc, readOp, requiresBroadcasting, unbroadcastedVectorType);
+      rewriter.create<scf::YieldOp>(loc, res);
+    };
+
     auto ifOp =
-        rewriter.create<scf::IfOp>(loc, isInBounds, thenBuilder, elseBuilder);
+        rewriter.create<scf::IfOp>(loc, ifCondition, thenBuilder, elseBuilder);
 
     rewriter.replaceOp(readOp, ifOp);
 

>From ca9d7df27328cebb792b3149716a89402eae97f4 Mon Sep 17 00:00:00 2001
From: jerryyin <zhuoryin at amd.com>
Date: Thu, 10 Apr 2025 20:48:50 +0000
Subject: [PATCH 3/6] Use affine for index and size computations

---
 .../AMDGPU/Transforms/TransferReadToLoad.cpp  | 92 +++++++++++--------
 1 file changed, 52 insertions(+), 40 deletions(-)

diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp
index 9dbcebee1252a..a60718041cbcb 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp
@@ -9,10 +9,15 @@
 #include "mlir/Dialect/AMDGPU/Transforms/Passes.h"
 
 #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Pass/Pass.h"
@@ -139,57 +144,64 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
 
     Location loc = readOp.getLoc();
     Value src = readOp.getSource();
-    MemRefType memRefType = cast<MemRefType>(src.getType());
-    ArrayRef<int64_t> shape = memRefType.getShape();
-
-    Value linearIndex = rewriter.create<arith::ConstantIndexOp>(loc, 0);
-    Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
-    Value stride = one;
-
-    // Compute the linear index by linearIndex += indices[i] * stride
-    for (int i = shape.size() - 1; i >= 0; --i) {
-      Value currentIndex = readOp.getIndices()[i];
-      Value strideIndexed =
-          rewriter.create<arith::MulIOp>(loc, currentIndex, stride);
-      linearIndex =
-          rewriter.create<arith::AddIOp>(loc, linearIndex, strideIndexed);
-
-      if (i == 0)
-        break;
-
-      // Update stride for the next dimension
-      Value nextStride;
-      if (shape[i] != ShapedType::kDynamic) {
-        nextStride = rewriter.create<arith::ConstantIndexOp>(loc, shape[i]);
-      } else {
-        nextStride = rewriter.create<memref::DimOp>(loc, src, i);
-      }
-      stride = rewriter.create<arith::MulIOp>(loc, stride, nextStride);
-    }
 
-    Value totalSize = one;
-    for (size_t i = 0; i < shape.size(); ++i) {
-      Value dimensionSize;
-      if (shape[i] != ShapedType::kDynamic) {
-        dimensionSize = rewriter.create<arith::ConstantIndexOp>(loc, shape[i]);
-      } else {
-        dimensionSize = rewriter.create<memref::DimOp>(loc, src, i);
-      }
-      totalSize = rewriter.create<arith::MulIOp>(loc, totalSize, dimensionSize);
+    VectorType vectorType = readOp.getVectorType();
+    int64_t vectorSize = vectorType.getNumElements();
+    int64_t elementBitWidth = vectorType.getElementTypeBitWidth();
+    // Value linearIndex = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+    SmallVector<OpFoldResult> indices = readOp.getIndices();
+
+    auto stridedMetadata =
+        rewriter.create<memref::ExtractStridedMetadataOp>(loc, src);
+    memref::LinearizedMemRefInfo linearizedInfo;
+    OpFoldResult linearizedIndices;
+    std::tie(linearizedInfo, linearizedIndices) =
+        memref::getLinearizedMemRefOffsetAndSize(
+            rewriter, loc, elementBitWidth, elementBitWidth,
+            stridedMetadata.getConstifiedMixedOffset(),
+            stridedMetadata.getConstifiedMixedSizes(),
+            stridedMetadata.getConstifiedMixedStrides(), indices);
+    // OpFoldResult linearIndexSize = linearizedInfo.linearizedSize;
+    Value linearIndex =
+        getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices);
+
+    // Note below doesn't give the correct result for the linearized size.
+    // It compute the mutiplied sizes of all dimensions instead of taking
+    // the maximum of each dimension size * stride.
+    // TODO(jerryyin): Fix the getLinearizedMemRefOffsetAndSize() function
+    // Value totalSize = getValueOrCreateConstantIndexOp(
+    //    rewriter, loc, linearizedInfo.linearizedSize);
+    SmallVector<AffineExpr> productExpressions;
+    SmallVector<Value> productResults;
+    unsigned sourceRank =
+        cast<ShapedType>(readOp.getSource().getType()).getRank();
+
+    SmallVector<AffineExpr> symbols(2 * sourceRank);
+    SmallVector<Value> offsetValues(2 * sourceRank);
+    bindSymbolsList(rewriter.getContext(), MutableArrayRef{symbols});
+    for (size_t i = 0; i < sourceRank; ++i) {
+      unsigned offsetIdx = 2 * i;
+      productExpressions.push_back(symbols[offsetIdx] * symbols[offsetIdx + 1]);
+      offsetValues[offsetIdx] = stridedMetadata.getStrides()[i];
+      offsetValues[offsetIdx + 1] = stridedMetadata.getSizes()[i];
     }
 
+    AffineMap maxMap = AffineMap::get(
+        /*dimCount=*/0, /*symbolCount=*/symbols.size(), productExpressions,
+        rewriter.getContext());
+    Value totalSize =
+        rewriter.create<affine::AffineMaxOp>(loc, maxMap, offsetValues);
+
     // delta = bufferSize - linearizedOffset
-    // 1) check if delta < vectorSize
-    VectorType vectorType = readOp.getVectorType();
-    int64_t vectorSize = vectorType.getNumElements();
     Value vectorSizeOffset =
         rewriter.create<arith::ConstantIndexOp>(loc, vectorSize);
     Value delta = rewriter.create<arith::SubIOp>(loc, totalSize, linearIndex);
+
+    // 1) check if delta < vectorSize
     Value isOutofBounds = rewriter.create<arith::CmpIOp>(
         loc, arith::CmpIPredicate::ule, delta, vectorSizeOffset);
 
     // 2) check if (detla(bytes) % (32 / elementBitwidth) != 0)
-    int64_t elementBitWidth = vectorType.getElementTypeBitWidth();
     Value deltaBytes = rewriter.create<arith::MulIOp>(
         loc, delta,
         rewriter.create<arith::ConstantIndexOp>(loc, elementBitWidth / 8));

>From a09cd5163e5e9d17b3ec2f3e627037320bde9e46 Mon Sep 17 00:00:00 2001
From: jerryyin <zhuoryin at amd.com>
Date: Fri, 11 Apr 2025 18:37:11 +0000
Subject: [PATCH 4/6] Invoking vector transfer lowering pattern in amdgpu pass

---
 .../AMDGPU/Transforms/TransferReadToLoad.cpp  | 22 +++++++++++--------
 1 file changed, 13 insertions(+), 9 deletions(-)

diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp
index a60718041cbcb..60839cbfaeae1 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp
@@ -16,12 +16,14 @@
 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "mlir/Transforms/WalkPatternRewriteDriver.h"
 
 namespace mlir::amdgpu {
@@ -132,7 +134,7 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
 
   LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
                                 PatternRewriter &rewriter) const override {
-    if (readOp->hasAttr("amdgpu.transformed"))
+    if (readOp->hasAttr("amdgpu.buffer_transfer_read_needs_mask"))
       return failure();
 
     bool requiresBroadcasting = false;
@@ -148,7 +150,6 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
     VectorType vectorType = readOp.getVectorType();
     int64_t vectorSize = vectorType.getNumElements();
     int64_t elementBitWidth = vectorType.getElementTypeBitWidth();
-    // Value linearIndex = rewriter.create<arith::ConstantIndexOp>(loc, 0);
     SmallVector<OpFoldResult> indices = readOp.getIndices();
 
     auto stridedMetadata =
@@ -161,16 +162,15 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
             stridedMetadata.getConstifiedMixedOffset(),
             stridedMetadata.getConstifiedMixedSizes(),
             stridedMetadata.getConstifiedMixedStrides(), indices);
-    // OpFoldResult linearIndexSize = linearizedInfo.linearizedSize;
     Value linearIndex =
         getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices);
 
-    // Note below doesn't give the correct result for the linearized size.
-    // It compute the mutiplied sizes of all dimensions instead of taking
-    // the maximum of each dimension size * stride.
     // TODO(jerryyin): Fix the getLinearizedMemRefOffsetAndSize() function
+    // Note below doesn't give the correct result for the linearized size.
     // Value totalSize = getValueOrCreateConstantIndexOp(
     //    rewriter, loc, linearizedInfo.linearizedSize);
+    // It compute the mutiplied sizes of all dimensions instead of taking
+    // the maximum of each dimension size * stride.
     SmallVector<AffineExpr> productExpressions;
     SmallVector<Value> productResults;
     unsigned sourceRank =
@@ -201,7 +201,7 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
     Value isOutofBounds = rewriter.create<arith::CmpIOp>(
         loc, arith::CmpIPredicate::ule, delta, vectorSizeOffset);
 
-    // 2) check if (detla(bytes) % (32 / elementBitwidth) != 0)
+    // 2) check if (detla_bytes % (32 / elementBitwidth) != 0)
     Value deltaBytes = rewriter.create<arith::MulIOp>(
         loc, delta,
         rewriter.create<arith::ConstantIndexOp>(loc, elementBitWidth / 8));
@@ -219,7 +219,8 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
 
     auto thenBuilder = [&](OpBuilder &builder, Location loc) {
       Operation *read = builder.clone(*readOp.getOperation());
-      read->setAttr("amdgpu.transformed", builder.getUnitAttr());
+      read->setAttr("amdgpu.buffer_transfer_read_needs_mask",
+                    builder.getUnitAttr());
       Value readResult = read->getResult(0);
       builder.create<scf::YieldOp>(loc, readResult);
     };
@@ -244,6 +245,7 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
 void mlir::amdgpu::populateAmdgpuTransferReadToLoadPatterns(
     RewritePatternSet &patterns) {
   patterns.add<TransferReadLowering>(patterns.getContext());
+  vector::populateVectorTransferLoweringPatterns(patterns);
 }
 
 struct AmdgpuTransferReadToLoadPass final
@@ -252,6 +254,8 @@ struct AmdgpuTransferReadToLoadPass final
   void runOnOperation() override {
     RewritePatternSet patterns(&getContext());
     populateAmdgpuTransferReadToLoadPatterns(patterns);
-    walkAndApplyPatterns(getOperation(), std::move(patterns));
+    if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
+      return signalPassFailure();
+    }
   }
 };

>From d5967040d1889ff6f13d351564bd1f78ec014486 Mon Sep 17 00:00:00 2001
From: jerryyin <zhuoryin at amd.com>
Date: Fri, 11 Apr 2025 21:34:42 +0000
Subject: [PATCH 5/6] Matching unit test with latest implementation

---
 .../AMDGPU/Transforms/TransferReadToLoad.cpp  |   6 +-
 .../Dialect/AMDGPU/transfer-read-to-load.mlir | 119 ++++++++++--------
 2 files changed, 71 insertions(+), 54 deletions(-)

diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp
index 60839cbfaeae1..4cab311f2b6c7 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp
@@ -154,9 +154,8 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
 
     auto stridedMetadata =
         rewriter.create<memref::ExtractStridedMetadataOp>(loc, src);
-    memref::LinearizedMemRefInfo linearizedInfo;
     OpFoldResult linearizedIndices;
-    std::tie(linearizedInfo, linearizedIndices) =
+    std::tie(std::ignore, linearizedIndices) =
         memref::getLinearizedMemRefOffsetAndSize(
             rewriter, loc, elementBitWidth, elementBitWidth,
             stridedMetadata.getConstifiedMixedOffset(),
@@ -173,8 +172,7 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
     // the maximum of each dimension size * stride.
     SmallVector<AffineExpr> productExpressions;
     SmallVector<Value> productResults;
-    unsigned sourceRank =
-        cast<ShapedType>(readOp.getSource().getType()).getRank();
+    unsigned sourceRank = cast<ShapedType>(src.getType()).getRank();
 
     SmallVector<AffineExpr> symbols(2 * sourceRank);
     SmallVector<Value> offsetValues(2 * sourceRank);
diff --git a/mlir/test/Dialect/AMDGPU/transfer-read-to-load.mlir b/mlir/test/Dialect/AMDGPU/transfer-read-to-load.mlir
index 776a047e6a85d..91b6d8b3137c8 100644
--- a/mlir/test/Dialect/AMDGPU/transfer-read-to-load.mlir
+++ b/mlir/test/Dialect/AMDGPU/transfer-read-to-load.mlir
@@ -9,55 +9,72 @@ func.func @transfer_to_maskedload_fatrawbuffer(%mem : memref<8x8xf32, #amdgpu.ad
   %res = vector.transfer_read %mem[%idx, %idx], %cf0, %mask {in_bounds = [true]} : memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<4xf32>
   return %res : vector<4xf32>
 }
-// CHECK: %[[CST:.*]] = arith.constant 0.0
-// CHECK: %[[C0:.*]] = arith.constant 0
-// CHECK: %[[C1:.*]] = arith.constant 1
-// CHECK: %[[MUL0:.*]] = arith.muli %[[ARG1]], %[[C1]]
-// CHECK: %[[ADD0:.*]] = arith.addi %[[C0]], %[[MUL0]]
-// CHECK: %[[C8:.*]] = arith.constant 8
-// CHECK: %[[MUL1:.*]] = arith.muli %[[C1]], %[[C8]]
-// CHECK: %[[MUL2:.*]] = arith.muli %[[ARG1]], %[[MUL1]]
-// CHECK: %[[ADD1:.*]] = arith.addi %[[ADD0]], %[[MUL2]]
-// CHECK: %[[C4:.*]] = arith.constant 4
-// CHECK: %[[ADD2:.*]] = arith.addi %[[ADD1]], %[[C4]]
-
-// CHECK: %[[MUL3:.*]] = arith.muli %[[C1]], %[[C8]]
-// CHECK: %[[MUL4:.*]] = arith.muli
-
-// CHECK: %[[CMP:.*]] = arith.cmpi ule, %[[ADD2]], %[[MUL4]]
-// CHECK: %[[IF:.*]] = scf.if %[[CMP]] -> (vector<4xf32>) {
-
-// CHECK: %[[SPLAT:.*]] = vector.splat %[[CST]]
-// CHECK: %[[LOAD:.*]] = vector.load %arg0[%arg1, %arg1]
-// CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[SPLAT]]
+
+// CHECK: %[[FALSE:.*]] = arith.constant false
+// CHECK: %[[IF:.*]] = scf.if %[[FALSE]] -> (vector<4xf32>) {
+// CHECK: vector.maskedload %[[ARG0]][%[[ARG1]], %[[ARG1]]], %[[ARG2]]
 
 // CHECK: } else {
-// CHECK: %[[LOAD:.*]] = vector.transfer_read %arg0[%arg1, %arg1], %[[CST]], %arg2 {amdgpu.transformed, in_bounds = [true]} : memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<4xf32>
+// CHECK: %[[LOAD:.*]] = vector.load %arg0[%arg1, %arg1]
+// CHECK: %[[SELECT:.*]] = arith.select %[[ARG2]], %[[LOAD]]
 
 // CHECK: return %[[IF]] : vector<4xf32>
 
 // -----
 
-// CHECK-LABEL: func @transfer_to_maskedload_fatrawbuffer_dynamic(
-// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xf32, #amdgpu.address_space<fat_raw_buffer>>
-// CHECK-SAME: %[[ARG1:.*]]: index
-// CHECK-SAME: %[[ARG2:.*]]: vector<4xi1>
-func.func @transfer_to_maskedload_fatrawbuffer_dynamic(%mem : memref<?x?xf32, #amdgpu.address_space<fat_raw_buffer>>, %idx : index, %mask : vector<4xi1>) -> vector<4xf32> {
-  %cf0 = arith.constant 0.0 : f32
-  %res = vector.transfer_read %mem[%idx, %idx], %cf0, %mask {in_bounds = [true]} : memref<?x?xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<4xf32>
-  return %res : vector<4xf32>
+// CHECK: #map = affine_map<()[s0, s1] -> (s0 * 8 + s1)>
+// CHECK-LABEL: func @transfer_to_maskedload_fatrawbuffer_f16(
+// CHECK-SAME: %[[ARG0:.+]]: memref<8x8xf16, #amdgpu.address_space<fat_raw_buffer>>,
+// CHECK-SAME: %[[ARG1:.+]]: index, %[[ARG2:.+]]: index,
+// CHECK-SAME: %[[ARG3:.+]]: vector<4xi1>)
+func.func @transfer_to_maskedload_fatrawbuffer_f16(%mem : memref<8x8xf16, #amdgpu.address_space<fat_raw_buffer>>, %idx0 : index, %idx1 : index, %mask : vector<4xi1>) -> vector<4xf16> {
+  %cf0 = arith.constant 0.0 : f16
+  %res = vector.transfer_read %mem[%idx0, %idx1], %cf0, %mask {in_bounds = [true]} : memref<8x8xf16, #amdgpu.address_space<fat_raw_buffer>>, vector<4xf16>
+  return %res : vector<4xf16>
 }
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0
+// CHECK-DAG: %[[SIZE:.*]] = arith.constant 64
+// CHECK-DAG: %[[BYTES:.*]] = arith.constant 2
+// CHECK-DAG: %[[VECTORSIZE:.*]] = arith.constant 4
+
+// CHECK: %[[LINEAR:.*]] = affine.apply #map()[%[[ARG1]], %[[ARG2]]]
+// CHECK: %[[DELTA:.*]] = arith.subi %[[SIZE]], %[[LINEAR]]
+// CHECK: %[[COND1:.*]] = arith.cmpi ule, %[[DELTA]], %[[VECTORSIZE]]
+
+// CHECK: %[[DELTABYTES:.*]] = arith.muli %[[DELTA]], %[[BYTES]]
+// CHECK: %[[REM:.*]] = arith.remui %[[DELTABYTES]], %[[BYTES]]
+// CHECK: %[[COND2:.*]] = arith.cmpi ne, %[[REM]], %[[C0]]
+
+// CHECK: %[[COND:.*]] = arith.andi %[[COND1]], %[[COND2]]
+// CHECK: %[[IF:.*]] = scf.if %[[COND]] -> (vector<4xf16>) {
+// CHECK: vector.maskedload %[[ARG0]][%[[ARG1]], %[[ARG2]]], %[[ARG3]]
+// CHECK: } else {
+// CHECK: %[[LOAD:.*]] = vector.load %[[ARG0]][%[[ARG1]], %[[ARG2]]]
+// CHECK: return %[[IF]] : vector<4xf16>
+
+// -----
 
-// CHECK: %[[C1:.*]] = arith.constant 1
-// CHECK: %[[DIM1:.*]] = memref.dim %[[ARG0]], %[[C1]]
-// CHECK: %[[MUL0:.*]] = arith.muli %{{.*}}, %[[DIM1]]
-// CHECK: %[[C0:.*]] = arith.constant 0
-// CHECK: %[[DIM0:.*]] = memref.dim %[[ARG0]], %[[C0]]
-// CHECK: %[[MUL1:.*]] = arith.muli %{{.*}}, %[[DIM0]]
+// CHECK: #map = affine_map<()[s0, s1, s2] -> (s0 * s1 + s2)>
+// CHECK: #map1 = affine_map<()[s0, s1, s2, s3] -> (s0 * s1, s2 * s3)>
+// CHECK-LABEL: func @transfer_to_maskedload_fatrawbuffer_dynamic_i8(
+// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xi8, #amdgpu.address_space<fat_raw_buffer>>
+// CHECK-SAME: %[[ARG1:.*]]: index, %[[ARG2:.*]]: index
+// CHECK-SAME: %[[ARG3:.*]]: vector<4xi1>
+func.func @transfer_to_maskedload_fatrawbuffer_dynamic_i8(%mem : memref<?x?xi8, #amdgpu.address_space<fat_raw_buffer>>, %idx0 : index, %idx1 : index, %mask : vector<4xi1>) -> vector<4xi8> {
+  %cf0 = arith.constant 0 : i8
+  %res = vector.transfer_read %mem[%idx0, %idx1], %cf0, %mask {in_bounds = [true]} : memref<?x?xi8, #amdgpu.address_space<fat_raw_buffer>>, vector<4xi8>
+  return %res : vector<4xi8>
+}
 
-// CHECK: %[[C1_1:.*]] = arith.constant 1
-// CHECK: %[[DIM1_1:.*]] = memref.dim %[[ARG0]], %[[C1_1]]
-// CHECK: %[[MUL2:.*]] = arith.muli %{{.*}}, %[[DIM1_1]]
+// CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<4xi8>
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[C4:.*]] = arith.constant 4 : index
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG0]]
+// CHECK: %[[LINEAR:.*]] = affine.apply #map()[%[[ARG1]], %[[STRIDES]]#0, %[[ARG2]]]
+// CHECK: %[[SIZE:.*]] = affine.max #map1()[%[[STRIDES]]#0, %[[SIZES]]#0, %[[C1]], %[[SIZES]]#1]
+// CHECK: %[[IF:.*]] = scf.if
+// CHECK: return
 
 // -----
 
@@ -70,8 +87,8 @@ func.func @transfer_to_maskedload_regular(%mem : memref<8x8xf32>, %idx : index,
   %res = vector.transfer_read %mem[%idx, %idx], %cf0, %mask {in_bounds = [true]} : memref<8x8xf32>, vector<4xf32>
   return %res : vector<4xf32>
 }
-// CHECK: %[[CST:.*]] = arith.constant 0.0
-// CHECK: %[[RES:.*]] = vector.transfer_read %arg0[%arg1, %arg1], %[[CST]], %arg2 {in_bounds = [true]} : memref<8x8xf32>, vector<4xf32>
+// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
+// CHECK: %[[RES:.*]] = vector.maskedload %[[ARG0]][%[[ARG1]], %[[ARG1]]], %[[ARG2]], %[[CST]]
 // CHECK: return %[[RES]] : vector<4xf32>
 
 // -----
@@ -85,8 +102,8 @@ func.func @transfer_to_maskedload_addrspace(%mem : memref<8x8xf32, #gpu.address_
   %res = vector.transfer_read %mem[%idx, %idx], %cf0, %mask {in_bounds = [true]} : memref<8x8xf32, #gpu.address_space<workgroup>>, vector<4xf32>
   return %res : vector<4xf32>
 }
-// CHECK: %[[CST:.*]] = arith.constant 0.0
-// CHECK: %[[RES:.*]] = vector.transfer_read %arg0[%arg1, %arg1], %[[CST]], %arg2 {in_bounds = [true]} : memref<8x8xf32, #gpu.address_space<workgroup>>, vector<4xf32>
+// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
+// CHECK: %[[RES:.*]] = vector.maskedload %[[ARG0]][%[[ARG1]], %[[ARG1]]], %[[ARG2]], %[[CST]]
 // CHECK: return %[[RES]] : vector<4xf32>
 
 // -----
@@ -103,10 +120,11 @@ func.func @transfer_broadcasting(%mem : memref<8x8xf32, #amdgpu.address_space<fa
       : memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<4xf32>
   return %res : vector<4xf32>
 }
-// CHECK: %[[CST:.*]] = arith.constant 0.0
-// CHECK: %[[SPLAT:.*]] = vector.splat %[[CST]]
+// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<1xf32>
+// CHECK: %[[FALSE:.*]] = arith.constant false
+// CHECK: %[[IF:.*]] = scf.if %[[FALSE]] -> (vector<4xf32>) {
 // CHECK: %[[LOAD:.*]] = vector.load %arg0[%arg1, %arg1]
-// CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[SPLAT]]
+// CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[CST]]
 // CHECK: %[[BROADCAST:.*]] = vector.broadcast %[[SELECT]] : vector<1xf32> to vector<4xf32>
 
 // -----
@@ -122,7 +140,8 @@ func.func @transfer_scalar(%mem : memref<8x8xf32, #amdgpu.address_space<fat_raw_
       : memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<1xf32>
   return %res : vector<1xf32>
 }
-// CHECK: %[[CST:.*]] = arith.constant 0.0
-// CHECK: %[[SPLAT:.*]] = vector.splat %[[CST]]
-// CHECK: %[[LOAD:.*]] = vector.load %arg0[%arg1, %arg1]
-// CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[SPLAT]]
+// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<1xf32>
+// CHECK: %[[FALSE:.*]] = arith.constant false
+// CHECK: %[[IF:.*]] = scf.if %[[FALSE]] -> (vector<1xf32>) {
+// CHECK: %[[LOAD:.*]] = vector.load %[[ARG0]][%[[ARG1]], %[[ARG1]]]
+// CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[CST]]

>From ddee079d1c12939e82d0972031faecec6ac771d5 Mon Sep 17 00:00:00 2001
From: jerryyin <zhuoryin at amd.com>
Date: Mon, 14 Apr 2025 18:43:02 +0000
Subject: [PATCH 6/6] Addressing review feedbacks

---
 .../AMDGPU/Transforms/TransferReadToLoad.cpp  | 74 +++++++++++++------
 .../Dialect/AMDGPU/transfer-read-to-load.mlir | 19 +++--
 2 files changed, 61 insertions(+), 32 deletions(-)

diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp
index 4cab311f2b6c7..a7d402752deb7 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp
@@ -24,7 +24,7 @@
 #include "mlir/Pass/Pass.h"
 #include "mlir/Support/LogicalResult.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "mlir/Transforms/WalkPatternRewriteDriver.h"
+#include "llvm/Support/MathExtras.h"
 
 namespace mlir::amdgpu {
 #define GEN_PASS_DEF_AMDGPUTRANSFERREADTOLOADPASS
@@ -76,6 +76,9 @@ static LogicalResult transferPreconditions(
   if (!memRefType.isLastDimUnitStride())
     return rewriter.notifyMatchFailure(xferOp, "!= 1 stride needs VectorToSCF");
 
+  if (memRefType.getElementTypeBitWidth() < 8)
+    return rewriter.notifyMatchFailure(xferOp, "unsupported sub-byte type");
+
   // If there is broadcasting involved then we first load the unbroadcasted
   // vector, and then broadcast it with `vector.broadcast`.
   ArrayRef<int64_t> vectorShape = xferOp.getVectorType().getShape();
@@ -127,6 +130,9 @@ static Value createVectorLoadForMaskedLoad(OpBuilder &builder, Location loc,
   return res;
 }
 
+static constexpr char kTransferReadNeedsMask[] =
+    "amdgpu.buffer_transfer_read_needs_mask";
+
 namespace {
 
 struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
@@ -134,7 +140,7 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
 
   LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
                                 PatternRewriter &rewriter) const override {
-    if (readOp->hasAttr("amdgpu.buffer_transfer_read_needs_mask"))
+    if (readOp->hasAttr(kTransferReadNeedsMask))
       return failure();
 
     bool requiresBroadcasting = false;
@@ -154,38 +160,60 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
 
     auto stridedMetadata =
         rewriter.create<memref::ExtractStridedMetadataOp>(loc, src);
+    SmallVector<OpFoldResult> strides =
+        stridedMetadata.getConstifiedMixedStrides();
+    SmallVector<OpFoldResult> sizes =
+        stridedMetadata.getConstifiedMixedSizes();
+    OpFoldResult offset =
+        stridedMetadata.getConstifiedMixedOffset();
     OpFoldResult linearizedIndices;
     std::tie(std::ignore, linearizedIndices) =
-        memref::getLinearizedMemRefOffsetAndSize(
-            rewriter, loc, elementBitWidth, elementBitWidth,
-            stridedMetadata.getConstifiedMixedOffset(),
-            stridedMetadata.getConstifiedMixedSizes(),
-            stridedMetadata.getConstifiedMixedStrides(), indices);
-    Value linearIndex =
-        getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices);
+        memref::getLinearizedMemRefOffsetAndSize(rewriter, loc, elementBitWidth,
+                                                 elementBitWidth, offset, sizes,
+                                                 strides, indices);
 
     // TODO(jerryyin): Fix the getLinearizedMemRefOffsetAndSize() function
     // Note below doesn't give the correct result for the linearized size.
     // Value totalSize = getValueOrCreateConstantIndexOp(
     //    rewriter, loc, linearizedInfo.linearizedSize);
-    // It compute the mutiplied sizes of all dimensions instead of taking
+    // It computes the multiplied sizes of all dimensions instead of taking
     // the maximum of each dimension size * stride.
     SmallVector<AffineExpr> productExpressions;
     SmallVector<Value> productResults;
     unsigned sourceRank = cast<ShapedType>(src.getType()).getRank();
 
     SmallVector<AffineExpr> symbols(2 * sourceRank);
-    SmallVector<Value> offsetValues(2 * sourceRank);
+    SmallVector<Value> offsetValues;
     bindSymbolsList(rewriter.getContext(), MutableArrayRef{symbols});
+
+    size_t symbolIndex = 0;
     for (size_t i = 0; i < sourceRank; ++i) {
-      unsigned offsetIdx = 2 * i;
-      productExpressions.push_back(symbols[offsetIdx] * symbols[offsetIdx + 1]);
-      offsetValues[offsetIdx] = stridedMetadata.getStrides()[i];
-      offsetValues[offsetIdx + 1] = stridedMetadata.getSizes()[i];
+      AffineExpr strideExpr, sizeExpr;
+      OpFoldResult stride = strides[i];
+      OpFoldResult size = sizes[i];
+      if (auto constantStride =
+              getConstantIntValue(stride)) {
+        strideExpr = rewriter.getAffineConstantExpr(*constantStride);
+      } else {
+        strideExpr = symbols[symbolIndex++];
+        offsetValues.push_back(getValueOrCreateConstantIndexOp(
+            rewriter, loc, stride));
+      }
+
+      if (auto constantSize =
+              getConstantIntValue(size)) {
+        sizeExpr = rewriter.getAffineConstantExpr(*constantSize);
+      } else {
+        sizeExpr = symbols[symbolIndex++];
+        offsetValues.push_back(getValueOrCreateConstantIndexOp(
+            rewriter, loc, size));
+      }
+
+      productExpressions.push_back(strideExpr * sizeExpr);
     }
 
     AffineMap maxMap = AffineMap::get(
-        /*dimCount=*/0, /*symbolCount=*/symbols.size(), productExpressions,
+        /*dimCount=*/0, /*symbolCount=*/symbolIndex, productExpressions,
         rewriter.getContext());
     Value totalSize =
         rewriter.create<affine::AffineMaxOp>(loc, maxMap, offsetValues);
@@ -193,32 +221,35 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
     // delta = bufferSize - linearizedOffset
     Value vectorSizeOffset =
         rewriter.create<arith::ConstantIndexOp>(loc, vectorSize);
+    Value linearIndex =
+        getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices);
     Value delta = rewriter.create<arith::SubIOp>(loc, totalSize, linearIndex);
 
     // 1) check if delta < vectorSize
     Value isOutofBounds = rewriter.create<arith::CmpIOp>(
-        loc, arith::CmpIPredicate::ule, delta, vectorSizeOffset);
+        loc, arith::CmpIPredicate::ult, delta, vectorSizeOffset);
 
     // 2) check if (detla_bytes % (32 / elementBitwidth) != 0)
     Value deltaBytes = rewriter.create<arith::MulIOp>(
         loc, delta,
         rewriter.create<arith::ConstantIndexOp>(loc, elementBitWidth / 8));
     Value elementsPerWord = rewriter.create<arith::ConstantIndexOp>(
-        loc, elementBitWidth < 32 ? 32 / elementBitWidth : 1);
+        loc, llvm::divideCeil(32, elementBitWidth));
     Value isNotWordAligned = rewriter.create<arith::CmpIOp>(
         loc, arith::CmpIPredicate::ne,
         rewriter.create<arith::RemUIOp>(loc, deltaBytes, elementsPerWord),
         rewriter.create<arith::ConstantIndexOp>(loc, 0));
 
     // We take the fallback of transfer_read default lowering only it is both
-    // out-of-bounds and not word aligned.
+    // out-of-bounds and not word aligned. The fallback ensures correct results
+    // when loading at the boundary of the buffer since buffer load returns
+    // inconsistent zeros for the whole word when boundary is crossed.
     Value ifCondition =
         rewriter.create<arith::AndIOp>(loc, isOutofBounds, isNotWordAligned);
 
     auto thenBuilder = [&](OpBuilder &builder, Location loc) {
       Operation *read = builder.clone(*readOp.getOperation());
-      read->setAttr("amdgpu.buffer_transfer_read_needs_mask",
-                    builder.getUnitAttr());
+      read->setAttr(kTransferReadNeedsMask, builder.getUnitAttr());
       Value readResult = read->getResult(0);
       builder.create<scf::YieldOp>(loc, readResult);
     };
@@ -243,7 +274,6 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
 void mlir::amdgpu::populateAmdgpuTransferReadToLoadPatterns(
     RewritePatternSet &patterns) {
   patterns.add<TransferReadLowering>(patterns.getContext());
-  vector::populateVectorTransferLoweringPatterns(patterns);
 }
 
 struct AmdgpuTransferReadToLoadPass final
diff --git a/mlir/test/Dialect/AMDGPU/transfer-read-to-load.mlir b/mlir/test/Dialect/AMDGPU/transfer-read-to-load.mlir
index 91b6d8b3137c8..d0805b6b8a973 100644
--- a/mlir/test/Dialect/AMDGPU/transfer-read-to-load.mlir
+++ b/mlir/test/Dialect/AMDGPU/transfer-read-to-load.mlir
@@ -12,7 +12,7 @@ func.func @transfer_to_maskedload_fatrawbuffer(%mem : memref<8x8xf32, #amdgpu.ad
 
 // CHECK: %[[FALSE:.*]] = arith.constant false
 // CHECK: %[[IF:.*]] = scf.if %[[FALSE]] -> (vector<4xf32>) {
-// CHECK: vector.maskedload %[[ARG0]][%[[ARG1]], %[[ARG1]]], %[[ARG2]]
+// CHECK: vector.transfer_read %[[ARG0]][%[[ARG1]], %[[ARG1]]]
 
 // CHECK: } else {
 // CHECK: %[[LOAD:.*]] = vector.load %arg0[%arg1, %arg1]
@@ -39,7 +39,7 @@ func.func @transfer_to_maskedload_fatrawbuffer_f16(%mem : memref<8x8xf16, #amdgp
 
 // CHECK: %[[LINEAR:.*]] = affine.apply #map()[%[[ARG1]], %[[ARG2]]]
 // CHECK: %[[DELTA:.*]] = arith.subi %[[SIZE]], %[[LINEAR]]
-// CHECK: %[[COND1:.*]] = arith.cmpi ule, %[[DELTA]], %[[VECTORSIZE]]
+// CHECK: %[[COND1:.*]] = arith.cmpi ult, %[[DELTA]], %[[VECTORSIZE]]
 
 // CHECK: %[[DELTABYTES:.*]] = arith.muli %[[DELTA]], %[[BYTES]]
 // CHECK: %[[REM:.*]] = arith.remui %[[DELTABYTES]], %[[BYTES]]
@@ -47,7 +47,7 @@ func.func @transfer_to_maskedload_fatrawbuffer_f16(%mem : memref<8x8xf16, #amdgp
 
 // CHECK: %[[COND:.*]] = arith.andi %[[COND1]], %[[COND2]]
 // CHECK: %[[IF:.*]] = scf.if %[[COND]] -> (vector<4xf16>) {
-// CHECK: vector.maskedload %[[ARG0]][%[[ARG1]], %[[ARG2]]], %[[ARG3]]
+// CHECK: vector.transfer_read %[[ARG0]][%[[ARG1]], %[[ARG2]]]
 // CHECK: } else {
 // CHECK: %[[LOAD:.*]] = vector.load %[[ARG0]][%[[ARG1]], %[[ARG2]]]
 // CHECK: return %[[IF]] : vector<4xf16>
@@ -55,7 +55,7 @@ func.func @transfer_to_maskedload_fatrawbuffer_f16(%mem : memref<8x8xf16, #amdgp
 // -----
 
 // CHECK: #map = affine_map<()[s0, s1, s2] -> (s0 * s1 + s2)>
-// CHECK: #map1 = affine_map<()[s0, s1, s2, s3] -> (s0 * s1, s2 * s3)>
+// CHECK: #map1 = affine_map<()[s0, s1, s2] -> (s0 * s1, s2)>
 // CHECK-LABEL: func @transfer_to_maskedload_fatrawbuffer_dynamic_i8(
 // CHECK-SAME: %[[ARG0:.*]]: memref<?x?xi8, #amdgpu.address_space<fat_raw_buffer>>
 // CHECK-SAME: %[[ARG1:.*]]: index, %[[ARG2:.*]]: index
@@ -69,10 +69,9 @@ func.func @transfer_to_maskedload_fatrawbuffer_dynamic_i8(%mem : memref<?x?xi8,
 // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<4xi8>
 // CHECK: %[[C0:.*]] = arith.constant 0 : index
 // CHECK: %[[C4:.*]] = arith.constant 4 : index
-// CHECK: %[[C1:.*]] = arith.constant 1 : index
 // CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG0]]
 // CHECK: %[[LINEAR:.*]] = affine.apply #map()[%[[ARG1]], %[[STRIDES]]#0, %[[ARG2]]]
-// CHECK: %[[SIZE:.*]] = affine.max #map1()[%[[STRIDES]]#0, %[[SIZES]]#0, %[[C1]], %[[SIZES]]#1]
+// CHECK: %[[SIZE:.*]] = affine.max #map1()[%[[STRIDES]]#0, %[[SIZES]]#0, %[[SIZES]]#1]
 // CHECK: %[[IF:.*]] = scf.if
 // CHECK: return
 
@@ -87,8 +86,8 @@ func.func @transfer_to_maskedload_regular(%mem : memref<8x8xf32>, %idx : index,
   %res = vector.transfer_read %mem[%idx, %idx], %cf0, %mask {in_bounds = [true]} : memref<8x8xf32>, vector<4xf32>
   return %res : vector<4xf32>
 }
-// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
-// CHECK: %[[RES:.*]] = vector.maskedload %[[ARG0]][%[[ARG1]], %[[ARG1]]], %[[ARG2]], %[[CST]]
+// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00
+// CHECK: %[[RES:.*]] = vector.transfer_read %[[ARG0]][%[[ARG1]], %[[ARG1]]], %[[CST]], %[[ARG2]]
 // CHECK: return %[[RES]] : vector<4xf32>
 
 // -----
@@ -102,8 +101,8 @@ func.func @transfer_to_maskedload_addrspace(%mem : memref<8x8xf32, #gpu.address_
   %res = vector.transfer_read %mem[%idx, %idx], %cf0, %mask {in_bounds = [true]} : memref<8x8xf32, #gpu.address_space<workgroup>>, vector<4xf32>
   return %res : vector<4xf32>
 }
-// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
-// CHECK: %[[RES:.*]] = vector.maskedload %[[ARG0]][%[[ARG1]], %[[ARG1]]], %[[ARG2]], %[[CST]]
+// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00
+// CHECK: %[[RES:.*]] = vector.transfer_read %[[ARG0]][%[[ARG1]], %[[ARG1]]], %[[CST]], %[[ARG2]]
 // CHECK: return %[[RES]] : vector<4xf32>
 
 // -----



More information about the llvm-commits mailing list