[Mlir-commits] [mlir] Support non minor id maps (PR #178992)
Michael Platings
llvmlistbot at llvm.org
Tue Feb 3 08:06:14 PST 2026
https://github.com/mplatings updated https://github.com/llvm/llvm-project/pull/178992
>From cbd01c755b7b73c97af90661fac668748280865f Mon Sep 17 00:00:00 2001
From: Michael Platings <michael.platings at arm.com>
Date: Fri, 30 Jan 2026 17:47:22 +0000
Subject: [PATCH] [mlir][vector-to-gpu]: Lower transposed strided transfer_read
Add support for lowering vector.transfer_read to
gpu.subgroup_mma_load_matrix with transpose permutation_map
with non-minor dimensions e.g. (d0, d1, d2) -> (d2, d0)
---
.../Conversion/VectorToGPU/VectorToGPU.cpp | 142 +++++++++---------
.../VectorToGPU/vector-to-mma-ops.mlir | 87 +++++++++++
2 files changed, 155 insertions(+), 74 deletions(-)
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index 335786f554c02..65433ae6eb1c6 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -95,78 +95,71 @@ static bool contractSupportsMMAMatrixType(vector::ContractionOp contract,
return true;
}
-// Return true if the given map represents a transposed matrix load,
-// i.e. (d0, d1, ...) -> (dn-1, dn-2).
-static bool isTransposeMatrixLoadMap(AffineMap permutationMap) {
+// Test whether the permutation map's first result corresponds to its last
+// dimension.
+//
+// In contexts where we only accept maps that have the last (most minor)
+// dimension as exactly one of the results, this is sufficient to classify
+// whether it represents a transpose.
+static bool isFirstResultLastMapDimension(AffineMap permutationMap) {
MLIRContext *ctx = permutationMap.getContext();
- // Local OpBuilder is fine here, we just build attributes.
- OpBuilder b(ctx);
- auto nDim = permutationMap.getNumDims();
- AffineExpr zero = b.getAffineConstantExpr(0);
- if (nDim < 2) {
- // Support transposed+broadcasted cases: affine_map<(d0) -> (d0, 0)>.
- AffineExpr dim0 = b.getAffineDimExpr(0);
- return permutationMap == AffineMap::get(1, 0, {dim0, zero}, ctx);
- }
-
- AffineExpr innerDim = b.getAffineDimExpr(nDim - 1);
- AffineExpr outerDim = b.getAffineDimExpr(nDim - 2);
- // Support both transposed and transposed+broadcasted cases.
- return permutationMap == AffineMap::get(nDim, 0, {innerDim, outerDim}, ctx) ||
- permutationMap == AffineMap::get(nDim, 0, {innerDim, zero}, ctx);
+ unsigned nDim = permutationMap.getNumDims();
+ return nDim && permutationMap.getNumResults() &&
+ permutationMap.getResult(0) == getAffineDimExpr(nDim - 1, ctx);
}
-// Return true if the given map represents a minor identity map with "strided"
-// dimensions i.e. (d0, d1, ..., dn-1) -> (da, dn-1) where 0 <= a <= n-1.
+// Return the `leadDimension` (row stride) implied by |permutationMap| for
+// |type|, if |type| is a memref with a statically-known layout.
//
-// This currently doesn't support permuted or broadcast dimensions.
-static bool isStridedMinorIdentity(AffineMap permutationMap) {
- if (permutationMap.getNumResults() != 2)
- return false;
-
- AffineExpr innerResult = permutationMap.getResult(1);
- const unsigned nDim = permutationMap.getNumDims();
- if (innerResult != getAffineDimExpr(nDim - 1, permutationMap.getContext()))
- return false;
-
- return isa<AffineDimExpr>(permutationMap.getResult(0));
-}
-
-// Return the stride for the "row" dimension of |type| if it is a memref
-// and has a constant stride.
+// The `leadDimension` is the stride (in elements) between consecutive rows in
+// the 2D view described by |permutationMap|. This helper supports the subset
+// of maps permitted by vector.transfer_read:
+// - Exactly 2 results.
+// - Each result is either an affine dimension or the constant 0 (broadcast).
+//
+// Constraints:
+// - Requires the most minor memref stride to be 1.
+//
+// Broadcast:
+// - If either result is constant 0, the implied `leadDimension` is 0.
static std::optional<int64_t>
getStaticallyKnownRowStride(ShapedType type, AffineMap permutationMap) {
auto memrefType = dyn_cast<MemRefType>(type);
if (!memrefType)
return std::nullopt;
-
- // We only support permutation maps with two results i.e. loading a 2D matrix.
- if (2 != permutationMap.getNumResults())
- return std::nullopt;
-
// If the memref is 0 or 1D the horizontal stride is 0.
if (memrefType.getRank() < 2)
return 0;
int64_t offset = 0;
- SmallVector<int64_t, 2> strides;
+ SmallVector<int64_t> strides;
if (failed(memrefType.getStridesAndOffset(strides, offset)) ||
strides.back() != 1)
return std::nullopt;
- int stridePostion = strides.size() - 2;
- // We need to be careful if we have a permutation map i.e. (d0, d1) -> (d1,
- // d0), in this case we want to get the stride of the rows i.e. d0 but the
- // logic below will extract the column stride d1 because the dims have been
- // swapped.
- if (!permutationMap.isPermutation()) {
- // It's possible we have a constant 0 expression here (the permutation map
- // must be an affine projection). Projected dimensions are already handled
- // by the caller.
- if (auto affineDimExpr =
- dyn_cast<AffineDimExpr>(permutationMap.getResult(0)))
- stridePostion = affineDimExpr.getPosition();
+ if (permutationMap.getNumResults() != 2)
+ return std::nullopt;
+
+ unsigned strideIndex = strides.size();
+
+ for (AffineExpr result : permutationMap.getResults()) {
+ if (auto dim = dyn_cast<AffineDimExpr>(result)) {
+ strideIndex = std::min(strideIndex, dim.getPosition());
+ continue;
+ }
+ auto cst = dyn_cast<AffineConstantExpr>(result);
+ if (!cst || cst.getValue() != 0)
+ return std::nullopt;
+ // A broadcast result forces row stride to 0.
+ return 0;
}
- const int64_t stride = strides[stridePostion];
+
+ // Structural validity check: ensure that the map selects at least one
+ // dimension more major than the most minor dimension. This also excludes
+ // degenerate cases where both results map to the most minor dimension.
+ if (strideIndex + 1 >= strides.size())
+ return std::nullopt;
+
+ int64_t stride = strides[strideIndex];
if (stride == ShapedType::kDynamic)
return std::nullopt;
return stride;
@@ -178,8 +171,8 @@ static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp) {
readOp.getVectorType().getRank() != 2)
return false;
- AffineMap map = readOp.getPermutationMap();
- if (!getStaticallyKnownRowStride(readOp.getShapedType(), map))
+ AffineMap permutationMap = readOp.getPermutationMap();
+ if (!getStaticallyKnownRowStride(readOp.getShapedType(), permutationMap))
return false;
// Only allow integer types if the signedness can be inferred.
@@ -189,12 +182,9 @@ static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp) {
return false;
MLIRContext *ctx = readOp.getContext();
- AffineExpr innerDim = getAffineDimExpr(map.getNumDims() - 1, ctx);
- AffineExpr zero = getAffineConstantExpr(0, ctx);
- auto broadcastInnerDim =
- AffineMap::get(map.getNumDims(), 0, {zero, innerDim}, ctx);
- return isStridedMinorIdentity(map) || map == broadcastInnerDim ||
- isTransposeMatrixLoadMap(map);
+ AffineExpr innerDim = getAffineDimExpr(permutationMap.getNumDims() - 1, ctx);
+ return permutationMap.getResult(0) == innerDim ||
+ permutationMap.getResult(1) == innerDim;
}
// Return true if the transfer op can be converted to a MMA matrix store.
@@ -208,11 +198,17 @@ transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp) {
writeOp.getVectorType().getRank() != 2)
return false;
- AffineMap map = writeOp.getPermutationMap();
- if (!getStaticallyKnownRowStride(writeOp.getShapedType(), map))
+ AffineMap permutationMap = writeOp.getPermutationMap();
+ std::optional<int64_t> stride =
+ getStaticallyKnownRowStride(writeOp.getShapedType(), permutationMap);
+ // Stride of zero means broadcast which is not permitted for writes.
+ if (!stride.has_value() || stride.value() == 0)
return false;
+
+ MLIRContext *ctx = writeOp.getContext();
+ AffineExpr innerDim = getAffineDimExpr(permutationMap.getNumDims() - 1, ctx);
// TODO: Support transpose once it is added to GPU dialect ops.
- return isStridedMinorIdentity(map);
+ return permutationMap.getResult(1) == innerDim;
}
/// Return true if the constant is a splat to a 2D vector so that it can be
@@ -584,21 +580,19 @@ convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op,
assert(transferReadSupportsMMAMatrixType(op) &&
"expected convertible operation");
- AffineMap map = op.getPermutationMap();
+ AffineMap permutationMap = op.getPermutationMap();
std::optional<int64_t> stride =
- getStaticallyKnownRowStride(op.getShapedType(), map);
+ getStaticallyKnownRowStride(op.getShapedType(), permutationMap);
if (!stride.has_value()) {
LDBG() << "no stride";
return rewriter.notifyMatchFailure(op, "no stride");
}
- bool isTranspose = isTransposeMatrixLoadMap(map);
-
- // Handle broadcast by setting the stride to 0.
- if (auto cstExpr = dyn_cast<AffineConstantExpr>(map.getResult(isTranspose))) {
- assert(cstExpr.getValue() == 0);
- stride = 0;
- }
+ // transferReadSupportsMMAMatrixType ensures that either of the map results is
+ // the most minor dimension. Under this constraint, whether the map represents
+ // a transposed view can be inferred from whether the first result is the most
+ // minor memref dimension.
+ bool isTranspose = isFirstResultLastMapDimension(permutationMap);
Value mappingResult = op.getResult();
auto elType = op.getVectorType().getElementType();
diff --git a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
index 71a748d33831e..bf858789c7e07 100644
--- a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
+++ b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
@@ -1,5 +1,21 @@
// RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(convert-vector-to-gpu),canonicalize)" --split-input-file | FileCheck %s
+// -----
+
+// The pass currently only works for 2D vector transfers.
+// CHECK-LABEL: func @no_convert_3d
+// CHECK-NOT: gpu
+func.func @no_convert_3d(%arg0: memref<2x2x2xf16>) {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.0 : f16
+ %A = vector.transfer_read %arg0[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<2x2x2xf16>, vector<2x2x2xf16>
+ %B = arith.addf %A, %A : vector<2x2x2xf16>
+ vector.transfer_write %B, %arg0[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<2x2x2xf16>, memref<2x2x2xf16>
+ return
+}
+
+// -----
+
#map0 = affine_map<(d0, d1) -> (d1, d0)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
@@ -575,3 +591,74 @@ func.func @matmul_with_strides(%arg0: memref<16x16xf16>, %arg1: memref<16x6x16xf
vector.transfer_write %D, %arg2[%c0, %c0, %c0] {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2) -> (d0, d2)>} : vector<16x16xf16>, memref<16x9x16xf16>
return
}
+
+// -----
+
+// CHECK-LABEL: func @read_transpose_with_strides_3d
+func.func @read_transpose_with_strides_3d(%arg0: memref<5x7x3xf16>, %arg1: memref<2x5x3xf16>, %arg2: memref<3x5xf16>) {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.0 : f16
+ // CHECK: gpu.subgroup_mma_load_matrix %{{.*}} {leadDimension = 21 : index, transpose} : memref<5x7x3xf16> -> !gpu.mma_matrix<3x5xf16, "COp">
+ %A = vector.transfer_read %arg0[%c0, %c0, %c0], %cst {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2) -> (d2, d0)>} : memref<5x7x3xf16>, vector<3x5xf16>
+ // CHECK: gpu.subgroup_mma_load_matrix %{{.*}} {leadDimension = 3 : index, transpose} : memref<2x5x3xf16> -> !gpu.mma_matrix<3x5xf16, "COp">
+ %B = vector.transfer_read %arg1[%c0, %c0, %c0], %cst {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2) -> (d2, d1)>} : memref<2x5x3xf16>, vector<3x5xf16>
+ %C = arith.addf %A, %B : vector<3x5xf16>
+ vector.transfer_write %C, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<3x5xf16>, memref<3x5xf16>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @read_transpose_with_strides_4d
+func.func @read_transpose_with_strides_4d(%arg0: memref<5x7x11x3xf16>, %arg1: memref<2x5x11x3xf16>, %arg2: memref<3x5xf16>) {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.0 : f16
+ // CHECK: gpu.subgroup_mma_load_matrix %{{.*}} {leadDimension = 231 : index, transpose} : memref<5x7x11x3xf16> -> !gpu.mma_matrix<3x5xf16, "COp">
+ %A = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %cst {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d0)>} : memref<5x7x11x3xf16>, vector<3x5xf16>
+ // CHECK: gpu.subgroup_mma_load_matrix %{{.*}} {leadDimension = 33 : index, transpose} : memref<2x5x11x3xf16> -> !gpu.mma_matrix<3x5xf16, "COp">
+ %B = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %cst {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d1)>} : memref<2x5x11x3xf16>, vector<3x5xf16>
+ %C = arith.addf %A, %B : vector<3x5xf16>
+ vector.transfer_write %C, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<3x5xf16>, memref<3x5xf16>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @no_convert_read_transpose_not_last_dim
+// CHECK-NOT: gpu
+func.func @no_convert_read_transpose_not_last_dim(%arg0: memref<2x2x2xf16>, %arg1: memref<2x2xf16>) {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.0 : f16
+ // Legal map, but does not map the last memref dim so should not be lowered to an MMA load.
+ %A = vector.transfer_read %arg0[%c0, %c0, %c0], %cst {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2) -> (d1, d0)>} : memref<2x2x2xf16>, vector<2x2xf16>
+ %B = arith.addf %A, %A : vector<2x2xf16>
+ vector.transfer_write %B, %arg1[%c0, %c0] {in_bounds = [true, true]} : vector<2x2xf16>, memref<2x2xf16>
+ return
+}
+
+// -----
+
+// Transpose write is not supported.
+// CHECK-LABEL: func @no_convert_write_transpose
+// CHECK-NOT: gpu
+func.func @no_convert_write_transpose(%arg0: memref<2x2xf16>) {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.0 : f16
+ %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<2x2xf16>, vector<2x2xf16>
+ %B = arith.addf %A, %A : vector<2x2xf16>
+ vector.transfer_write %B, %arg0[%c0, %c0] {in_bounds = [true, true], permutation_map = affine_map<(d0, d1) -> (d1, d0)>} : vector<2x2xf16>, memref<2x2xf16>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @read_transpose_with_broadcast_3d
+func.func @read_transpose_with_broadcast_3d(%arg0: memref<2x2x2xf16>, %arg1: memref<2x2xf16>) {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.0 : f16
+ // CHECK: gpu.subgroup_mma_load_matrix %{{.*}} {leadDimension = 0 : index, transpose} : memref<2x2x2xf16> -> !gpu.mma_matrix<2x2xf16, "COp">
+ %A = vector.transfer_read %arg0[%c0, %c0, %c0], %cst {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2) -> (d2, 0)>} : memref<2x2x2xf16>, vector<2x2xf16>
+ %B = arith.addf %A, %A : vector<2x2xf16>
+ vector.transfer_write %B, %arg1[%c0, %c0] {in_bounds = [true, true]} : vector<2x2xf16>, memref<2x2xf16>
+ return
+}
More information about the Mlir-commits
mailing list