[Mlir-commits] [mlir] [MLIR][Linalg] Specialize linalg.generic to linalg.mmt4d (PR #189719)

Stephen Long llvmlistbot at llvm.org
Thu Apr 9 07:11:49 PDT 2026


https://github.com/steplong updated https://github.com/llvm/llvm-project/pull/189719

>From c388d5a942e9597f82e338a5a8095636f97cfb55 Mon Sep 17 00:00:00 2001
From: Stephen Long <steplong at quicinc.com>
Date: Tue, 31 Mar 2026 10:33:06 -0700
Subject: [PATCH 1/8] [MLIR][Linalg] Specialize linalg.generic to linalg.mmt4d

Specialize linalg.generic to linalg.mmt4d based on index map
---
 .../Dialect/Linalg/Transforms/Specialize.cpp  | 40 +++++++++
 .../test/Dialect/Linalg/specialize-mmt4d.mlir | 86 +++++++++++++++++++
 2 files changed, 126 insertions(+)
 create mode 100644 mlir/test/Dialect/Linalg/specialize-mmt4d.mlir

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index a764d1705e85c..292fcc2343b83 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -307,6 +307,44 @@ static std::optional<TypeFn> getCastTypeForMatmulLikeOp(GenericOp genericOp) {
   return TypeFn::cast_signed;
 }
 
+static FailureOr<LinalgOp> specializeLinalgMmt4D(RewriterBase &rewriter,
+                                                 GenericOp genericOp,
+                                                 std::optional<TypeFn> castTy,
+                                                 ContractionDimensions &dims) {
+  // Should all be rank 4 and dim 6
+  auto indexingMaps = genericOp.getIndexingMapsArray();
+  if (llvm::any_of(indexingMaps, [](AffineMap m) {
+        return m.getResults().size() != 4 || m.getNumDims() != 6;
+      }))
+    return failure();
+
+  auto aOuter =
+      matchOperandMap(indexingMaps[0], 0, dims.m[0], dims.k[0]);
+  auto aInner =
+      matchOperandMap(indexingMaps[0], 2, dims.m[1], dims.k[1]);
+
+  auto bOuter =
+      matchOperandMap(indexingMaps[1], 0, dims.k[0], dims.n[0]);
+  auto bInner =
+      matchOperandMap(indexingMaps[1], 2, dims.k[1], dims.n[1]);
+
+  auto cOuter =
+      matchOperandMap(indexingMaps[2], 0, dims.m[0], dims.n[0]);
+  auto cInner =
+      matchOperandMap(indexingMaps[2], 2, dims.m[1], dims.n[1]);
+
+  if (llvm::is_contained({aOuter, bOuter, cOuter}, IndexMatchResult::Mismatch))
+    return failure();
+  if (llvm::is_contained({aInner, bInner, cInner}, IndexMatchResult::Mismatch))
+    return failure();
+
+  SmallVector<AffineMap> namedOpMaps =
+      {indexingMaps[0], indexingMaps[1], indexingMaps[2]};
+
+  return replaceWithMatmulVariant<Mmt4DOp>(rewriter, genericOp, castTy,
+                                           namedOpMaps);
+}
+
 // Converts linalg.generic to named linalg.*matmul* where possible.
 static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
                                                         GenericOp genericOp,
@@ -368,6 +406,8 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
   if (!succeeded(res))
     return failure();
   auto dims = *res;
+  if (dims.m.size() == 2 && dims.n.size() == 2 && dims.k.size() == 2)
+    return specializeLinalgMmt4D(rewriter, genericOp, castTy, dims);
   if (dims.m.size() != 1 || dims.n.size() != 1 || dims.k.size() != 1)
     return failure();
 
diff --git a/mlir/test/Dialect/Linalg/specialize-mmt4d.mlir b/mlir/test/Dialect/Linalg/specialize-mmt4d.mlir
new file mode 100644
index 0000000000000..19b1759043434
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/specialize-mmt4d.mlir
@@ -0,0 +1,86 @@
+// RUN: mlir-opt %s -linalg-specialize-generic-ops | FileCheck %s
+
+// CHECK-LABEL: @generic_to_mmt4d
+// CHECK: linalg.mmt4d
+func.func @generic_to_mmt4d(
+    %A : tensor<?x?x?x?xf32>,
+    %B : tensor<?x?x?x?xf32>,
+    %C : tensor<?x?x?x?xf32>
+) -> tensor<?x?x?x?xf32> {
+
+  %0 = linalg.generic {
+    indexing_maps = [
+      affine_map<(m, n, k, m0, n0, k0) -> (m, k, m0, k0)>,
+      affine_map<(m, n, k, m0, n0, k0) -> (n, k, n0, k0)>,
+      affine_map<(m, n, k, m0, n0, k0) -> (m, n, m0, n0)>
+    ],
+    iterator_types = ["parallel", "parallel", "reduction",
+                      "parallel", "parallel", "reduction"]
+  }
+  ins(%A, %B : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
+  outs(%C : tensor<?x?x?x?xf32>) {
+  ^bb0(%a : f32, %b : f32, %c : f32):
+    %mul = arith.mulf %a, %b : f32
+    %add = arith.addf %c, %mul : f32
+    linalg.yield %add : f32
+  } -> tensor<?x?x?x?xf32>
+
+  return %0 : tensor<?x?x?x?xf32>
+}
+
+// CHECK-LABEL: @generic_to_mmt4d_transposed_inner
+// CHECK: linalg.mmt4d
+func.func @generic_to_mmt4d_transposed_inner(
+    %A : tensor<?x?x?x?xf32>,
+    %B : tensor<?x?x?x?xf32>,
+    %C : tensor<?x?x?x?xf32>
+) -> tensor<?x?x?x?xf32> {
+
+  %0 = linalg.generic {
+    indexing_maps = [
+      // Inner dims swapped (m0,k0) to (k0,m0)
+      affine_map<(m, n, k, m0, n0, k0) -> (m, k, k0, m0)>,
+      affine_map<(m, n, k, m0, n0, k0) -> (n, k, k0, n0)>,
+      affine_map<(m, n, k, m0, n0, k0) -> (m, n, m0, n0)>
+    ],
+    iterator_types = ["parallel", "parallel", "reduction",
+                      "parallel", "parallel", "reduction"]
+  }
+  ins(%A, %B : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
+  outs(%C : tensor<?x?x?x?xf32>) {
+  ^bb0(%a : f32, %b : f32, %c : f32):
+    %mul = arith.mulf %a, %b : f32
+    %add = arith.addf %c, %mul : f32
+    linalg.yield %add : f32
+  } -> tensor<?x?x?x?xf32>
+
+  return %0 : tensor<?x?x?x?xf32>
+}
+
+// CHECK-LABEL: @no_mmt4d_bad_map
+// CHECK-NOT: linalg.mmt4d
+func.func @no_mmt4d_bad_map(
+    %A : tensor<?x?x?x?xf32>,
+    %B : tensor<?x?x?x?xf32>,
+    %C : tensor<?x?x?x?xf32>
+) -> tensor<?x?x?x?xf32> {
+
+  %0 = linalg.generic {
+    indexing_maps = [
+      affine_map<(m, n, k, m0, n0, k0) -> (k, m, m0, k0)>, // bad map
+      affine_map<(m, n, k, m0, n0, k0) -> (n, k, n0, k0)>,
+      affine_map<(m, n, k, m0, n0, k0) -> (m, n, m0, n0)>
+    ],
+    iterator_types = ["parallel", "parallel", "reduction",
+                      "parallel", "parallel", "reduction"]
+  }
+  ins(%A, %B : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
+  outs(%C : tensor<?x?x?x?xf32>) {
+  ^bb0(%a : f32, %b : f32, %c : f32):
+    %mul = arith.mulf %a, %b : f32
+    %add = arith.addf %c, %mul : f32
+    linalg.yield %add : f32
+  } -> tensor<?x?x?x?xf32>
+
+  return %0 : tensor<?x?x?x?xf32>
+}

>From fca5f4ae2d41ab0aaed6ede69801aec5c5f1bb09 Mon Sep 17 00:00:00 2001
From: Stephen Long <steplong at quicinc.com>
Date: Tue, 31 Mar 2026 11:03:59 -0700
Subject: [PATCH 2/8] Fix formatting

---
 .../Dialect/Linalg/Transforms/Specialize.cpp  | 22 +++++++------------
 1 file changed, 8 insertions(+), 14 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 292fcc2343b83..a7cd57cf4ed9e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -318,28 +318,22 @@ static FailureOr<LinalgOp> specializeLinalgMmt4D(RewriterBase &rewriter,
       }))
     return failure();
 
-  auto aOuter =
-      matchOperandMap(indexingMaps[0], 0, dims.m[0], dims.k[0]);
-  auto aInner =
-      matchOperandMap(indexingMaps[0], 2, dims.m[1], dims.k[1]);
+  auto aOuter = matchOperandMap(indexingMaps[0], 0, dims.m[0], dims.k[0]);
+  auto aInner = matchOperandMap(indexingMaps[0], 2, dims.m[1], dims.k[1]);
 
-  auto bOuter =
-      matchOperandMap(indexingMaps[1], 0, dims.k[0], dims.n[0]);
-  auto bInner =
-      matchOperandMap(indexingMaps[1], 2, dims.k[1], dims.n[1]);
+  auto bOuter = matchOperandMap(indexingMaps[1], 0, dims.k[0], dims.n[0]);
+  auto bInner = matchOperandMap(indexingMaps[1], 2, dims.k[1], dims.n[1]);
 
-  auto cOuter =
-      matchOperandMap(indexingMaps[2], 0, dims.m[0], dims.n[0]);
-  auto cInner =
-      matchOperandMap(indexingMaps[2], 2, dims.m[1], dims.n[1]);
+  auto cOuter = matchOperandMap(indexingMaps[2], 0, dims.m[0], dims.n[0]);
+  auto cInner = matchOperandMap(indexingMaps[2], 2, dims.m[1], dims.n[1]);
 
   if (llvm::is_contained({aOuter, bOuter, cOuter}, IndexMatchResult::Mismatch))
     return failure();
   if (llvm::is_contained({aInner, bInner, cInner}, IndexMatchResult::Mismatch))
     return failure();
 
-  SmallVector<AffineMap> namedOpMaps =
-      {indexingMaps[0], indexingMaps[1], indexingMaps[2]};
+  SmallVector<AffineMap> namedOpMaps = {indexingMaps[0], indexingMaps[1],
+                                        indexingMaps[2]};
 
   return replaceWithMatmulVariant<Mmt4DOp>(rewriter, genericOp, castTy,
                                            namedOpMaps);

>From 98732b56cf8fb64b50bfa11d2b3242b2a9dc075e Mon Sep 17 00:00:00 2001
From: Stephen Long <steplong at quicinc.com>
Date: Tue, 31 Mar 2026 11:17:23 -0700
Subject: [PATCH 3/8] Fix bad map test. Not sure if we should be accepting this

---
 mlir/test/Dialect/Linalg/specialize-mmt4d.mlir | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/test/Dialect/Linalg/specialize-mmt4d.mlir b/mlir/test/Dialect/Linalg/specialize-mmt4d.mlir
index 19b1759043434..41be0665965c8 100644
--- a/mlir/test/Dialect/Linalg/specialize-mmt4d.mlir
+++ b/mlir/test/Dialect/Linalg/specialize-mmt4d.mlir
@@ -67,7 +67,7 @@ func.func @no_mmt4d_bad_map(
 
   %0 = linalg.generic {
     indexing_maps = [
-      affine_map<(m, n, k, m0, n0, k0) -> (k, m, m0, k0)>, // bad map
+      affine_map<(m, n, k, m0, n0, k0) -> (k, n, k0, n0)>, // bad map
       affine_map<(m, n, k, m0, n0, k0) -> (n, k, n0, k0)>,
       affine_map<(m, n, k, m0, n0, k0) -> (m, n, m0, n0)>
     ],

>From e746c1c5ea4224ac0387e0c8e49a5a8e4fc7cd62 Mon Sep 17 00:00:00 2001
From: Stephen Long <steplong at quicinc.com>
Date: Thu, 2 Apr 2026 11:35:34 -0700
Subject: [PATCH 4/8] Fixup test cases (i.e. move to
 specialize-generic-ops.mlir)

---
 .../Linalg/specialize-generic-ops.mlir        | 167 ++++++++++++++++++
 .../test/Dialect/Linalg/specialize-mmt4d.mlir |  86 ---------
 2 files changed, 167 insertions(+), 86 deletions(-)
 delete mode 100644 mlir/test/Dialect/Linalg/specialize-mmt4d.mlir

diff --git a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
index 37dec828687bd..f5df59781bf12 100644
--- a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
+++ b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
@@ -1205,3 +1205,170 @@ func.func @op_batch_matmul_broadcast_b(%A: tensor<2x16x8xf32>, %B: tensor<8xf32>
 
 // CATEGORY-NOT: linalg.generic
 // CATEGORY: linalg.contract
+
+// -----
+
+///----------------------------------------------------------------------------------------
+/// Tests for linalg.mmt4d
+///----------------------------------------------------------------------------------------
+
+#mapA = affine_map<(m, n, k, m0, n0, k0) -> (m, k, m0, k0)>
+#mapB = affine_map<(m, n, k, m0, n0, k0) -> (n, k, n0, k0)>
+#mapC = affine_map<(m, n, k, m0, n0, k0) -> (m, n, m0, n0)>
+func.func @op_mmt4d(%A: tensor<?x?x?x?xf32>, %B: tensor<?x?x?x?xf32>,
+                    %C: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+  %0 = linalg.generic
+    {indexing_maps = [#mapA, #mapB, #mapC],
+    iterator_types = ["parallel", "parallel", "reduction",
+                      "parallel", "parallel", "reduction"]}
+    ins(%A, %B : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
+    outs(%C : tensor<?x?x?x?xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %1 = arith.mulf %in, %in_0 : f32
+    %2 = arith.addf %out, %1 : f32
+    linalg.yield %2 : f32
+  } -> tensor<?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?xf32>
+}
+
+// ALL-LABEL: op_mmt4d
+
+// NOT-NAMED: linalg.generic
+// NAMED: linalg.mmt4d
+
+// CATEGORY-NOT: linalg.generic
+// CATEGORY: linalg.contract
+
+// Matmul transpose A inner and outer:
+//   A is accessed as (k, m, k0, m0) instead of (m, k, m0, k0)
+#map_tA = affine_map<(m, n, k, m0, n0, k0) -> (k, m, k0, m0)>
+func.func @op_mmt4d_transpose_a(%A: tensor<?x?x?x?xf32>, %B: tensor<?x?x?x?xf32>,
+                                %C: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+  %0 = linalg.generic
+    {indexing_maps = [#map_tA, #mapB, #mapC],
+    iterator_types = ["parallel", "parallel", "reduction",
+                      "parallel", "parallel", "reduction"]}
+    ins(%A, %B : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
+    outs(%C : tensor<?x?x?x?xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %1 = arith.mulf %in, %in_0 : f32
+    %2 = arith.addf %out, %1 : f32
+    linalg.yield %2 : f32
+  } -> tensor<?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?xf32>
+}
+
+// ALL-LABEL: op_mmt4d_transpose_a
+
+// NOT-NAMED: linalg.generic
+// NAMED: linalg.mmt4d
+
+// CATEGORY-NOT: linalg.generic
+// CATEGORY: linalg.contract
+
+// Matmul transpose B inner and outer:
+//   B is accessed as (k, n, k0, n0) instead of (n, k, n0, k0)
+#map_tB = affine_map<(m, n, k, m0, n0, k0) -> (k, n, k0, n0)>
+func.func @op_mmt4d_transpose_b(%A: tensor<?x?x?x?xf32>, %B: tensor<?x?x?x?xf32>,
+                                %C: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+  %0 = linalg.generic
+    {indexing_maps = [#mapA, #map_tB, #mapC],
+    iterator_types = ["parallel", "parallel", "reduction",
+                      "parallel", "parallel", "reduction"]}
+    ins(%A, %B : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
+    outs(%C : tensor<?x?x?x?xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %1 = arith.mulf %in, %in_0 : f32
+    %2 = arith.addf %out, %1 : f32
+    linalg.yield %2 : f32
+  } -> tensor<?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?xf32>
+}
+
+// ALL-LABEL: op_mmt4d_transpose_b
+
+// NOT-NAMED: linalg.generic
+// NAMED: linalg.mmt4d
+
+// CATEGORY-NOT: linalg.generic
+// CATEGORY: linalg.contract
+
+// Matmul transpose both A and B inner and outer:
+func.func @op_mmt4d_transpose_a_and_b(
+    %A: tensor<?x?x?x?xf32>, %B: tensor<?x?x?x?xf32>,
+    %C: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+  %0 = linalg.generic
+    {indexing_maps = [#map_tA, #map_tB, #mapC],
+    iterator_types = ["parallel", "parallel", "reduction",
+                      "parallel", "parallel", "reduction"]}
+    ins(%A, %B : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
+    outs(%C : tensor<?x?x?x?xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %1 = arith.mulf %in, %in_0 : f32
+    %2 = arith.addf %out, %1 : f32
+    linalg.yield %2 : f32
+  } -> tensor<?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?xf32>
+}
+
+// ALL-LABEL: op_mmt4d_transpose_a_and_b
+
+// NOT-NAMED: linalg.generic
+// NAMED: linalg.mmt4d
+
+// CATEGORY-NOT: linalg.generic
+// CATEGORY: linalg.contract
+
+// Matmul transpose C inner and outer:
+//   C is accessed as (n, m, n0, m0) instead of (m, n, m0, n0)
+#map_tC = affine_map<(m, n, k, m0, n0, k0) -> (n, m, n0, m0)>
+func.func @op_mmt4d_transpose_c(%A: tensor<?x?x?x?xf32>, %B: tensor<?x?x?x?xf32>,
+                                %C: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+  %0 = linalg.generic
+    {indexing_maps = [#mapA, #mapB, #map_tC],
+    iterator_types = ["parallel", "parallel", "reduction",
+                      "parallel", "parallel", "reduction"]}
+    ins(%A, %B : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
+    outs(%C : tensor<?x?x?x?xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %1 = arith.mulf %in, %in_0 : f32
+    %2 = arith.addf %out, %1 : f32
+    linalg.yield %2 : f32
+  } -> tensor<?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?xf32>
+}
+
+// ALL-LABEL: op_mmt4d_transpose_c
+
+// NOT-NAMED: linalg.generic
+// NAMED: linalg.mmt4d
+
+// CATEGORY-NOT: linalg.generic
+// CATEGORY: linalg.contract
+
+// Matmul transpose C inner only:
+//   C is accessed as (m, n, n0, m0) instead of (m, n, m0, n0)
+#map_tC_inner = affine_map<(m, n, k, m0, n0, k0) -> (m, n, n0, m0)>
+func.func @op_mmt4d_transpose_c_inner(%A: tensor<?x?x?x?xf32>, %B: tensor<?x?x?x?xf32>,
+                                      %C: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+  %0 = linalg.generic
+    {indexing_maps = [#mapA, #mapB, #map_tC_inner],
+    iterator_types = ["parallel", "parallel", "reduction",
+                      "parallel", "parallel", "reduction"]}
+    ins(%A, %B : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
+    outs(%C : tensor<?x?x?x?xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %1 = arith.mulf %in, %in_0 : f32
+    %2 = arith.addf %out, %1 : f32
+    linalg.yield %2 : f32
+  } -> tensor<?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?xf32>
+}
+
+// ALL-LABEL: op_mmt4d_transpose_c_inner
+
+// NOT-NAMED: linalg.generic
+// NAMED: linalg.mmt4d
+
+// CATEGORY-NOT: linalg.generic
+// CATEGORY: linalg.contract
diff --git a/mlir/test/Dialect/Linalg/specialize-mmt4d.mlir b/mlir/test/Dialect/Linalg/specialize-mmt4d.mlir
deleted file mode 100644
index 41be0665965c8..0000000000000
--- a/mlir/test/Dialect/Linalg/specialize-mmt4d.mlir
+++ /dev/null
@@ -1,86 +0,0 @@
-// RUN: mlir-opt %s -linalg-specialize-generic-ops | FileCheck %s
-
-// CHECK-LABEL: @generic_to_mmt4d
-// CHECK: linalg.mmt4d
-func.func @generic_to_mmt4d(
-    %A : tensor<?x?x?x?xf32>,
-    %B : tensor<?x?x?x?xf32>,
-    %C : tensor<?x?x?x?xf32>
-) -> tensor<?x?x?x?xf32> {
-
-  %0 = linalg.generic {
-    indexing_maps = [
-      affine_map<(m, n, k, m0, n0, k0) -> (m, k, m0, k0)>,
-      affine_map<(m, n, k, m0, n0, k0) -> (n, k, n0, k0)>,
-      affine_map<(m, n, k, m0, n0, k0) -> (m, n, m0, n0)>
-    ],
-    iterator_types = ["parallel", "parallel", "reduction",
-                      "parallel", "parallel", "reduction"]
-  }
-  ins(%A, %B : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
-  outs(%C : tensor<?x?x?x?xf32>) {
-  ^bb0(%a : f32, %b : f32, %c : f32):
-    %mul = arith.mulf %a, %b : f32
-    %add = arith.addf %c, %mul : f32
-    linalg.yield %add : f32
-  } -> tensor<?x?x?x?xf32>
-
-  return %0 : tensor<?x?x?x?xf32>
-}
-
-// CHECK-LABEL: @generic_to_mmt4d_transposed_inner
-// CHECK: linalg.mmt4d
-func.func @generic_to_mmt4d_transposed_inner(
-    %A : tensor<?x?x?x?xf32>,
-    %B : tensor<?x?x?x?xf32>,
-    %C : tensor<?x?x?x?xf32>
-) -> tensor<?x?x?x?xf32> {
-
-  %0 = linalg.generic {
-    indexing_maps = [
-      // Inner dims swapped (m0,k0) to (k0,m0)
-      affine_map<(m, n, k, m0, n0, k0) -> (m, k, k0, m0)>,
-      affine_map<(m, n, k, m0, n0, k0) -> (n, k, k0, n0)>,
-      affine_map<(m, n, k, m0, n0, k0) -> (m, n, m0, n0)>
-    ],
-    iterator_types = ["parallel", "parallel", "reduction",
-                      "parallel", "parallel", "reduction"]
-  }
-  ins(%A, %B : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
-  outs(%C : tensor<?x?x?x?xf32>) {
-  ^bb0(%a : f32, %b : f32, %c : f32):
-    %mul = arith.mulf %a, %b : f32
-    %add = arith.addf %c, %mul : f32
-    linalg.yield %add : f32
-  } -> tensor<?x?x?x?xf32>
-
-  return %0 : tensor<?x?x?x?xf32>
-}
-
-// CHECK-LABEL: @no_mmt4d_bad_map
-// CHECK-NOT: linalg.mmt4d
-func.func @no_mmt4d_bad_map(
-    %A : tensor<?x?x?x?xf32>,
-    %B : tensor<?x?x?x?xf32>,
-    %C : tensor<?x?x?x?xf32>
-) -> tensor<?x?x?x?xf32> {
-
-  %0 = linalg.generic {
-    indexing_maps = [
-      affine_map<(m, n, k, m0, n0, k0) -> (k, n, k0, n0)>, // bad map
-      affine_map<(m, n, k, m0, n0, k0) -> (n, k, n0, k0)>,
-      affine_map<(m, n, k, m0, n0, k0) -> (m, n, m0, n0)>
-    ],
-    iterator_types = ["parallel", "parallel", "reduction",
-                      "parallel", "parallel", "reduction"]
-  }
-  ins(%A, %B : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
-  outs(%C : tensor<?x?x?x?xf32>) {
-  ^bb0(%a : f32, %b : f32, %c : f32):
-    %mul = arith.mulf %a, %b : f32
-    %add = arith.addf %c, %mul : f32
-    linalg.yield %add : f32
-  } -> tensor<?x?x?x?xf32>
-
-  return %0 : tensor<?x?x?x?xf32>
-}

>From b4c2f9d8bc87d4b51e18a61abff774c512012c09 Mon Sep 17 00:00:00 2001
From: Stephen Long <steplong at quicinc.com>
Date: Wed, 8 Apr 2026 06:48:20 -0700
Subject: [PATCH 5/8] Fixup NOT-NAMEDs

---
 mlir/test/Dialect/Linalg/specialize-generic-ops.mlir | 12 ++++++------
 1 file changed, 6 insertions(+), 6 deletions(-)

diff --git a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
index f5df59781bf12..0259e553c1bf2 100644
--- a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
+++ b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
@@ -1233,7 +1233,7 @@ func.func @op_mmt4d(%A: tensor<?x?x?x?xf32>, %B: tensor<?x?x?x?xf32>,
 
 // ALL-LABEL: op_mmt4d
 
-// NOT-NAMED: linalg.generic
+// NAMED-NOT: linalg.generic
 // NAMED: linalg.mmt4d
 
 // CATEGORY-NOT: linalg.generic
@@ -1260,7 +1260,7 @@ func.func @op_mmt4d_transpose_a(%A: tensor<?x?x?x?xf32>, %B: tensor<?x?x?x?xf32>
 
 // ALL-LABEL: op_mmt4d_transpose_a
 
-// NOT-NAMED: linalg.generic
+// NAMED-NOT: linalg.generic
 // NAMED: linalg.mmt4d
 
 // CATEGORY-NOT: linalg.generic
@@ -1287,7 +1287,7 @@ func.func @op_mmt4d_transpose_b(%A: tensor<?x?x?x?xf32>, %B: tensor<?x?x?x?xf32>
 
 // ALL-LABEL: op_mmt4d_transpose_b
 
-// NOT-NAMED: linalg.generic
+// NAMED-NOT: linalg.generic
 // NAMED: linalg.mmt4d
 
 // CATEGORY-NOT: linalg.generic
@@ -1313,7 +1313,7 @@ func.func @op_mmt4d_transpose_a_and_b(
 
 // ALL-LABEL: op_mmt4d_transpose_a_and_b
 
-// NOT-NAMED: linalg.generic
+// NAMED-NOT: linalg.generic
 // NAMED: linalg.mmt4d
 
 // CATEGORY-NOT: linalg.generic
@@ -1340,7 +1340,7 @@ func.func @op_mmt4d_transpose_c(%A: tensor<?x?x?x?xf32>, %B: tensor<?x?x?x?xf32>
 
 // ALL-LABEL: op_mmt4d_transpose_c
 
-// NOT-NAMED: linalg.generic
+// NAMED-NOT: linalg.generic
 // NAMED: linalg.mmt4d
 
 // CATEGORY-NOT: linalg.generic
@@ -1367,7 +1367,7 @@ func.func @op_mmt4d_transpose_c_inner(%A: tensor<?x?x?x?xf32>, %B: tensor<?x?x?x
 
 // ALL-LABEL: op_mmt4d_transpose_c_inner
 
-// NOT-NAMED: linalg.generic
+// NAMED-NOT: linalg.generic
 // NAMED: linalg.mmt4d
 
 // CATEGORY-NOT: linalg.generic

>From 535f6e309181706952331bf2f9d07beeed3bf592 Mon Sep 17 00:00:00 2001
From: Stephen Long <steplong at quicinc.com>
Date: Wed, 8 Apr 2026 07:02:42 -0700
Subject: [PATCH 6/8] Add negative test for mmt4d

---
 .../Linalg/specialize-generic-ops.mlir        | 24 +++++++++++++++++++
 1 file changed, 24 insertions(+)

diff --git a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
index 0259e553c1bf2..87520f0e4e556 100644
--- a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
+++ b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
@@ -1372,3 +1372,27 @@ func.func @op_mmt4d_transpose_c_inner(%A: tensor<?x?x?x?xf32>, %B: tensor<?x?x?x
 
 // CATEGORY-NOT: linalg.generic
 // CATEGORY: linalg.contract
+
+// Negative MMT4D:
+// A can only be accessed as inner transpose or outer transpose of (m, k, m0, k0)
+#mapA_negative = affine_map<(m, n, k, m0, n0, k0) -> (n, k, n0, k0)>
+func.func @negative_op_mmt4d(%A: tensor<?x?x?x?xf32>, %B: tensor<?x?x?x?xf32>,
+                             %C: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+  %0 = linalg.generic
+    {indexing_maps = [#mapA_negative, #mapB, #mapC],
+    iterator_types = ["parallel", "parallel", "reduction",
+                      "parallel", "parallel", "reduction"]}
+    ins(%A, %B : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
+    outs(%C : tensor<?x?x?x?xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %1 = arith.mulf %in, %in_0 : f32
+    %2 = arith.addf %out, %1 : f32
+    linalg.yield %2 : f32
+  } -> tensor<?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?xf32>
+}
+
+// ALL-LABEL: negative_op_mmt4d
+
+// NAMED-NOT: linalg.mmt4d
+// NAMED: linalg.generic

>From 07047fae191e2632186743ead7837437bc96bd82 Mon Sep 17 00:00:00 2001
From: Stephen Long <steplong at quicinc.com>
Date: Wed, 8 Apr 2026 07:04:23 -0700
Subject: [PATCH 7/8] Fix some comments mentioning matmul instead of mmt4d

---
 mlir/test/Dialect/Linalg/specialize-generic-ops.mlir | 10 +++++-----
 1 file changed, 5 insertions(+), 5 deletions(-)

diff --git a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
index 87520f0e4e556..9b8bcd8361f36 100644
--- a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
+++ b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
@@ -1239,7 +1239,7 @@ func.func @op_mmt4d(%A: tensor<?x?x?x?xf32>, %B: tensor<?x?x?x?xf32>,
 // CATEGORY-NOT: linalg.generic
 // CATEGORY: linalg.contract
 
-// Matmul transpose A inner and outer:
+// MMT4D transpose A inner and outer:
 //   A is accessed as (k, m, k0, m0) instead of (m, k, m0, k0)
 #map_tA = affine_map<(m, n, k, m0, n0, k0) -> (k, m, k0, m0)>
 func.func @op_mmt4d_transpose_a(%A: tensor<?x?x?x?xf32>, %B: tensor<?x?x?x?xf32>,
@@ -1266,7 +1266,7 @@ func.func @op_mmt4d_transpose_a(%A: tensor<?x?x?x?xf32>, %B: tensor<?x?x?x?xf32>
 // CATEGORY-NOT: linalg.generic
 // CATEGORY: linalg.contract
 
-// Matmul transpose B inner and outer:
+// MMT4D transpose B inner and outer:
 //   B is accessed as (k, n, k0, n0) instead of (n, k, n0, k0)
 #map_tB = affine_map<(m, n, k, m0, n0, k0) -> (k, n, k0, n0)>
 func.func @op_mmt4d_transpose_b(%A: tensor<?x?x?x?xf32>, %B: tensor<?x?x?x?xf32>,
@@ -1293,7 +1293,7 @@ func.func @op_mmt4d_transpose_b(%A: tensor<?x?x?x?xf32>, %B: tensor<?x?x?x?xf32>
 // CATEGORY-NOT: linalg.generic
 // CATEGORY: linalg.contract
 
-// Matmul transpose both A and B inner and outer:
+// MMT4D transpose both A and B inner and outer:
 func.func @op_mmt4d_transpose_a_and_b(
     %A: tensor<?x?x?x?xf32>, %B: tensor<?x?x?x?xf32>,
     %C: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
@@ -1319,7 +1319,7 @@ func.func @op_mmt4d_transpose_a_and_b(
 // CATEGORY-NOT: linalg.generic
 // CATEGORY: linalg.contract
 
-// Matmul transpose C inner and outer:
+// MMT4D transpose C inner and outer:
 //   C is accessed as (n, m, n0, m0) instead of (m, n, m0, n0)
 #map_tC = affine_map<(m, n, k, m0, n0, k0) -> (n, m, n0, m0)>
 func.func @op_mmt4d_transpose_c(%A: tensor<?x?x?x?xf32>, %B: tensor<?x?x?x?xf32>,
@@ -1346,7 +1346,7 @@ func.func @op_mmt4d_transpose_c(%A: tensor<?x?x?x?xf32>, %B: tensor<?x?x?x?xf32>
 // CATEGORY-NOT: linalg.generic
 // CATEGORY: linalg.contract
 
-// Matmul transpose C inner only:
+// MMT4D transpose C inner only:
 //   C is accessed as (m, n, n0, m0) instead of (m, n, m0, n0)
 #map_tC_inner = affine_map<(m, n, k, m0, n0, k0) -> (m, n, n0, m0)>
 func.func @op_mmt4d_transpose_c_inner(%A: tensor<?x?x?x?xf32>, %B: tensor<?x?x?x?xf32>,

>From 0c7c435ee544640d195382439c010d5861b34b30 Mon Sep 17 00:00:00 2001
From: Stephen Long <steplong at quicinc.com>
Date: Thu, 9 Apr 2026 07:10:55 -0700
Subject: [PATCH 8/8] Add CATEGORY check for negative mmt4d test

---
 mlir/test/Dialect/Linalg/specialize-generic-ops.mlir | 3 +++
 1 file changed, 3 insertions(+)

diff --git a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
index 9b8bcd8361f36..1cca2b86ddc25 100644
--- a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
+++ b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
@@ -1396,3 +1396,6 @@ func.func @negative_op_mmt4d(%A: tensor<?x?x?x?xf32>, %B: tensor<?x?x?x?xf32>,
 
 // NAMED-NOT: linalg.mmt4d
 // NAMED: linalg.generic
+
+// CATEGORY-NOT: linalg.generic
+// CATEGORY: linalg.contract



More information about the Mlir-commits mailing list