[Mlir-commits] [mlir] [MLIR] Enable scalable vectorization for linalg.batch_matmul (PR #172333)
Momchil Velikov
llvmlistbot at llvm.org
Mon Dec 15 08:33:07 PST 2025
https://github.com/momchil-velikov created https://github.com/llvm/llvm-project/pull/172333
Also add a missing testcase for fixed size `linalg.batch_matmul` vectorization.
>From de5f8767d95c03d3eee61dfdb23b5c7740a1b995 Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Mon, 15 Dec 2025 15:19:33 +0000
Subject: [PATCH] [MLIR] Enable scalable vectorization for linalg.batch_matmul
Also add a missing testcase for fixed size `linalg.batch_matmul`
vectorization.
---
.../Linalg/Transforms/Vectorization.cpp | 1 +
.../Linalg/vectorization/linalg-ops.mlir | 84 +++++++++++++++++++
2 files changed, 85 insertions(+)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index bb3bccdae0e14..4d7e45aa8036f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -2640,6 +2640,7 @@ vectorizeScalableVectorPrecondition(Operation *op,
// Cond 4: Only the following ops are supported in the
// presence of scalable vectors
return success(isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
+ isa<linalg::BatchMatmulOp>(op) ||
isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
isa<linalg::MatvecOp>(op) || isa<linalg::Mmt4DOp>(op) ||
isa<linalg::BatchMmt4DOp>(op) ||
diff --git a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
index 170bae6141609..1f8762bd3b1ef 100644
--- a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
@@ -1725,3 +1725,87 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+func.func @batch_matmul(%A: memref<?x?x?xf32>, %B: memref<?x?x?xf32>, %C: memref<?x?x?xf32>) {
+ linalg.batch_matmul ins(%A, %B: memref<?x?x?xf32>, memref<?x?x?xf32>)
+ outs(%C: memref<?x?x?xf32>)
+ return
+}
+
+// CHECK-LABEL: func.func @batch_matmul(
+// CHECK-SAME: %[[A:.*]]: memref<?x?x?xf32>, %[[B:.*]]: memref<?x?x?xf32>, %[[C:.*]]: memref<?x?x?xf32>
+// CHECK: %[[c0:.*]] = arith.constant 0 : index
+// CHECK: %[[BATCH_DIM:.*]] = memref.dim %[[A]], %[[c0]] : memref<?x?x?xf32>
+// CHECK: %[[c1:.*]] = arith.constant 1 : index
+// CHECK: %[[M:.*]] = memref.dim %[[A]], %[[c1]] : memref<?x?x?xf32>
+// CHECK: %[[c2:.*]] = arith.constant 2 : index
+// CHECK: %[[N:.*]] = memref.dim %[[B]], %[[c2]] : memref<?x?x?xf32>
+// CHECK: %[[c2_2:.*]] = arith.constant 2 : index
+// CHECK: %[[K:.*]] = memref.dim %[[A]], %[[c2_2]] : memref<?x?x?xf32>
+// CHECK: %[[c0_4:.*]] = arith.constant 0 : index
+// CHECK: %[[P0:.*]] = ub.poison : f32
+// CHECK: %[[MA:.*]] = vector.create_mask %[[BATCH_DIM]], %[[M]], %[[K]] : vector<4x8x4xi1>
+// CHECK: %[[VA:.*]] = vector.mask %[[MA]] { vector.transfer_read %[[A]][%[[c0_4]], %[[c0_4]], %[[c0_4]]], %[[P0]] {in_bounds = [true, true, true, true], permutation_map = #{{.*}}} : memref<?x?x?xf32>, vector<4x8x16x4xf32> } : vector<4x8x4xi1> -> vector<4x8x16x4xf32>
+// CHECK: %[[P1:.*]] = ub.poison : f32
+// CHECK: %[[MB:.*]] = vector.create_mask %[[BATCH_DIM]], %[[K]], %[[N]] : vector<4x4x16xi1>
+// CHECK: %[[VB:.*]] = vector.mask %[[MB]] { vector.transfer_read %[[B]][%[[c0_4]], %[[c0_4]], %[[c0_4]]], %[[P1]] {in_bounds = [true, true, true, true], permutation_map = #{{.*}}} : memref<?x?x?xf32>, vector<4x8x16x4xf32> } : vector<4x4x16xi1> -> vector<4x8x16x4xf32>
+// CHECK: %[[P2:.*]] = ub.poison : f32
+// CHECK: %[[MC:.*]] = vector.create_mask %[[BATCH_DIM]], %[[M]], %[[N]] : vector<4x8x16xi1>
+// CHECK: %[[VC:.*]] = vector.mask %[[MC]] { vector.transfer_read %[[C]][%[[c0_4]], %[[c0_4]], %[[c0_4]]], %[[P2]] {in_bounds = [true, true, true]} : memref<?x?x?xf32>, vector<4x8x16xf32> } : vector<4x8x16xi1> -> vector<4x8x16xf32>
+// CHECK: %[[MUL:.*]] = arith.mulf %[[VA]], %[[VB]] : vector<4x8x16x4xf32>
+// CHECK: %[[MRED:.*]] = vector.create_mask %[[BATCH_DIM]], %[[M]], %[[N]], %[[K]] : vector<4x8x16x4xi1>
+// CHECK: %[[RED:.*]] = vector.mask %[[MRED]] { vector.multi_reduction <add>, %[[MUL]], %[[VC]] [3] : vector<4x8x16x4xf32> to vector<4x8x16xf32> } : vector<4x8x16x4xi1> -> vector<4x8x16xf32>
+// CHECK: %[[c0_5:.*]] = arith.constant 0 : index
+// CHECK: vector.mask %[[MC]] { vector.transfer_write %[[RED]], %[[C]][%[[c0_5]], %[[c0_5]], %[[c0_5]]] {in_bounds = [true, true, true]} : vector<4x8x16xf32>, memref<?x?x?xf32> } : vector<4x8x16xi1>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %matmul = transform.structured.match ops{["linalg.batch_matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %matmul vector_sizes [4, 8, 16, 4] : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @batch_matmul_scalable(%A: memref<?x?x?xf32>, %B: memref<?x?x?xf32>, %C: memref<?x?x?xf32>) {
+ linalg.batch_matmul ins(%A, %B: memref<?x?x?xf32>, memref<?x?x?xf32>)
+ outs(%C: memref<?x?x?xf32>)
+ return
+}
+
+// CHECK-LABEL: func.func @batch_matmul_scalable
+// CHECK-SAME: (%[[A:.*]]: memref<?x?x?xf32>, %[[B:.*]]: memref<?x?x?xf32>, %[[C:.*]]: memref<?x?x?xf32>) {
+// CHECK: %[[c0:.*]] = arith.constant 0 : index
+// CHECK: %[[BATCH_DIM:.*]] = memref.dim %[[A]], %[[c0]] : memref<?x?x?xf32>
+// CHECK: %[[c1:.*]] = arith.constant 1 : index
+// CHECK: %[[M:.*]] = memref.dim %[[A]], %[[c1]] : memref<?x?x?xf32>
+// CHECK: %[[c2:.*]] = arith.constant 2 : index
+// CHECK: %[[N:.*]] = memref.dim %[[B]], %[[c2]] : memref<?x?x?xf32>
+// CHECK: %[[c2_2:.*]] = arith.constant 2 : index
+// CHECK: %[[K:.*]] = memref.dim %[[A]], %[[c2_2]] : memref<?x?x?xf32>
+// CHECK: %[[c0_4:.*]] = arith.constant 0 : index
+// CHECK: %[[P0:.*]] = ub.poison : f32
+// CHECK: %[[MA:.*]] = vector.create_mask %[[BATCH_DIM]], %[[M]], %[[K]] : vector<4x8x4xi1>
+// CHECK: %[[VA:.*]] = vector.mask %[[MA]] { vector.transfer_read %[[A]][%[[c0_4]], %[[c0_4]], %[[c0_4]]], %[[P0]] {in_bounds = [true, true, true, true], permutation_map = #{{.*}}} : memref<?x?x?xf32>, vector<4x8x[16]x4xf32> } : vector<4x8x4xi1> -> vector<4x8x[16]x4xf32>
+// CHECK: %[[P1:.*]] = ub.poison : f32
+// CHECK: %[[MB:.*]] = vector.create_mask %[[BATCH_DIM]], %[[K]], %[[N]] : vector<4x4x[16]xi1>
+// CHECK: %[[VB:.*]] = vector.mask %[[MB]] { vector.transfer_read %[[B]][%[[c0_4]], %[[c0_4]], %[[c0_4]]], %[[P1]] {in_bounds = [true, true, true, true], permutation_map = #{{.*}}} : memref<?x?x?xf32>, vector<4x8x[16]x4xf32> } : vector<4x4x[16]xi1> -> vector<4x8x[16]x4xf32>
+// CHECK: %[[P2:.*]] = ub.poison : f32
+// CHECK: %[[MC:.*]] = vector.create_mask %[[BATCH_DIM]], %[[M]], %[[N]] : vector<4x8x[16]xi1>
+// CHECK: %[[VC:.*]] = vector.mask %[[MC]] { vector.transfer_read %[[C]][%[[c0_4]], %[[c0_4]], %[[c0_4]]], %[[P2]] {in_bounds = [true, true, true]} : memref<?x?x?xf32>, vector<4x8x[16]xf32> } : vector<4x8x[16]xi1> -> vector<4x8x[16]xf32>
+// CHECK: %[[MUL:.*]] = arith.mulf %[[VA]], %[[VB]] : vector<4x8x[16]x4xf32>
+// CHECK: %[[MRED:.*]] = vector.create_mask %[[BATCH_DIM]], %[[M]], %[[N]], %[[K]] : vector<4x8x[16]x4xi1>
+// CHECK: %[[RED:.*]] = vector.mask %[[MRED]] { vector.multi_reduction <add>, %[[MUL]], %[[VC]] [3] : vector<4x8x[16]x4xf32> to vector<4x8x[16]xf32> } : vector<4x8x[16]x4xi1> -> vector<4x8x[16]xf32>
+// CHECK: %[[c0_5:.*]] = arith.constant 0 : index
+// CHECK: vector.mask %[[MC]] { vector.transfer_write %[[RED]], %[[C]][%[[c0_5]], %[[c0_5]], %[[c0_5]]] {in_bounds = [true, true, true]} : vector<4x8x[16]xf32>, memref<?x?x?xf32> } : vector<4x8x[16]xi1>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %matmul = transform.structured.match ops{["linalg.batch_matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %matmul vector_sizes [4, 8, [16], 4] : !transform.any_op
+ transform.yield
+ }
+}
More information about the Mlir-commits
mailing list