[Mlir-commits] [mlir] ca80bda - [MLIR][Linalg] Specialize linalg.generic to linalg.mmt4d (#189719)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Apr 10 07:14:07 PDT 2026


Author: Stephen Long
Date: 2026-04-10T10:14:03-04:00
New Revision: ca80bda446b50c5124ab653c480decf9ae8b1288

URL: https://github.com/llvm/llvm-project/commit/ca80bda446b50c5124ab653c480decf9ae8b1288
DIFF: https://github.com/llvm/llvm-project/commit/ca80bda446b50c5124ab653c480decf9ae8b1288.diff

LOG: [MLIR][Linalg] Specialize linalg.generic to linalg.mmt4d (#189719)

Specialize linalg.generic to linalg.mmt4d based on index map

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
    mlir/test/Dialect/Linalg/specialize-generic-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index a764d1705e85c..a7cd57cf4ed9e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -307,6 +307,38 @@ static std::optional<TypeFn> getCastTypeForMatmulLikeOp(GenericOp genericOp) {
   return TypeFn::cast_signed;
 }
 
+static FailureOr<LinalgOp> specializeLinalgMmt4D(RewriterBase &rewriter,
+                                                 GenericOp genericOp,
+                                                 std::optional<TypeFn> castTy,
+                                                 ContractionDimensions &dims) {
+  // Should all be rank 4 and dim 6
+  auto indexingMaps = genericOp.getIndexingMapsArray();
+  if (llvm::any_of(indexingMaps, [](AffineMap m) {
+        return m.getResults().size() != 4 || m.getNumDims() != 6;
+      }))
+    return failure();
+
+  auto aOuter = matchOperandMap(indexingMaps[0], 0, dims.m[0], dims.k[0]);
+  auto aInner = matchOperandMap(indexingMaps[0], 2, dims.m[1], dims.k[1]);
+
+  auto bOuter = matchOperandMap(indexingMaps[1], 0, dims.k[0], dims.n[0]);
+  auto bInner = matchOperandMap(indexingMaps[1], 2, dims.k[1], dims.n[1]);
+
+  auto cOuter = matchOperandMap(indexingMaps[2], 0, dims.m[0], dims.n[0]);
+  auto cInner = matchOperandMap(indexingMaps[2], 2, dims.m[1], dims.n[1]);
+
+  if (llvm::is_contained({aOuter, bOuter, cOuter}, IndexMatchResult::Mismatch))
+    return failure();
+  if (llvm::is_contained({aInner, bInner, cInner}, IndexMatchResult::Mismatch))
+    return failure();
+
+  SmallVector<AffineMap> namedOpMaps = {indexingMaps[0], indexingMaps[1],
+                                        indexingMaps[2]};
+
+  return replaceWithMatmulVariant<Mmt4DOp>(rewriter, genericOp, castTy,
+                                           namedOpMaps);
+}
+
 // Converts linalg.generic to named linalg.*matmul* where possible.
 static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
                                                         GenericOp genericOp,
@@ -368,6 +400,8 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
   if (!succeeded(res))
     return failure();
   auto dims = *res;
+  if (dims.m.size() == 2 && dims.n.size() == 2 && dims.k.size() == 2)
+    return specializeLinalgMmt4D(rewriter, genericOp, castTy, dims);
   if (dims.m.size() != 1 || dims.n.size() != 1 || dims.k.size() != 1)
     return failure();
 

diff  --git a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
index 37dec828687bd..1cca2b86ddc25 100644
--- a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
+++ b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
@@ -1205,3 +1205,197 @@ func.func @op_batch_matmul_broadcast_b(%A: tensor<2x16x8xf32>, %B: tensor<8xf32>
 
 // CATEGORY-NOT: linalg.generic
 // CATEGORY: linalg.contract
+
+// -----
+
+///----------------------------------------------------------------------------------------
+/// Tests for linalg.mmt4d
+///----------------------------------------------------------------------------------------
+
+#mapA = affine_map<(m, n, k, m0, n0, k0) -> (m, k, m0, k0)>
+#mapB = affine_map<(m, n, k, m0, n0, k0) -> (n, k, n0, k0)>
+#mapC = affine_map<(m, n, k, m0, n0, k0) -> (m, n, m0, n0)>
+func.func @op_mmt4d(%A: tensor<?x?x?x?xf32>, %B: tensor<?x?x?x?xf32>,
+                    %C: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+  %0 = linalg.generic
+    {indexing_maps = [#mapA, #mapB, #mapC],
+    iterator_types = ["parallel", "parallel", "reduction",
+                      "parallel", "parallel", "reduction"]}
+    ins(%A, %B : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
+    outs(%C : tensor<?x?x?x?xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %1 = arith.mulf %in, %in_0 : f32
+    %2 = arith.addf %out, %1 : f32
+    linalg.yield %2 : f32
+  } -> tensor<?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?xf32>
+}
+
+// ALL-LABEL: op_mmt4d
+
+// NAMED-NOT: linalg.generic
+// NAMED: linalg.mmt4d
+
+// CATEGORY-NOT: linalg.generic
+// CATEGORY: linalg.contract
+
+// MMT4D transpose A inner and outer:
+//   A is accessed as (k, m, k0, m0) instead of (m, k, m0, k0)
+#map_tA = affine_map<(m, n, k, m0, n0, k0) -> (k, m, k0, m0)>
+func.func @op_mmt4d_transpose_a(%A: tensor<?x?x?x?xf32>, %B: tensor<?x?x?x?xf32>,
+                                %C: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+  %0 = linalg.generic
+    {indexing_maps = [#map_tA, #mapB, #mapC],
+    iterator_types = ["parallel", "parallel", "reduction",
+                      "parallel", "parallel", "reduction"]}
+    ins(%A, %B : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
+    outs(%C : tensor<?x?x?x?xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %1 = arith.mulf %in, %in_0 : f32
+    %2 = arith.addf %out, %1 : f32
+    linalg.yield %2 : f32
+  } -> tensor<?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?xf32>
+}
+
+// ALL-LABEL: op_mmt4d_transpose_a
+
+// NAMED-NOT: linalg.generic
+// NAMED: linalg.mmt4d
+
+// CATEGORY-NOT: linalg.generic
+// CATEGORY: linalg.contract
+
+// MMT4D transpose B inner and outer:
+//   B is accessed as (k, n, k0, n0) instead of (n, k, n0, k0)
+#map_tB = affine_map<(m, n, k, m0, n0, k0) -> (k, n, k0, n0)>
+func.func @op_mmt4d_transpose_b(%A: tensor<?x?x?x?xf32>, %B: tensor<?x?x?x?xf32>,
+                                %C: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+  %0 = linalg.generic
+    {indexing_maps = [#mapA, #map_tB, #mapC],
+    iterator_types = ["parallel", "parallel", "reduction",
+                      "parallel", "parallel", "reduction"]}
+    ins(%A, %B : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
+    outs(%C : tensor<?x?x?x?xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %1 = arith.mulf %in, %in_0 : f32
+    %2 = arith.addf %out, %1 : f32
+    linalg.yield %2 : f32
+  } -> tensor<?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?xf32>
+}
+
+// ALL-LABEL: op_mmt4d_transpose_b
+
+// NAMED-NOT: linalg.generic
+// NAMED: linalg.mmt4d
+
+// CATEGORY-NOT: linalg.generic
+// CATEGORY: linalg.contract
+
+// MMT4D transpose both A and B inner and outer:
+func.func @op_mmt4d_transpose_a_and_b(
+    %A: tensor<?x?x?x?xf32>, %B: tensor<?x?x?x?xf32>,
+    %C: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+  %0 = linalg.generic
+    {indexing_maps = [#map_tA, #map_tB, #mapC],
+    iterator_types = ["parallel", "parallel", "reduction",
+                      "parallel", "parallel", "reduction"]}
+    ins(%A, %B : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
+    outs(%C : tensor<?x?x?x?xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %1 = arith.mulf %in, %in_0 : f32
+    %2 = arith.addf %out, %1 : f32
+    linalg.yield %2 : f32
+  } -> tensor<?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?xf32>
+}
+
+// ALL-LABEL: op_mmt4d_transpose_a_and_b
+
+// NAMED-NOT: linalg.generic
+// NAMED: linalg.mmt4d
+
+// CATEGORY-NOT: linalg.generic
+// CATEGORY: linalg.contract
+
+// MMT4D transpose C inner and outer:
+//   C is accessed as (n, m, n0, m0) instead of (m, n, m0, n0)
+#map_tC = affine_map<(m, n, k, m0, n0, k0) -> (n, m, n0, m0)>
+func.func @op_mmt4d_transpose_c(%A: tensor<?x?x?x?xf32>, %B: tensor<?x?x?x?xf32>,
+                                %C: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+  %0 = linalg.generic
+    {indexing_maps = [#mapA, #mapB, #map_tC],
+    iterator_types = ["parallel", "parallel", "reduction",
+                      "parallel", "parallel", "reduction"]}
+    ins(%A, %B : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
+    outs(%C : tensor<?x?x?x?xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %1 = arith.mulf %in, %in_0 : f32
+    %2 = arith.addf %out, %1 : f32
+    linalg.yield %2 : f32
+  } -> tensor<?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?xf32>
+}
+
+// ALL-LABEL: op_mmt4d_transpose_c
+
+// NAMED-NOT: linalg.generic
+// NAMED: linalg.mmt4d
+
+// CATEGORY-NOT: linalg.generic
+// CATEGORY: linalg.contract
+
+// MMT4D transpose C inner only:
+//   C is accessed as (m, n, n0, m0) instead of (m, n, m0, n0)
+#map_tC_inner = affine_map<(m, n, k, m0, n0, k0) -> (m, n, n0, m0)>
+func.func @op_mmt4d_transpose_c_inner(%A: tensor<?x?x?x?xf32>, %B: tensor<?x?x?x?xf32>,
+                                      %C: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+  %0 = linalg.generic
+    {indexing_maps = [#mapA, #mapB, #map_tC_inner],
+    iterator_types = ["parallel", "parallel", "reduction",
+                      "parallel", "parallel", "reduction"]}
+    ins(%A, %B : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
+    outs(%C : tensor<?x?x?x?xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %1 = arith.mulf %in, %in_0 : f32
+    %2 = arith.addf %out, %1 : f32
+    linalg.yield %2 : f32
+  } -> tensor<?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?xf32>
+}
+
+// ALL-LABEL: op_mmt4d_transpose_c_inner
+
+// NAMED-NOT: linalg.generic
+// NAMED: linalg.mmt4d
+
+// CATEGORY-NOT: linalg.generic
+// CATEGORY: linalg.contract
+
+// Negative MMT4D:
+// A can only be accessed as inner transpose or outer transpose of (m, k, m0, k0)
+#mapA_negative = affine_map<(m, n, k, m0, n0, k0) -> (n, k, n0, k0)>
+func.func @negative_op_mmt4d(%A: tensor<?x?x?x?xf32>, %B: tensor<?x?x?x?xf32>,
+                             %C: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+  %0 = linalg.generic
+    {indexing_maps = [#mapA_negative, #mapB, #mapC],
+    iterator_types = ["parallel", "parallel", "reduction",
+                      "parallel", "parallel", "reduction"]}
+    ins(%A, %B : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
+    outs(%C : tensor<?x?x?x?xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %1 = arith.mulf %in, %in_0 : f32
+    %2 = arith.addf %out, %1 : f32
+    linalg.yield %2 : f32
+  } -> tensor<?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?xf32>
+}
+
+// ALL-LABEL: negative_op_mmt4d
+
+// NAMED-NOT: linalg.mmt4d
+// NAMED: linalg.generic
+
+// CATEGORY-NOT: linalg.generic
+// CATEGORY: linalg.contract


        


More information about the Mlir-commits mailing list