[Mlir-commits] [mlir] [MLIR][Linalg] Specialize linalg.generic to linalg.mmt4d (PR #189719)

Stephen Long llvmlistbot at llvm.org
Tue Mar 31 10:49:11 PDT 2026


https://github.com/steplong created https://github.com/llvm/llvm-project/pull/189719

Specialize linalg.generic to linalg.mmt4d based on index map

>From 9e227757c454c1523163a2300a05f6245815de3d Mon Sep 17 00:00:00 2001
From: Stephen Long <steplong at quicinc.com>
Date: Tue, 31 Mar 2026 10:33:06 -0700
Subject: [PATCH] [MLIR][Linalg] Specialize linalg.generic to linalg.mmt4d

Specialize linalg.generic to linalg.mmt4d based on index map
---
 .../Dialect/Linalg/Transforms/Specialize.cpp  | 40 +++++++++
 .../test/Dialect/Linalg/specialize-mmt4d.mlir | 86 +++++++++++++++++++
 2 files changed, 126 insertions(+)
 create mode 100644 mlir/test/Dialect/Linalg/specialize-mmt4d.mlir

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 60b18fb2e8d93..49dd1b7d73b51 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -216,6 +216,44 @@ static std::optional<TypeFn> getCastTypeForMatmulLikeOp(GenericOp genericOp) {
   return TypeFn::cast_signed;
 }
 
+static FailureOr<LinalgOp> specializeLinalgMmt4D(RewriterBase &rewriter,
+                                                 GenericOp genericOp,
+                                                 std::optional<TypeFn> castTy,
+                                                 ContractionDimensions &dims) {
+  // Should all be rank 4 and dim 6
+  auto indexingMaps = genericOp.getIndexingMapsArray();
+  if (llvm::any_of(indexingMaps, [](AffineMap m) {
+        return m.getResults().size() != 4 || m.getNumDims() != 6;
+      }))
+    return failure();
+
+  auto aOuter =
+      matchOperandMap(indexingMaps[0], 0, dims.m[0], dims.k[0]);
+  auto aInner =
+      matchOperandMap(indexingMaps[0], 2, dims.m[1], dims.k[1]);
+
+  auto bOuter =
+      matchOperandMap(indexingMaps[1], 0, dims.k[0], dims.n[0]);
+  auto bInner =
+      matchOperandMap(indexingMaps[1], 2, dims.k[1], dims.n[1]);
+
+  auto cOuter =
+      matchOperandMap(indexingMaps[2], 0, dims.m[0], dims.n[0]);
+  auto cInner =
+      matchOperandMap(indexingMaps[2], 2, dims.m[1], dims.n[1]);
+
+  if (llvm::is_contained({aOuter, bOuter, cOuter}, IndexMatchResult::Mismatch))
+    return failure();
+  if (llvm::is_contained({aInner, bInner, cInner}, IndexMatchResult::Mismatch))
+    return failure();
+
+  SmallVector<AffineMap> namedOpMaps =
+      {indexingMaps[0], indexingMaps[1], indexingMaps[2]};
+
+  return replaceWithMatmulVariant<Mmt4DOp>(rewriter, genericOp, castTy,
+                                           namedOpMaps);
+}
+
 // Converts linalg.generic to named linalg.*matmul* where possible.
 static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
                                                         GenericOp genericOp,
@@ -277,6 +315,8 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
   if (!succeeded(res))
     return failure();
   auto dims = *res;
+  if (dims.m.size() == 2 && dims.n.size() == 2 && dims.k.size() == 2)
+    return specializeLinalgMmt4D(rewriter, genericOp, castTy, dims);
   if (dims.m.size() != 1 || dims.n.size() != 1 || dims.k.size() != 1)
     return failure();
 
diff --git a/mlir/test/Dialect/Linalg/specialize-mmt4d.mlir b/mlir/test/Dialect/Linalg/specialize-mmt4d.mlir
new file mode 100644
index 0000000000000..19b1759043434
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/specialize-mmt4d.mlir
@@ -0,0 +1,86 @@
+// RUN: mlir-opt %s -linalg-specialize-generic-ops | FileCheck %s
+
+// CHECK-LABEL: @generic_to_mmt4d
+// CHECK: linalg.mmt4d
+func.func @generic_to_mmt4d(
+    %A : tensor<?x?x?x?xf32>,
+    %B : tensor<?x?x?x?xf32>,
+    %C : tensor<?x?x?x?xf32>
+) -> tensor<?x?x?x?xf32> {
+
+  %0 = linalg.generic {
+    indexing_maps = [
+      affine_map<(m, n, k, m0, n0, k0) -> (m, k, m0, k0)>,
+      affine_map<(m, n, k, m0, n0, k0) -> (n, k, n0, k0)>,
+      affine_map<(m, n, k, m0, n0, k0) -> (m, n, m0, n0)>
+    ],
+    iterator_types = ["parallel", "parallel", "reduction",
+                      "parallel", "parallel", "reduction"]
+  }
+  ins(%A, %B : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
+  outs(%C : tensor<?x?x?x?xf32>) {
+  ^bb0(%a : f32, %b : f32, %c : f32):
+    %mul = arith.mulf %a, %b : f32
+    %add = arith.addf %c, %mul : f32
+    linalg.yield %add : f32
+  } -> tensor<?x?x?x?xf32>
+
+  return %0 : tensor<?x?x?x?xf32>
+}
+
+// CHECK-LABEL: @generic_to_mmt4d_transposed_inner
+// CHECK: linalg.mmt4d
+func.func @generic_to_mmt4d_transposed_inner(
+    %A : tensor<?x?x?x?xf32>,
+    %B : tensor<?x?x?x?xf32>,
+    %C : tensor<?x?x?x?xf32>
+) -> tensor<?x?x?x?xf32> {
+
+  %0 = linalg.generic {
+    indexing_maps = [
+      // Inner dims swapped (m0,k0) to (k0,m0)
+      affine_map<(m, n, k, m0, n0, k0) -> (m, k, k0, m0)>,
+      affine_map<(m, n, k, m0, n0, k0) -> (n, k, k0, n0)>,
+      affine_map<(m, n, k, m0, n0, k0) -> (m, n, m0, n0)>
+    ],
+    iterator_types = ["parallel", "parallel", "reduction",
+                      "parallel", "parallel", "reduction"]
+  }
+  ins(%A, %B : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
+  outs(%C : tensor<?x?x?x?xf32>) {
+  ^bb0(%a : f32, %b : f32, %c : f32):
+    %mul = arith.mulf %a, %b : f32
+    %add = arith.addf %c, %mul : f32
+    linalg.yield %add : f32
+  } -> tensor<?x?x?x?xf32>
+
+  return %0 : tensor<?x?x?x?xf32>
+}
+
+// CHECK-LABEL: @no_mmt4d_bad_map
+// CHECK-NOT: linalg.mmt4d
+func.func @no_mmt4d_bad_map(
+    %A : tensor<?x?x?x?xf32>,
+    %B : tensor<?x?x?x?xf32>,
+    %C : tensor<?x?x?x?xf32>
+) -> tensor<?x?x?x?xf32> {
+
+  %0 = linalg.generic {
+    indexing_maps = [
+      affine_map<(m, n, k, m0, n0, k0) -> (k, m, m0, k0)>, // bad map
+      affine_map<(m, n, k, m0, n0, k0) -> (n, k, n0, k0)>,
+      affine_map<(m, n, k, m0, n0, k0) -> (m, n, m0, n0)>
+    ],
+    iterator_types = ["parallel", "parallel", "reduction",
+                      "parallel", "parallel", "reduction"]
+  }
+  ins(%A, %B : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
+  outs(%C : tensor<?x?x?x?xf32>) {
+  ^bb0(%a : f32, %b : f32, %c : f32):
+    %mul = arith.mulf %a, %b : f32
+    %add = arith.addf %c, %mul : f32
+    linalg.yield %add : f32
+  } -> tensor<?x?x?x?xf32>
+
+  return %0 : tensor<?x?x?x?xf32>
+}



More information about the Mlir-commits mailing list