[Mlir-commits] [mlir] Support non minor id maps (PR #178992)

Michael Platings llvmlistbot at llvm.org
Tue Feb 3 08:06:14 PST 2026


https://github.com/mplatings updated https://github.com/llvm/llvm-project/pull/178992

>From cbd01c755b7b73c97af90661fac668748280865f Mon Sep 17 00:00:00 2001
From: Michael Platings <michael.platings at arm.com>
Date: Fri, 30 Jan 2026 17:47:22 +0000
Subject: [PATCH] [mlir][vector-to-gpu]: Lower transposed strided transfer_read

Add support for lowering vector.transfer_read to
gpu.subgroup_mma_load_matrix with transpose permutation_map
with non-minor dimensions e.g. (d0, d1, d2) -> (d2, d0)
---
 .../Conversion/VectorToGPU/VectorToGPU.cpp    | 142 +++++++++---------
 .../VectorToGPU/vector-to-mma-ops.mlir        |  87 +++++++++++
 2 files changed, 155 insertions(+), 74 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index 335786f554c02..65433ae6eb1c6 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -95,78 +95,71 @@ static bool contractSupportsMMAMatrixType(vector::ContractionOp contract,
   return true;
 }
 
-// Return true if the given map represents a transposed matrix load,
-// i.e. (d0, d1, ...) -> (dn-1, dn-2).
-static bool isTransposeMatrixLoadMap(AffineMap permutationMap) {
+// Test whether the permutation map's first result corresponds to its last
+// dimension.
+//
+// In contexts where we only accept maps that have the last (most minor)
+// dimension as exactly one of the results, this is sufficient to classify
+// whether it represents a transpose.
+static bool isFirstResultLastMapDimension(AffineMap permutationMap) {
   MLIRContext *ctx = permutationMap.getContext();
-  // Local OpBuilder is fine here, we just build attributes.
-  OpBuilder b(ctx);
-  auto nDim = permutationMap.getNumDims();
-  AffineExpr zero = b.getAffineConstantExpr(0);
-  if (nDim < 2) {
-    // Support transposed+broadcasted cases: affine_map<(d0) -> (d0, 0)>.
-    AffineExpr dim0 = b.getAffineDimExpr(0);
-    return permutationMap == AffineMap::get(1, 0, {dim0, zero}, ctx);
-  }
-
-  AffineExpr innerDim = b.getAffineDimExpr(nDim - 1);
-  AffineExpr outerDim = b.getAffineDimExpr(nDim - 2);
-  // Support both transposed and transposed+broadcasted cases.
-  return permutationMap == AffineMap::get(nDim, 0, {innerDim, outerDim}, ctx) ||
-         permutationMap == AffineMap::get(nDim, 0, {innerDim, zero}, ctx);
+  unsigned nDim = permutationMap.getNumDims();
+  return nDim && permutationMap.getNumResults() &&
+         permutationMap.getResult(0) == getAffineDimExpr(nDim - 1, ctx);
 }
 
-// Return true if the given map represents a minor identity map with "strided"
-// dimensions i.e. (d0, d1, ..., dn-1) -> (da, dn-1) where 0 <= a <= n-1.
+// Return the `leadDimension` (row stride) implied by |permutationMap| for
+// |type|, if |type| is a memref with a statically-known layout.
 //
-// This currently doesn't support permuted or broadcast dimensions.
-static bool isStridedMinorIdentity(AffineMap permutationMap) {
-  if (permutationMap.getNumResults() != 2)
-    return false;
-
-  AffineExpr innerResult = permutationMap.getResult(1);
-  const unsigned nDim = permutationMap.getNumDims();
-  if (innerResult != getAffineDimExpr(nDim - 1, permutationMap.getContext()))
-    return false;
-
-  return isa<AffineDimExpr>(permutationMap.getResult(0));
-}
-
-// Return the stride for the "row" dimension of |type| if it is a memref
-// and has a constant stride.
+// The `leadDimension` is the stride (in elements) between consecutive rows in
+// the 2D view described by |permutationMap|. This helper supports the subset
+// of maps permitted by vector.transfer_read:
+// - Exactly 2 results.
+// - Each result is either an affine dimension or the constant 0 (broadcast).
+//
+// Constraints:
+// - Requires the most minor memref stride to be 1.
+//
+// Broadcast:
+// - If either result is constant 0, the implied `leadDimension` is 0.
 static std::optional<int64_t>
 getStaticallyKnownRowStride(ShapedType type, AffineMap permutationMap) {
   auto memrefType = dyn_cast<MemRefType>(type);
   if (!memrefType)
     return std::nullopt;
-
-  // We only support permutation maps with two results i.e. loading a 2D matrix.
-  if (2 != permutationMap.getNumResults())
-    return std::nullopt;
-
   // If the memref is 0 or 1D the horizontal stride is 0.
   if (memrefType.getRank() < 2)
     return 0;
   int64_t offset = 0;
-  SmallVector<int64_t, 2> strides;
+  SmallVector<int64_t> strides;
   if (failed(memrefType.getStridesAndOffset(strides, offset)) ||
       strides.back() != 1)
     return std::nullopt;
 
-  int stridePostion = strides.size() - 2;
-  // We need to be careful if we have a permutation map i.e. (d0, d1) -> (d1,
-  // d0), in this case we want to get the stride of the rows i.e. d0 but the
-  // logic below will extract the column stride d1 because the dims have been
-  // swapped.
-  if (!permutationMap.isPermutation()) {
-    // It's possible we have a constant 0 expression here (the permutation map
-    // must be an affine projection). Projected dimensions are already handled
-    // by the caller.
-    if (auto affineDimExpr =
-            dyn_cast<AffineDimExpr>(permutationMap.getResult(0)))
-      stridePostion = affineDimExpr.getPosition();
+  if (permutationMap.getNumResults() != 2)
+    return std::nullopt;
+
+  unsigned strideIndex = strides.size();
+
+  for (AffineExpr result : permutationMap.getResults()) {
+    if (auto dim = dyn_cast<AffineDimExpr>(result)) {
+      strideIndex = std::min(strideIndex, dim.getPosition());
+      continue;
+    }
+    auto cst = dyn_cast<AffineConstantExpr>(result);
+    if (!cst || cst.getValue() != 0)
+      return std::nullopt;
+    // A broadcast result forces row stride to 0.
+    return 0;
   }
-  const int64_t stride = strides[stridePostion];
+
+  // Structural validity check: ensure that the map selects at least one
+  // dimension more major than the most minor dimension. This also excludes
+  // degenerate cases where both results map to the most minor dimension.
+  if (strideIndex + 1 >= strides.size())
+    return std::nullopt;
+
+  int64_t stride = strides[strideIndex];
   if (stride == ShapedType::kDynamic)
     return std::nullopt;
   return stride;
@@ -178,8 +171,8 @@ static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp) {
       readOp.getVectorType().getRank() != 2)
     return false;
 
-  AffineMap map = readOp.getPermutationMap();
-  if (!getStaticallyKnownRowStride(readOp.getShapedType(), map))
+  AffineMap permutationMap = readOp.getPermutationMap();
+  if (!getStaticallyKnownRowStride(readOp.getShapedType(), permutationMap))
     return false;
 
   // Only allow integer types if the signedness can be inferred.
@@ -189,12 +182,9 @@ static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp) {
       return false;
 
   MLIRContext *ctx = readOp.getContext();
-  AffineExpr innerDim = getAffineDimExpr(map.getNumDims() - 1, ctx);
-  AffineExpr zero = getAffineConstantExpr(0, ctx);
-  auto broadcastInnerDim =
-      AffineMap::get(map.getNumDims(), 0, {zero, innerDim}, ctx);
-  return isStridedMinorIdentity(map) || map == broadcastInnerDim ||
-         isTransposeMatrixLoadMap(map);
+  AffineExpr innerDim = getAffineDimExpr(permutationMap.getNumDims() - 1, ctx);
+  return permutationMap.getResult(0) == innerDim ||
+         permutationMap.getResult(1) == innerDim;
 }
 
 // Return true if the transfer op can be converted to a MMA matrix store.
@@ -208,11 +198,17 @@ transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp) {
       writeOp.getVectorType().getRank() != 2)
     return false;
 
-  AffineMap map = writeOp.getPermutationMap();
-  if (!getStaticallyKnownRowStride(writeOp.getShapedType(), map))
+  AffineMap permutationMap = writeOp.getPermutationMap();
+  std::optional<int64_t> stride =
+      getStaticallyKnownRowStride(writeOp.getShapedType(), permutationMap);
+  // Stride of zero means broadcast which is not permitted for writes.
+  if (!stride.has_value() || stride.value() == 0)
     return false;
+
+  MLIRContext *ctx = writeOp.getContext();
+  AffineExpr innerDim = getAffineDimExpr(permutationMap.getNumDims() - 1, ctx);
   // TODO: Support transpose once it is added to GPU dialect ops.
-  return isStridedMinorIdentity(map);
+  return permutationMap.getResult(1) == innerDim;
 }
 
 /// Return true if the constant is a splat to a 2D vector so that it can be
@@ -584,21 +580,19 @@ convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op,
   assert(transferReadSupportsMMAMatrixType(op) &&
          "expected convertible operation");
 
-  AffineMap map = op.getPermutationMap();
+  AffineMap permutationMap = op.getPermutationMap();
   std::optional<int64_t> stride =
-      getStaticallyKnownRowStride(op.getShapedType(), map);
+      getStaticallyKnownRowStride(op.getShapedType(), permutationMap);
   if (!stride.has_value()) {
     LDBG() << "no stride";
     return rewriter.notifyMatchFailure(op, "no stride");
   }
 
-  bool isTranspose = isTransposeMatrixLoadMap(map);
-
-  // Handle broadcast by setting the stride to 0.
-  if (auto cstExpr = dyn_cast<AffineConstantExpr>(map.getResult(isTranspose))) {
-    assert(cstExpr.getValue() == 0);
-    stride = 0;
-  }
+  // transferReadSupportsMMAMatrixType ensures that either of the map results is
+  // the most minor dimension. Under this constraint, whether the map represents
+  // a transposed view can be inferred from whether the first result is the most
+  // minor memref dimension.
+  bool isTranspose = isFirstResultLastMapDimension(permutationMap);
 
   Value mappingResult = op.getResult();
   auto elType = op.getVectorType().getElementType();
diff --git a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
index 71a748d33831e..bf858789c7e07 100644
--- a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
+++ b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
@@ -1,5 +1,21 @@
 // RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(convert-vector-to-gpu),canonicalize)" --split-input-file | FileCheck %s
 
+// -----
+
+// The pass currently only works for 2D vector transfers.
+// CHECK-LABEL: func @no_convert_3d
+// CHECK-NOT: gpu
+func.func @no_convert_3d(%arg0: memref<2x2x2xf16>) {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.0 : f16
+  %A = vector.transfer_read %arg0[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<2x2x2xf16>, vector<2x2x2xf16>
+  %B = arith.addf %A, %A : vector<2x2x2xf16>
+  vector.transfer_write %B, %arg0[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<2x2x2xf16>, memref<2x2x2xf16>
+  return
+}
+
+// -----
+
 #map0 = affine_map<(d0, d1) -> (d1, d0)>
 #map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
 #map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
@@ -575,3 +591,74 @@ func.func @matmul_with_strides(%arg0: memref<16x16xf16>, %arg1: memref<16x6x16xf
   vector.transfer_write %D, %arg2[%c0, %c0, %c0] {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2) -> (d0, d2)>} : vector<16x16xf16>, memref<16x9x16xf16>
   return
 }
+
+// -----
+
+// CHECK-LABEL: func @read_transpose_with_strides_3d
+func.func @read_transpose_with_strides_3d(%arg0: memref<5x7x3xf16>, %arg1: memref<2x5x3xf16>, %arg2: memref<3x5xf16>) {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.0 : f16
+  // CHECK: gpu.subgroup_mma_load_matrix %{{.*}} {leadDimension = 21 : index, transpose} : memref<5x7x3xf16> -> !gpu.mma_matrix<3x5xf16, "COp">
+  %A = vector.transfer_read %arg0[%c0, %c0, %c0], %cst {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2) -> (d2, d0)>} : memref<5x7x3xf16>, vector<3x5xf16>
+  // CHECK: gpu.subgroup_mma_load_matrix %{{.*}} {leadDimension = 3 : index, transpose} : memref<2x5x3xf16> -> !gpu.mma_matrix<3x5xf16, "COp">
+  %B = vector.transfer_read %arg1[%c0, %c0, %c0], %cst {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2) -> (d2, d1)>} : memref<2x5x3xf16>, vector<3x5xf16>
+  %C = arith.addf %A, %B : vector<3x5xf16>
+  vector.transfer_write %C, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<3x5xf16>, memref<3x5xf16>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @read_transpose_with_strides_4d
+func.func @read_transpose_with_strides_4d(%arg0: memref<5x7x11x3xf16>, %arg1: memref<2x5x11x3xf16>, %arg2: memref<3x5xf16>) {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.0 : f16
+  // CHECK: gpu.subgroup_mma_load_matrix %{{.*}} {leadDimension = 231 : index, transpose} : memref<5x7x11x3xf16> -> !gpu.mma_matrix<3x5xf16, "COp">
+  %A = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %cst {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d0)>} : memref<5x7x11x3xf16>, vector<3x5xf16>
+  // CHECK: gpu.subgroup_mma_load_matrix %{{.*}} {leadDimension = 33 : index, transpose} : memref<2x5x11x3xf16> -> !gpu.mma_matrix<3x5xf16, "COp">
+  %B = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %cst {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d1)>} : memref<2x5x11x3xf16>, vector<3x5xf16>
+  %C = arith.addf %A, %B : vector<3x5xf16>
+  vector.transfer_write %C, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<3x5xf16>, memref<3x5xf16>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @no_convert_read_transpose_not_last_dim
+// CHECK-NOT: gpu
+func.func @no_convert_read_transpose_not_last_dim(%arg0: memref<2x2x2xf16>, %arg1: memref<2x2xf16>) {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.0 : f16
+  // Legal map, but does not map the last memref dim so should not be lowered to an MMA load.
+  %A = vector.transfer_read %arg0[%c0, %c0, %c0], %cst {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2) -> (d1, d0)>} : memref<2x2x2xf16>, vector<2x2xf16>
+  %B = arith.addf %A, %A : vector<2x2xf16>
+  vector.transfer_write %B, %arg1[%c0, %c0] {in_bounds = [true, true]} : vector<2x2xf16>, memref<2x2xf16>
+  return
+}
+
+// -----
+
+// Transpose write is not supported.
+// CHECK-LABEL: func @no_convert_write_transpose
+// CHECK-NOT: gpu
+func.func @no_convert_write_transpose(%arg0: memref<2x2xf16>) {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.0 : f16
+  %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<2x2xf16>, vector<2x2xf16>
+  %B = arith.addf %A, %A : vector<2x2xf16>
+  vector.transfer_write %B, %arg0[%c0, %c0] {in_bounds = [true, true], permutation_map = affine_map<(d0, d1) -> (d1, d0)>} : vector<2x2xf16>, memref<2x2xf16>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @read_transpose_with_broadcast_3d
+func.func @read_transpose_with_broadcast_3d(%arg0: memref<2x2x2xf16>, %arg1: memref<2x2xf16>) {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.0 : f16
+  // CHECK: gpu.subgroup_mma_load_matrix %{{.*}} {leadDimension = 0 : index, transpose} : memref<2x2x2xf16> -> !gpu.mma_matrix<2x2xf16, "COp">
+  %A = vector.transfer_read %arg0[%c0, %c0, %c0], %cst {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2) -> (d2, 0)>} : memref<2x2x2xf16>, vector<2x2xf16>
+  %B = arith.addf %A, %A : vector<2x2xf16>
+  vector.transfer_write %B, %arg1[%c0, %c0] {in_bounds = [true, true]} : vector<2x2xf16>, memref<2x2xf16>
+  return
+}



More information about the Mlir-commits mailing list