[Mlir-commits] [mlir] be9e84e - [mlir] [linalg] fix failure on specializing matmul with permuted loops (#184294)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Mar 6 00:46:36 PST 2026
Author: ziereis
Date: 2026-03-06T09:46:32+01:00
New Revision: be9e84e1a9087941f4da9259fb64bb975029b0cb
URL: https://github.com/llvm/llvm-project/commit/be9e84e1a9087941f4da9259fb64bb975029b0cb
DIFF: https://github.com/llvm/llvm-project/commit/be9e84e1a9087941f4da9259fb64bb975029b0cb.diff
LOG: [mlir] [linalg] fix failure on specializing matmul with permuted loops (#184294)
This patch fixes generic specialization when the loop dimensions are
permuted in the generic w.r.t. to canonical iterator order of the named
ops by not forwarding the maps of the original generic and instead
recreating them ensuring they always follow the canonical order.
For example, the generic which is to be specialized to a matmul could
have `[parallel, reduction, parallel]` loops, specializing this as is
and just coping the indexing maps like we do now will lead to a
verification error since the dimension will not match the canonical form
the matmul named op expects
e.g. the maps could be:
```
(m, k, n) -> (m,k)
...
```
So we would have to recreate the maps to be:
```
(m,n,k) -> (m,k)
...
```
Assisted by: Claude Code
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index d74335e3c08c9..b4de2bb1e1169 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -139,7 +139,8 @@ static IndexMatchResult matchOperandMap(AffineMap map, unsigned rowDimIdx,
// attribute is needed for the named matmul op variant.
template <typename NamedOpTy>
static LinalgOp replaceWithMatmulVariant(RewriterBase &rewriter, GenericOp op,
- std::optional<TypeFn> castTy) {
+ std::optional<TypeFn> castTy,
+ ArrayRef<AffineMap> indexingMaps) {
SmallVector<NamedAttribute> castAttrVec;
// Only explicitly specify the cast attribute for unsigned cast; signed is
// the default for linalg.matmul/linalg.batch_matmul.
@@ -147,14 +148,11 @@ static LinalgOp replaceWithMatmulVariant(RewriterBase &rewriter, GenericOp op,
castAttrVec = {rewriter.getNamedAttr(
"cast", TypeFnAttr::get(rewriter.getContext(), *castTy))};
- ArrayAttr indexingMaps = op.getIndexingMaps();
-
- LinalgOp namedOp = rewriter.replaceOpWithNewOp<NamedOpTy>(
+ auto namedOp = rewriter.replaceOpWithNewOp<NamedOpTy>(
op, ValueRange{op.getDpsInputs()[0], op.getDpsInputs()[1]},
ValueRange{op.getDpsInits()[0]}, castAttrVec);
- // Set the original generic's maps to preserve transposed operand semantics.
- namedOp->setAttr("indexing_maps", indexingMaps);
+ namedOp.setIndexingMapsAttr(rewriter.getAffineMapArrayAttr(indexingMaps));
return namedOp;
}
@@ -299,11 +297,38 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
return rewriter.notifyMatchFailure(
genericOp, "contains invalid cast ops for the named matmul op");
- /// Codegen the
diff erent matmul variants.
+ // Build indexing maps for the named op in its canonical dimension ordering
+ auto *ctx = genericOp.getContext();
+ unsigned numLoopDims = numOfBatchDims + 3;
+ unsigned mIdx = numOfBatchDims;
+ unsigned nIdx = mIdx + 1;
+ unsigned kIdx = mIdx + 2;
+
+ // TODO: add support for indexing_maps with broadcasts.
+ auto makeMap = [&](IndexMatchResult match, unsigned rowIdx, unsigned colIdx) {
+ SmallVector<unsigned> tensorDims;
+ for (unsigned i = 0; i < numOfBatchDims; ++i)
+ tensorDims.push_back(i);
+ if (match == IndexMatchResult::Transposed)
+ llvm::append_values(tensorDims, colIdx, rowIdx);
+ else
+ llvm::append_values(tensorDims, rowIdx, colIdx);
+ return AffineMap::getMultiDimMapWithTargets(numLoopDims, tensorDims, ctx);
+ };
+
+ auto mapA = makeMap(a, mIdx, kIdx);
+ auto mapB = makeMap(b, kIdx, nIdx);
+ auto mapC = makeMap(c, mIdx, nIdx);
+
+ SmallVector<AffineMap> namedOpMaps = {mapA, mapB, mapC};
+
+ // Codegen the
diff erent matmul variants.
if (numOfBatchDims) {
- return replaceWithMatmulVariant<BatchMatmulOp>(rewriter, genericOp, castTy);
+ return replaceWithMatmulVariant<BatchMatmulOp>(rewriter, genericOp, castTy,
+ namedOpMaps);
}
- return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp, castTy);
+ return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp, castTy,
+ namedOpMaps);
}
/// Utility to specialize a `genericOp` with a convolution op of type `ConvOpTy`
diff --git a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
index 87218844c5c39..5c58a5fedd639 100644
--- a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
+++ b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
@@ -530,3 +530,213 @@ func.func @op_matmul_transposed_output(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
// CHECK-SAME: indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_TC]]]
// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>)
// CHECK-SAME: outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+// -----
+
+// Matmul with non-canonical loop ordering.
+#map_nc_a = affine_map<(m, k, n) -> (m, k)>
+#map_nc_b = affine_map<(m, k, n) -> (k, n)>
+#map_nc_c = affine_map<(m, k, n) -> (m, n)>
+func.func @op_matmul_non_canonical_loops(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
+ %Out: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.generic
+ {indexing_maps = [#map_nc_a, #map_nc_b, #map_nc_c],
+ iterator_types = ["parallel", "reduction", "parallel"]}
+ ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>) outs(%Out : tensor<?x?xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %1 = arith.mulf %in, %in_0 : f32
+ %2 = arith.addf %out, %1 : f32
+ linalg.yield %2 : f32
+ } -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: op_matmul_non_canonical_loops
+// CHECK-SAME: %[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?xf32>, %[[Out:.+]]: tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.matmul ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+// -----
+
+// Batch matmul with non-canonical loop ordering.
+#map_bnc_a = affine_map<(batch, m, k, n) -> (batch, m, k)>
+#map_bnc_b = affine_map<(batch, m, k, n) -> (batch, k, n)>
+#map_bnc_c = affine_map<(batch, m, k, n) -> (batch, m, n)>
+func.func @op_batch_matmul_non_canonical_loops(%A: tensor<2x16x8xf32>, %B: tensor<2x8x16xf32>,
+ %Out: tensor<2x16x16xf32>) -> tensor<2x16x16xf32> {
+ %0 = linalg.generic
+ {indexing_maps = [#map_bnc_a, #map_bnc_b, #map_bnc_c],
+ iterator_types = ["parallel", "parallel", "reduction", "parallel"]}
+ ins(%A, %B : tensor<2x16x8xf32>, tensor<2x8x16xf32>) outs(%Out : tensor<2x16x16xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %1 = arith.mulf %in, %in_0 : f32
+ %2 = arith.addf %out, %1 : f32
+ linalg.yield %2 : f32
+ } -> tensor<2x16x16xf32>
+ return %0 : tensor<2x16x16xf32>
+}
+
+// CHECK-LABEL: op_batch_matmul_non_canonical_loops
+// CHECK-SAME: %[[A:.+]]: tensor<2x16x8xf32>, %[[B:.+]]: tensor<2x8x16xf32>, %[[Out:.+]]: tensor<2x16x16xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.batch_matmul ins(%[[A]], %[[B]] : tensor<2x16x8xf32>, tensor<2x8x16xf32>) outs(%[[Out]] : tensor<2x16x16xf32>) -> tensor<2x16x16xf32>
+
+// -----
+
+// Matmul with non-canonical loop ordering (d0=m, d1=k, d2=n) and B transposed.
+#map_nc_tb_a = affine_map<(d0, d1, d2) -> (d0, d1)>
+#map_nc_tb_b = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map_nc_tb_c = affine_map<(d0, d1, d2) -> (d0, d2)>
+func.func @op_matmul_non_canonical_transpose_b(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
+ %Out: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.generic
+ {indexing_maps = [#map_nc_tb_a, #map_nc_tb_b, #map_nc_tb_c],
+ iterator_types = ["parallel", "reduction", "parallel"]}
+ ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>) outs(%Out : tensor<?x?xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %1 = arith.mulf %in, %in_0 : f32
+ %2 = arith.addf %out, %1 : f32
+ linalg.yield %2 : f32
+ } -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+// CHECK-DAG: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK-DAG: #[[$MAP_TB:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CHECK-DAG: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: op_matmul_non_canonical_transpose_b
+// CHECK-SAME: %[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?xf32>, %[[Out:.+]]: tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.matmul
+// CHECK-SAME: indexing_maps = [#[[$MAP_A]], #[[$MAP_TB]], #[[$MAP_C]]]
+// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>)
+// CHECK-SAME: outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+// -----
+
+// Batch matmul with non-canonical loop ordering (d0=batch, d1=m, d2=k, d3=n)
+// and B Transposed.
+#map_bnc_tb_a = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+#map_bnc_tb_b = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#map_bnc_tb_c = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+func.func @op_batch_matmul_non_canonical_transpose_b(%A: tensor<2x16x8xf32>, %B: tensor<2x16x8xf32>,
+ %Out: tensor<2x16x16xf32>) -> tensor<2x16x16xf32> {
+ %0 = linalg.generic
+ {indexing_maps = [#map_bnc_tb_a, #map_bnc_tb_b, #map_bnc_tb_c],
+ iterator_types = ["parallel", "parallel", "reduction", "parallel"]}
+ ins(%A, %B : tensor<2x16x8xf32>, tensor<2x16x8xf32>) outs(%Out : tensor<2x16x16xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %1 = arith.mulf %in, %in_0 : f32
+ %2 = arith.addf %out, %1 : f32
+ linalg.yield %2 : f32
+ } -> tensor<2x16x16xf32>
+ return %0 : tensor<2x16x16xf32>
+}
+
+// CHECK-DAG: #[[$MAP_A:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK-DAG: #[[$MAP_TB:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+// CHECK-DAG: #[[$MAP_C:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+// CHECK-LABEL: op_batch_matmul_non_canonical_transpose_b
+// CHECK-SAME: %[[A:.+]]: tensor<2x16x8xf32>, %[[B:.+]]: tensor<2x16x8xf32>, %[[Out:.+]]: tensor<2x16x16xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.batch_matmul
+// CHECK-SAME: indexing_maps = [#[[$MAP_A]], #[[$MAP_TB]], #[[$MAP_C]]]
+// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<2x16x8xf32>, tensor<2x16x8xf32>)
+// CHECK-SAME: outs(%[[Out]] : tensor<2x16x16xf32>) -> tensor<2x16x16xf32>
+
+// -----
+
+// Matmul with fully permuted loop ordering.
+#map_fs_a = affine_map<(d0, d1, d2) -> (d1, d0)>
+#map_fs_b = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map_fs_c = affine_map<(d0, d1, d2) -> (d1, d2)>
+func.func @op_matmul_fully_shuffled_loops(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
+ %Out: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.generic
+ {indexing_maps = [#map_fs_a, #map_fs_b, #map_fs_c],
+ iterator_types = ["reduction", "parallel", "parallel"]}
+ ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>) outs(%Out : tensor<?x?xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %1 = arith.mulf %in, %in_0 : f32
+ %2 = arith.addf %out, %1 : f32
+ linalg.yield %2 : f32
+ } -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: op_matmul_fully_shuffled_loops
+// CHECK-SAME: %[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?xf32>, %[[Out:.+]]: tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.matmul ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+// -----
+
+// TODO: this could also be specialized to a named matmul.
+#map_bcast_a = affine_map<(d0, d1, d2) -> (d2)>
+#map_bcast_b = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map_bcast_c = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @negative_matmul_broadcast_a(%A: tensor<?xf32>, %B: tensor<?x?xf32>,
+ %Out: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.generic
+ {indexing_maps = [#map_bcast_a, #map_bcast_b, #map_bcast_c],
+ iterator_types = ["parallel", "parallel", "reduction"]}
+ ins(%A, %B : tensor<?xf32>, tensor<?x?xf32>) outs(%Out : tensor<?x?xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %1 = arith.mulf %in, %in_0 : f32
+ %2 = arith.addf %out, %1 : f32
+ linalg.yield %2 : f32
+ } -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: negative_matmul_broadcast_a
+// CHECK: linalg.generic
+// CHECK-NOT: linalg.matmul
+
+// -----
+
+// TODO: this could also be specialized to a named batch_matmul.
+#map_bbcast_a = affine_map<(d0, d1, d2, d3) -> (d1, d3)>
+#map_bbcast_b = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#map_bbcast_c = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+func.func @negative_batch_matmul_broadcast_a(%A: tensor<16x8xf32>, %B: tensor<2x8x16xf32>,
+ %Out: tensor<2x16x16xf32>) -> tensor<2x16x16xf32> {
+ %0 = linalg.generic
+ {indexing_maps = [#map_bbcast_a, #map_bbcast_b, #map_bbcast_c],
+ iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
+ ins(%A, %B : tensor<16x8xf32>, tensor<2x8x16xf32>) outs(%Out : tensor<2x16x16xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %1 = arith.mulf %in, %in_0 : f32
+ %2 = arith.addf %out, %1 : f32
+ linalg.yield %2 : f32
+ } -> tensor<2x16x16xf32>
+ return %0 : tensor<2x16x16xf32>
+}
+
+// CHECK-LABEL: negative_batch_matmul_broadcast_a
+// CHECK: linalg.generic
+// CHECK-NOT: linalg.batch_matmul
+
+// -----
+
+// TODO: this could also be specialized to a named batch_matmul.
+#map_bbcast2_a = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map_bbcast2_b = affine_map<(d0, d1, d2, d3) -> (d3)>
+#map_bbcast2_c = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+func.func @negative_batch_matmul_broadcast_b(%A: tensor<2x16x8xf32>, %B: tensor<8xf32>,
+ %Out: tensor<2x16x16xf32>) -> tensor<2x16x16xf32> {
+ %0 = linalg.generic
+ {indexing_maps = [#map_bbcast2_a, #map_bbcast2_b, #map_bbcast2_c],
+ iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
+ ins(%A, %B : tensor<2x16x8xf32>, tensor<8xf32>) outs(%Out : tensor<2x16x16xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %1 = arith.mulf %in, %in_0 : f32
+ %2 = arith.addf %out, %1 : f32
+ linalg.yield %2 : f32
+ } -> tensor<2x16x16xf32>
+ return %0 : tensor<2x16x16xf32>
+}
+
+// CHECK-LABEL: negative_batch_matmul_broadcast_b
+// CHECK: linalg.generic
+// CHECK-NOT: linalg.batch_matmul
More information about the Mlir-commits
mailing list