[Mlir-commits] [mlir] 71a37dd - [mlir][vector-to-gpu]: Extend MMA Lowerings (#176785)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Feb 3 07:57:15 PST 2026
Author: Jack Frankland
Date: 2026-02-03T15:52:41Z
New Revision: 71a37dd4b9b46cf7d90657c1f1006744a20e18cb
URL: https://github.com/llvm/llvm-project/commit/71a37dd4b9b46cf7d90657c1f1006744a20e18cb
DIFF: https://github.com/llvm/llvm-project/commit/71a37dd4b9b46cf7d90657c1f1006744a20e18cb.diff
LOG: [mlir][vector-to-gpu]: Extend MMA Lowerings (#176785)
Add support for lowering non-minor-identity maps during
`vector.transfer_read` and `vector.transfer_write` to
`gpu.subgroup_mma_load_matrix` and `gpu.subgroup_mma_store_matrix`
lowerings. If
the permutation map is a "strided minor identity", that is it jumps some
intermediate dimensions e.g. (d0, d1, d2) -> (d0, d2) then we can
express this stride in the `leadDimension` attribute of the
`gpu.subgroup_mma_load_matrix` and stride over the missing intermediate
dimensions when we load.
Signed-off-by: Jack Frankland <jack.frankland at arm.com>
Added:
Modified:
mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index 98434357f826f..335786f554c02 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -116,12 +116,34 @@ static bool isTransposeMatrixLoadMap(AffineMap permutationMap) {
permutationMap == AffineMap::get(nDim, 0, {innerDim, zero}, ctx);
}
-// Return the stide for the second-to-last dimension of |type| if it is a memref
+// Return true if the given map represents a minor identity map with "strided"
+// dimensions i.e. (d0, d1, ..., dn-1) -> (da, dn-1) where 0 <= a <= n-1.
+//
+// This currently doesn't support permuted or broadcast dimensions.
+static bool isStridedMinorIdentity(AffineMap permutationMap) {
+ if (permutationMap.getNumResults() != 2)
+ return false;
+
+ AffineExpr innerResult = permutationMap.getResult(1);
+ const unsigned nDim = permutationMap.getNumDims();
+ if (innerResult != getAffineDimExpr(nDim - 1, permutationMap.getContext()))
+ return false;
+
+ return isa<AffineDimExpr>(permutationMap.getResult(0));
+}
+
+// Return the stride for the "row" dimension of |type| if it is a memref
// and has a constant stride.
-static std::optional<int64_t> getStaticallyKnownRowStride(ShapedType type) {
+static std::optional<int64_t>
+getStaticallyKnownRowStride(ShapedType type, AffineMap permutationMap) {
auto memrefType = dyn_cast<MemRefType>(type);
if (!memrefType)
- return false;
+ return std::nullopt;
+
+ // We only support permutation maps with two results i.e. loading a 2D matrix.
+ if (2 != permutationMap.getNumResults())
+ return std::nullopt;
+
// If the memref is 0 or 1D the horizontal stride is 0.
if (memrefType.getRank() < 2)
return 0;
@@ -130,7 +152,21 @@ static std::optional<int64_t> getStaticallyKnownRowStride(ShapedType type) {
if (failed(memrefType.getStridesAndOffset(strides, offset)) ||
strides.back() != 1)
return std::nullopt;
- int64_t stride = strides[strides.size() - 2];
+
+ int stridePostion = strides.size() - 2;
+ // We need to be careful if we have a permutation map i.e. (d0, d1) -> (d1,
+ // d0), in this case we want to get the stride of the rows i.e. d0 but the
+ // logic below will extract the column stride d1 because the dims have been
+ // swapped.
+ if (!permutationMap.isPermutation()) {
+ // It's possible we have a constant 0 expression here (the permutation map
+ // must be an affine projection). Projected dimensions are already handled
+ // by the caller.
+ if (auto affineDimExpr =
+ dyn_cast<AffineDimExpr>(permutationMap.getResult(0)))
+ stridePostion = affineDimExpr.getPosition();
+ }
+ const int64_t stride = strides[stridePostion];
if (stride == ShapedType::kDynamic)
return std::nullopt;
return stride;
@@ -141,7 +177,9 @@ static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp) {
if (readOp.getMask() || readOp.hasOutOfBoundsDim() ||
readOp.getVectorType().getRank() != 2)
return false;
- if (!getStaticallyKnownRowStride(readOp.getShapedType()))
+
+ AffineMap map = readOp.getPermutationMap();
+ if (!getStaticallyKnownRowStride(readOp.getShapedType(), map))
return false;
// Only allow integer types if the signedness can be inferred.
@@ -150,13 +188,12 @@ static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp) {
!isa<arith::ExtUIOp>(*readOp->user_begin())))
return false;
- AffineMap map = readOp.getPermutationMap();
MLIRContext *ctx = readOp.getContext();
AffineExpr innerDim = getAffineDimExpr(map.getNumDims() - 1, ctx);
AffineExpr zero = getAffineConstantExpr(0, ctx);
auto broadcastInnerDim =
AffineMap::get(map.getNumDims(), 0, {zero, innerDim}, ctx);
- return map.isMinorIdentity() || map == broadcastInnerDim ||
+ return isStridedMinorIdentity(map) || map == broadcastInnerDim ||
isTransposeMatrixLoadMap(map);
}
@@ -170,12 +207,12 @@ transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp) {
if (writeOp.getMask() || writeOp.hasOutOfBoundsDim() ||
writeOp.getVectorType().getRank() != 2)
return false;
- if (!getStaticallyKnownRowStride(writeOp.getShapedType()))
+
+ AffineMap map = writeOp.getPermutationMap();
+ if (!getStaticallyKnownRowStride(writeOp.getShapedType(), map))
return false;
// TODO: Support transpose once it is added to GPU dialect ops.
- if (!writeOp.getPermutationMap().isMinorIdentity())
- return false;
- return true;
+ return isStridedMinorIdentity(map);
}
/// Return true if the constant is a splat to a 2D vector so that it can be
@@ -547,14 +584,14 @@ convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op,
assert(transferReadSupportsMMAMatrixType(op) &&
"expected convertible operation");
+ AffineMap map = op.getPermutationMap();
std::optional<int64_t> stride =
- getStaticallyKnownRowStride(op.getShapedType());
+ getStaticallyKnownRowStride(op.getShapedType(), map);
if (!stride.has_value()) {
LDBG() << "no stride";
return rewriter.notifyMatchFailure(op, "no stride");
}
- AffineMap map = op.getPermutationMap();
bool isTranspose = isTransposeMatrixLoadMap(map);
// Handle broadcast by setting the stride to 0.
@@ -597,7 +634,7 @@ convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op,
assert(transferWriteSupportsMMAMatrixType(op));
std::optional<int64_t> stride =
- getStaticallyKnownRowStride(op.getShapedType());
+ getStaticallyKnownRowStride(op.getShapedType(), op.getPermutationMap());
if (!stride.has_value()) {
LDBG() << "no stride";
return rewriter.notifyMatchFailure(op, "no stride");
diff --git a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
index ef72901750479..71a748d33831e 100644
--- a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
+++ b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
@@ -555,3 +555,23 @@ func.func @addf(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>, %arg2: memre
vector.transfer_write %C, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<16x16xf16>
return
}
+
+// -----
+
+// CHECK-LABEL: func @matmul_with_strides
+// CHECK-DAG: %[[A:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "AOp">
+// CHECK-DAG: %[[B:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}] {leadDimension = 96 : index} : memref<16x6x16xf16> -> !gpu.mma_matrix<16x16xf16, "BOp">
+// CHECK-DAG: %[[C:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 144 : index} : memref<16x9x16xf16> -> !gpu.mma_matrix<16x16xf16, "COp">
+// CHECK: %[[D:.+]] = gpu.subgroup_mma_compute %[[A]], %[[B]], %[[C]] : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp">
+// CHECK: gpu.subgroup_mma_store_matrix %[[D]], %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}] {leadDimension = 144 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<16x9x16xf16>
+func.func @matmul_with_strides(%arg0: memref<16x16xf16>, %arg1: memref<16x6x16xf16>, %arg2: memref<16x9x16xf16>) {
+ %cst_0 = arith.constant dense<0.000000e+00> : vector<16x16xf16>
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.000000e+00 : f16
+ %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
+ %B = vector.transfer_read %arg1[%c0, %c0, %c0], %cst {permutation_map = affine_map<(d0, d1, d2) -> (d0, d2)>, in_bounds = [true, true]} : memref<16x6x16xf16>, vector<16x16xf16>
+ %C = vector.transfer_read %arg2[%c0, %c0, %c0], %cst {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2) -> (d0, d2)>} : memref<16x9x16xf16>, vector<16x16xf16>
+ %D = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %C : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
+ vector.transfer_write %D, %arg2[%c0, %c0, %c0] {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2) -> (d0, d2)>} : vector<16x16xf16>, memref<16x9x16xf16>
+ return
+}
More information about the Mlir-commits
mailing list