[Mlir-commits] [mlir] [MLIR][Linalg] Specialize linalg.generic to linalg.mmt4d (PR #189719)
Stephen Long
llvmlistbot at llvm.org
Tue Mar 31 10:49:11 PDT 2026
https://github.com/steplong created https://github.com/llvm/llvm-project/pull/189719
Specialize linalg.generic to linalg.mmt4d based on index map
>From 9e227757c454c1523163a2300a05f6245815de3d Mon Sep 17 00:00:00 2001
From: Stephen Long <steplong at quicinc.com>
Date: Tue, 31 Mar 2026 10:33:06 -0700
Subject: [PATCH] [MLIR][Linalg] Specialize linalg.generic to linalg.mmt4d
Specialize linalg.generic to linalg.mmt4d based on index map
---
.../Dialect/Linalg/Transforms/Specialize.cpp | 40 +++++++++
.../test/Dialect/Linalg/specialize-mmt4d.mlir | 86 +++++++++++++++++++
2 files changed, 126 insertions(+)
create mode 100644 mlir/test/Dialect/Linalg/specialize-mmt4d.mlir
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 60b18fb2e8d93..49dd1b7d73b51 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -216,6 +216,44 @@ static std::optional<TypeFn> getCastTypeForMatmulLikeOp(GenericOp genericOp) {
return TypeFn::cast_signed;
}
+static FailureOr<LinalgOp> specializeLinalgMmt4D(RewriterBase &rewriter,
+ GenericOp genericOp,
+ std::optional<TypeFn> castTy,
+ ContractionDimensions &dims) {
+ // Should all be rank 4 and dim 6
+ auto indexingMaps = genericOp.getIndexingMapsArray();
+ if (llvm::any_of(indexingMaps, [](AffineMap m) {
+ return m.getResults().size() != 4 || m.getNumDims() != 6;
+ }))
+ return failure();
+
+ auto aOuter =
+ matchOperandMap(indexingMaps[0], 0, dims.m[0], dims.k[0]);
+ auto aInner =
+ matchOperandMap(indexingMaps[0], 2, dims.m[1], dims.k[1]);
+
+ auto bOuter =
+ matchOperandMap(indexingMaps[1], 0, dims.k[0], dims.n[0]);
+ auto bInner =
+ matchOperandMap(indexingMaps[1], 2, dims.k[1], dims.n[1]);
+
+ auto cOuter =
+ matchOperandMap(indexingMaps[2], 0, dims.m[0], dims.n[0]);
+ auto cInner =
+ matchOperandMap(indexingMaps[2], 2, dims.m[1], dims.n[1]);
+
+ if (llvm::is_contained({aOuter, bOuter, cOuter}, IndexMatchResult::Mismatch))
+ return failure();
+ if (llvm::is_contained({aInner, bInner, cInner}, IndexMatchResult::Mismatch))
+ return failure();
+
+ SmallVector<AffineMap> namedOpMaps =
+ {indexingMaps[0], indexingMaps[1], indexingMaps[2]};
+
+ return replaceWithMatmulVariant<Mmt4DOp>(rewriter, genericOp, castTy,
+ namedOpMaps);
+}
+
// Converts linalg.generic to named linalg.*matmul* where possible.
static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
GenericOp genericOp,
@@ -277,6 +315,8 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
if (!succeeded(res))
return failure();
auto dims = *res;
+ if (dims.m.size() == 2 && dims.n.size() == 2 && dims.k.size() == 2)
+ return specializeLinalgMmt4D(rewriter, genericOp, castTy, dims);
if (dims.m.size() != 1 || dims.n.size() != 1 || dims.k.size() != 1)
return failure();
diff --git a/mlir/test/Dialect/Linalg/specialize-mmt4d.mlir b/mlir/test/Dialect/Linalg/specialize-mmt4d.mlir
new file mode 100644
index 0000000000000..19b1759043434
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/specialize-mmt4d.mlir
@@ -0,0 +1,86 @@
+// RUN: mlir-opt %s -linalg-specialize-generic-ops | FileCheck %s
+
+// CHECK-LABEL: @generic_to_mmt4d
+// CHECK: linalg.mmt4d
+func.func @generic_to_mmt4d(
+ %A : tensor<?x?x?x?xf32>,
+ %B : tensor<?x?x?x?xf32>,
+ %C : tensor<?x?x?x?xf32>
+) -> tensor<?x?x?x?xf32> {
+
+ %0 = linalg.generic {
+ indexing_maps = [
+ affine_map<(m, n, k, m0, n0, k0) -> (m, k, m0, k0)>,
+ affine_map<(m, n, k, m0, n0, k0) -> (n, k, n0, k0)>,
+ affine_map<(m, n, k, m0, n0, k0) -> (m, n, m0, n0)>
+ ],
+ iterator_types = ["parallel", "parallel", "reduction",
+ "parallel", "parallel", "reduction"]
+ }
+ ins(%A, %B : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
+ outs(%C : tensor<?x?x?x?xf32>) {
+ ^bb0(%a : f32, %b : f32, %c : f32):
+ %mul = arith.mulf %a, %b : f32
+ %add = arith.addf %c, %mul : f32
+ linalg.yield %add : f32
+ } -> tensor<?x?x?x?xf32>
+
+ return %0 : tensor<?x?x?x?xf32>
+}
+
+// CHECK-LABEL: @generic_to_mmt4d_transposed_inner
+// CHECK: linalg.mmt4d
+func.func @generic_to_mmt4d_transposed_inner(
+ %A : tensor<?x?x?x?xf32>,
+ %B : tensor<?x?x?x?xf32>,
+ %C : tensor<?x?x?x?xf32>
+) -> tensor<?x?x?x?xf32> {
+
+ %0 = linalg.generic {
+ indexing_maps = [
+ // Inner dims swapped (m0,k0) to (k0,m0)
+ affine_map<(m, n, k, m0, n0, k0) -> (m, k, k0, m0)>,
+ affine_map<(m, n, k, m0, n0, k0) -> (n, k, k0, n0)>,
+ affine_map<(m, n, k, m0, n0, k0) -> (m, n, m0, n0)>
+ ],
+ iterator_types = ["parallel", "parallel", "reduction",
+ "parallel", "parallel", "reduction"]
+ }
+ ins(%A, %B : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
+ outs(%C : tensor<?x?x?x?xf32>) {
+ ^bb0(%a : f32, %b : f32, %c : f32):
+ %mul = arith.mulf %a, %b : f32
+ %add = arith.addf %c, %mul : f32
+ linalg.yield %add : f32
+ } -> tensor<?x?x?x?xf32>
+
+ return %0 : tensor<?x?x?x?xf32>
+}
+
+// CHECK-LABEL: @no_mmt4d_bad_map
+// CHECK-NOT: linalg.mmt4d
+func.func @no_mmt4d_bad_map(
+ %A : tensor<?x?x?x?xf32>,
+ %B : tensor<?x?x?x?xf32>,
+ %C : tensor<?x?x?x?xf32>
+) -> tensor<?x?x?x?xf32> {
+
+ %0 = linalg.generic {
+ indexing_maps = [
+ affine_map<(m, n, k, m0, n0, k0) -> (k, m, m0, k0)>, // bad map
+ affine_map<(m, n, k, m0, n0, k0) -> (n, k, n0, k0)>,
+ affine_map<(m, n, k, m0, n0, k0) -> (m, n, m0, n0)>
+ ],
+ iterator_types = ["parallel", "parallel", "reduction",
+ "parallel", "parallel", "reduction"]
+ }
+ ins(%A, %B : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
+ outs(%C : tensor<?x?x?x?xf32>) {
+ ^bb0(%a : f32, %b : f32, %c : f32):
+ %mul = arith.mulf %a, %b : f32
+ %add = arith.addf %c, %mul : f32
+ linalg.yield %add : f32
+ } -> tensor<?x?x?x?xf32>
+
+ return %0 : tensor<?x?x?x?xf32>
+}
More information about the Mlir-commits
mailing list