[Mlir-commits] [mlir] ca80bda - [MLIR][Linalg] Specialize linalg.generic to linalg.mmt4d (#189719)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Apr 10 07:14:07 PDT 2026
Author: Stephen Long
Date: 2026-04-10T10:14:03-04:00
New Revision: ca80bda446b50c5124ab653c480decf9ae8b1288
URL: https://github.com/llvm/llvm-project/commit/ca80bda446b50c5124ab653c480decf9ae8b1288
DIFF: https://github.com/llvm/llvm-project/commit/ca80bda446b50c5124ab653c480decf9ae8b1288.diff
LOG: [MLIR][Linalg] Specialize linalg.generic to linalg.mmt4d (#189719)
Specialize linalg.generic to linalg.mmt4d based on index map
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index a764d1705e85c..a7cd57cf4ed9e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -307,6 +307,38 @@ static std::optional<TypeFn> getCastTypeForMatmulLikeOp(GenericOp genericOp) {
return TypeFn::cast_signed;
}
+static FailureOr<LinalgOp> specializeLinalgMmt4D(RewriterBase &rewriter,
+ GenericOp genericOp,
+ std::optional<TypeFn> castTy,
+ ContractionDimensions &dims) {
+ // Should all be rank 4 and dim 6
+ auto indexingMaps = genericOp.getIndexingMapsArray();
+ if (llvm::any_of(indexingMaps, [](AffineMap m) {
+ return m.getResults().size() != 4 || m.getNumDims() != 6;
+ }))
+ return failure();
+
+ auto aOuter = matchOperandMap(indexingMaps[0], 0, dims.m[0], dims.k[0]);
+ auto aInner = matchOperandMap(indexingMaps[0], 2, dims.m[1], dims.k[1]);
+
+ auto bOuter = matchOperandMap(indexingMaps[1], 0, dims.k[0], dims.n[0]);
+ auto bInner = matchOperandMap(indexingMaps[1], 2, dims.k[1], dims.n[1]);
+
+ auto cOuter = matchOperandMap(indexingMaps[2], 0, dims.m[0], dims.n[0]);
+ auto cInner = matchOperandMap(indexingMaps[2], 2, dims.m[1], dims.n[1]);
+
+ if (llvm::is_contained({aOuter, bOuter, cOuter}, IndexMatchResult::Mismatch))
+ return failure();
+ if (llvm::is_contained({aInner, bInner, cInner}, IndexMatchResult::Mismatch))
+ return failure();
+
+ SmallVector<AffineMap> namedOpMaps = {indexingMaps[0], indexingMaps[1],
+ indexingMaps[2]};
+
+ return replaceWithMatmulVariant<Mmt4DOp>(rewriter, genericOp, castTy,
+ namedOpMaps);
+}
+
// Converts linalg.generic to named linalg.*matmul* where possible.
static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
GenericOp genericOp,
@@ -368,6 +400,8 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
if (!succeeded(res))
return failure();
auto dims = *res;
+ if (dims.m.size() == 2 && dims.n.size() == 2 && dims.k.size() == 2)
+ return specializeLinalgMmt4D(rewriter, genericOp, castTy, dims);
if (dims.m.size() != 1 || dims.n.size() != 1 || dims.k.size() != 1)
return failure();
diff --git a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
index 37dec828687bd..1cca2b86ddc25 100644
--- a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
+++ b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
@@ -1205,3 +1205,197 @@ func.func @op_batch_matmul_broadcast_b(%A: tensor<2x16x8xf32>, %B: tensor<8xf32>
// CATEGORY-NOT: linalg.generic
// CATEGORY: linalg.contract
+
+// -----
+
+///----------------------------------------------------------------------------------------
+/// Tests for linalg.mmt4d
+///----------------------------------------------------------------------------------------
+
+#mapA = affine_map<(m, n, k, m0, n0, k0) -> (m, k, m0, k0)>
+#mapB = affine_map<(m, n, k, m0, n0, k0) -> (n, k, n0, k0)>
+#mapC = affine_map<(m, n, k, m0, n0, k0) -> (m, n, m0, n0)>
+func.func @op_mmt4d(%A: tensor<?x?x?x?xf32>, %B: tensor<?x?x?x?xf32>,
+ %C: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+ %0 = linalg.generic
+ {indexing_maps = [#mapA, #mapB, #mapC],
+ iterator_types = ["parallel", "parallel", "reduction",
+ "parallel", "parallel", "reduction"]}
+ ins(%A, %B : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
+ outs(%C : tensor<?x?x?x?xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %1 = arith.mulf %in, %in_0 : f32
+ %2 = arith.addf %out, %1 : f32
+ linalg.yield %2 : f32
+ } -> tensor<?x?x?x?xf32>
+ return %0 : tensor<?x?x?x?xf32>
+}
+
+// ALL-LABEL: op_mmt4d
+
+// NAMED-NOT: linalg.generic
+// NAMED: linalg.mmt4d
+
+// CATEGORY-NOT: linalg.generic
+// CATEGORY: linalg.contract
+
+// MMT4D transpose A inner and outer:
+// A is accessed as (k, m, k0, m0) instead of (m, k, m0, k0)
+#map_tA = affine_map<(m, n, k, m0, n0, k0) -> (k, m, k0, m0)>
+func.func @op_mmt4d_transpose_a(%A: tensor<?x?x?x?xf32>, %B: tensor<?x?x?x?xf32>,
+ %C: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+ %0 = linalg.generic
+ {indexing_maps = [#map_tA, #mapB, #mapC],
+ iterator_types = ["parallel", "parallel", "reduction",
+ "parallel", "parallel", "reduction"]}
+ ins(%A, %B : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
+ outs(%C : tensor<?x?x?x?xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %1 = arith.mulf %in, %in_0 : f32
+ %2 = arith.addf %out, %1 : f32
+ linalg.yield %2 : f32
+ } -> tensor<?x?x?x?xf32>
+ return %0 : tensor<?x?x?x?xf32>
+}
+
+// ALL-LABEL: op_mmt4d_transpose_a
+
+// NAMED-NOT: linalg.generic
+// NAMED: linalg.mmt4d
+
+// CATEGORY-NOT: linalg.generic
+// CATEGORY: linalg.contract
+
+// MMT4D transpose B inner and outer:
+// B is accessed as (k, n, k0, n0) instead of (n, k, n0, k0)
+#map_tB = affine_map<(m, n, k, m0, n0, k0) -> (k, n, k0, n0)>
+func.func @op_mmt4d_transpose_b(%A: tensor<?x?x?x?xf32>, %B: tensor<?x?x?x?xf32>,
+ %C: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+ %0 = linalg.generic
+ {indexing_maps = [#mapA, #map_tB, #mapC],
+ iterator_types = ["parallel", "parallel", "reduction",
+ "parallel", "parallel", "reduction"]}
+ ins(%A, %B : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
+ outs(%C : tensor<?x?x?x?xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %1 = arith.mulf %in, %in_0 : f32
+ %2 = arith.addf %out, %1 : f32
+ linalg.yield %2 : f32
+ } -> tensor<?x?x?x?xf32>
+ return %0 : tensor<?x?x?x?xf32>
+}
+
+// ALL-LABEL: op_mmt4d_transpose_b
+
+// NAMED-NOT: linalg.generic
+// NAMED: linalg.mmt4d
+
+// CATEGORY-NOT: linalg.generic
+// CATEGORY: linalg.contract
+
+// MMT4D transpose both A and B inner and outer:
+func.func @op_mmt4d_transpose_a_and_b(
+ %A: tensor<?x?x?x?xf32>, %B: tensor<?x?x?x?xf32>,
+ %C: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+ %0 = linalg.generic
+ {indexing_maps = [#map_tA, #map_tB, #mapC],
+ iterator_types = ["parallel", "parallel", "reduction",
+ "parallel", "parallel", "reduction"]}
+ ins(%A, %B : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
+ outs(%C : tensor<?x?x?x?xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %1 = arith.mulf %in, %in_0 : f32
+ %2 = arith.addf %out, %1 : f32
+ linalg.yield %2 : f32
+ } -> tensor<?x?x?x?xf32>
+ return %0 : tensor<?x?x?x?xf32>
+}
+
+// ALL-LABEL: op_mmt4d_transpose_a_and_b
+
+// NAMED-NOT: linalg.generic
+// NAMED: linalg.mmt4d
+
+// CATEGORY-NOT: linalg.generic
+// CATEGORY: linalg.contract
+
+// MMT4D transpose C inner and outer:
+// C is accessed as (n, m, n0, m0) instead of (m, n, m0, n0)
+#map_tC = affine_map<(m, n, k, m0, n0, k0) -> (n, m, n0, m0)>
+func.func @op_mmt4d_transpose_c(%A: tensor<?x?x?x?xf32>, %B: tensor<?x?x?x?xf32>,
+ %C: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+ %0 = linalg.generic
+ {indexing_maps = [#mapA, #mapB, #map_tC],
+ iterator_types = ["parallel", "parallel", "reduction",
+ "parallel", "parallel", "reduction"]}
+ ins(%A, %B : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
+ outs(%C : tensor<?x?x?x?xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %1 = arith.mulf %in, %in_0 : f32
+ %2 = arith.addf %out, %1 : f32
+ linalg.yield %2 : f32
+ } -> tensor<?x?x?x?xf32>
+ return %0 : tensor<?x?x?x?xf32>
+}
+
+// ALL-LABEL: op_mmt4d_transpose_c
+
+// NAMED-NOT: linalg.generic
+// NAMED: linalg.mmt4d
+
+// CATEGORY-NOT: linalg.generic
+// CATEGORY: linalg.contract
+
+// MMT4D transpose C inner only:
+// C is accessed as (m, n, n0, m0) instead of (m, n, m0, n0)
+#map_tC_inner = affine_map<(m, n, k, m0, n0, k0) -> (m, n, n0, m0)>
+func.func @op_mmt4d_transpose_c_inner(%A: tensor<?x?x?x?xf32>, %B: tensor<?x?x?x?xf32>,
+ %C: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+ %0 = linalg.generic
+ {indexing_maps = [#mapA, #mapB, #map_tC_inner],
+ iterator_types = ["parallel", "parallel", "reduction",
+ "parallel", "parallel", "reduction"]}
+ ins(%A, %B : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
+ outs(%C : tensor<?x?x?x?xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %1 = arith.mulf %in, %in_0 : f32
+ %2 = arith.addf %out, %1 : f32
+ linalg.yield %2 : f32
+ } -> tensor<?x?x?x?xf32>
+ return %0 : tensor<?x?x?x?xf32>
+}
+
+// ALL-LABEL: op_mmt4d_transpose_c_inner
+
+// NAMED-NOT: linalg.generic
+// NAMED: linalg.mmt4d
+
+// CATEGORY-NOT: linalg.generic
+// CATEGORY: linalg.contract
+
+// Negative MMT4D:
+// A can only be accessed as inner transpose or outer transpose of (m, k, m0, k0)
+#mapA_negative = affine_map<(m, n, k, m0, n0, k0) -> (n, k, n0, k0)>
+func.func @negative_op_mmt4d(%A: tensor<?x?x?x?xf32>, %B: tensor<?x?x?x?xf32>,
+ %C: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+ %0 = linalg.generic
+ {indexing_maps = [#mapA_negative, #mapB, #mapC],
+ iterator_types = ["parallel", "parallel", "reduction",
+ "parallel", "parallel", "reduction"]}
+ ins(%A, %B : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
+ outs(%C : tensor<?x?x?x?xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %1 = arith.mulf %in, %in_0 : f32
+ %2 = arith.addf %out, %1 : f32
+ linalg.yield %2 : f32
+ } -> tensor<?x?x?x?xf32>
+ return %0 : tensor<?x?x?x?xf32>
+}
+
+// ALL-LABEL: negative_op_mmt4d
+
+// NAMED-NOT: linalg.mmt4d
+// NAMED: linalg.generic
+
+// CATEGORY-NOT: linalg.generic
+// CATEGORY: linalg.contract
More information about the Mlir-commits
mailing list