[Mlir-commits] [mlir] 49f2896 - [MLIR] Enable scalable vectorization for linalg.batch_matmul (#172333)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jan 5 01:43:32 PST 2026
Author: Momchil Velikov
Date: 2026-01-05T09:43:28Z
New Revision: 49f28961484580cd4a3f27aeba8bfc2e30a449b1
URL: https://github.com/llvm/llvm-project/commit/49f28961484580cd4a3f27aeba8bfc2e30a449b1
DIFF: https://github.com/llvm/llvm-project/commit/49f28961484580cd4a3f27aeba8bfc2e30a449b1.diff
LOG: [MLIR] Enable scalable vectorization for linalg.batch_matmul (#172333)
Also add a missing testcase for fixed size `linalg.batch_matmul`
vectorization.
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 2b76c24334c0a..52d651b59bbd0 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -2640,6 +2640,7 @@ vectorizeScalableVectorPrecondition(Operation *op,
// Cond 4: Only the following ops are supported in the
// presence of scalable vectors
return success(isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
+ isa<linalg::BatchMatmulOp>(op) ||
isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
isa<linalg::MatvecOp>(op) || isa<linalg::Mmt4DOp>(op) ||
isa<linalg::BatchMmt4DOp>(op) ||
diff --git a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
index 170bae6141609..a5d94bc4f581c 100644
--- a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
@@ -1725,3 +1725,87 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+func.func @batch_matmul(%A: memref<?x?x?xf32>, %B: memref<?x?x?xf32>, %C: memref<?x?x?xf32>) {
+ linalg.batch_matmul ins(%A, %B: memref<?x?x?xf32>, memref<?x?x?xf32>)
+ outs(%C: memref<?x?x?xf32>)
+ return
+}
+
+// CHECK-LABEL: func.func @batch_matmul(
+// CHECK-SAME: %[[A:.*]]: memref<?x?x?xf32>, %[[B:.*]]: memref<?x?x?xf32>, %[[C:.*]]: memref<?x?x?xf32>
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[BATCH_DIM:.*]] = memref.dim %[[A]], %[[C0]] : memref<?x?x?xf32>
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[M:.*]] = memref.dim %[[A]], %[[C1]] : memref<?x?x?xf32>
+// CHECK: %[[C2:.*]] = arith.constant 2 : index
+// CHECK: %[[N:.*]] = memref.dim %[[B]], %[[C2]] : memref<?x?x?xf32>
+// CHECK: %[[C2_2:.*]] = arith.constant 2 : index
+// CHECK: %[[K:.*]] = memref.dim %[[A]], %[[C2_2]] : memref<?x?x?xf32>
+// CHECK: %[[C0_4:.*]] = arith.constant 0 : index
+// CHECK: %[[P0:.*]] = ub.poison : f32
+// CHECK: %[[MA:.*]] = vector.create_mask %[[BATCH_DIM]], %[[M]], %[[K]] : vector<4x8x4xi1>
+// CHECK: %[[VA:.*]] = vector.mask %[[MA]] { vector.transfer_read %[[A]][%[[C0_4]], %[[C0_4]], %[[C0_4]]], %[[P0]] {in_bounds = [true, true, true, true], permutation_map = #{{.*}}} : memref<?x?x?xf32>, vector<4x8x16x4xf32> } : vector<4x8x4xi1> -> vector<4x8x16x4xf32>
+// CHECK: %[[P1:.*]] = ub.poison : f32
+// CHECK: %[[MB:.*]] = vector.create_mask %[[BATCH_DIM]], %[[K]], %[[N]] : vector<4x4x16xi1>
+// CHECK: %[[VB:.*]] = vector.mask %[[MB]] { vector.transfer_read %[[B]][%[[C0_4]], %[[C0_4]], %[[C0_4]]], %[[P1]] {in_bounds = [true, true, true, true], permutation_map = #{{.*}}} : memref<?x?x?xf32>, vector<4x8x16x4xf32> } : vector<4x4x16xi1> -> vector<4x8x16x4xf32>
+// CHECK: %[[P2:.*]] = ub.poison : f32
+// CHECK: %[[MC:.*]] = vector.create_mask %[[BATCH_DIM]], %[[M]], %[[N]] : vector<4x8x16xi1>
+// CHECK: %[[VC:.*]] = vector.mask %[[MC]] { vector.transfer_read %[[C]][%[[C0_4]], %[[C0_4]], %[[C0_4]]], %[[P2]] {in_bounds = [true, true, true]} : memref<?x?x?xf32>, vector<4x8x16xf32> } : vector<4x8x16xi1> -> vector<4x8x16xf32>
+// CHECK: %[[MUL:.*]] = arith.mulf %[[VA]], %[[VB]] : vector<4x8x16x4xf32>
+// CHECK: %[[MRED:.*]] = vector.create_mask %[[BATCH_DIM]], %[[M]], %[[N]], %[[K]] : vector<4x8x16x4xi1>
+// CHECK: %[[RED:.*]] = vector.mask %[[MRED]] { vector.multi_reduction <add>, %[[MUL]], %[[VC]] [3] : vector<4x8x16x4xf32> to vector<4x8x16xf32> } : vector<4x8x16x4xi1> -> vector<4x8x16xf32>
+// CHECK: %[[C0_5:.*]] = arith.constant 0 : index
+// CHECK: vector.mask %[[MC]] { vector.transfer_write %[[RED]], %[[C]][%[[C0_5]], %[[C0_5]], %[[C0_5]]] {in_bounds = [true, true, true]} : vector<4x8x16xf32>, memref<?x?x?xf32> } : vector<4x8x16xi1>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %matmul = transform.structured.match ops{["linalg.batch_matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %matmul vector_sizes [4, 8, 16, 4] : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @batch_matmul_scalable(%A: memref<?x?x?xf32>, %B: memref<?x?x?xf32>, %C: memref<?x?x?xf32>) {
+ linalg.batch_matmul ins(%A, %B: memref<?x?x?xf32>, memref<?x?x?xf32>)
+ outs(%C: memref<?x?x?xf32>)
+ return
+}
+
+// CHECK-LABEL: func.func @batch_matmul_scalable
+// CHECK-SAME: (%[[A:.*]]: memref<?x?x?xf32>, %[[B:.*]]: memref<?x?x?xf32>, %[[C:.*]]: memref<?x?x?xf32>) {
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[BATCH_DIM:.*]] = memref.dim %[[A]], %[[C0]] : memref<?x?x?xf32>
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[M:.*]] = memref.dim %[[A]], %[[C1]] : memref<?x?x?xf32>
+// CHECK: %[[C2:.*]] = arith.constant 2 : index
+// CHECK: %[[N:.*]] = memref.dim %[[B]], %[[C2]] : memref<?x?x?xf32>
+// CHECK: %[[C2_2:.*]] = arith.constant 2 : index
+// CHECK: %[[K:.*]] = memref.dim %[[A]], %[[C2_2]] : memref<?x?x?xf32>
+// CHECK: %[[C0_4:.*]] = arith.constant 0 : index
+// CHECK: %[[P0:.*]] = ub.poison : f32
+// CHECK: %[[MA:.*]] = vector.create_mask %[[BATCH_DIM]], %[[M]], %[[K]] : vector<4x8x4xi1>
+// CHECK: %[[VA:.*]] = vector.mask %[[MA]] { vector.transfer_read %[[A]][%[[C0_4]], %[[C0_4]], %[[C0_4]]], %[[P0]] {in_bounds = [true, true, true, true], permutation_map = #{{.*}}} : memref<?x?x?xf32>, vector<4x8x[16]x4xf32> } : vector<4x8x4xi1> -> vector<4x8x[16]x4xf32>
+// CHECK: %[[P1:.*]] = ub.poison : f32
+// CHECK: %[[MB:.*]] = vector.create_mask %[[BATCH_DIM]], %[[K]], %[[N]] : vector<4x4x[16]xi1>
+// CHECK: %[[VB:.*]] = vector.mask %[[MB]] { vector.transfer_read %[[B]][%[[C0_4]], %[[C0_4]], %[[C0_4]]], %[[P1]] {in_bounds = [true, true, true, true], permutation_map = #{{.*}}} : memref<?x?x?xf32>, vector<4x8x[16]x4xf32> } : vector<4x4x[16]xi1> -> vector<4x8x[16]x4xf32>
+// CHECK: %[[P2:.*]] = ub.poison : f32
+// CHECK: %[[MC:.*]] = vector.create_mask %[[BATCH_DIM]], %[[M]], %[[N]] : vector<4x8x[16]xi1>
+// CHECK: %[[VC:.*]] = vector.mask %[[MC]] { vector.transfer_read %[[C]][%[[C0_4]], %[[C0_4]], %[[C0_4]]], %[[P2]] {in_bounds = [true, true, true]} : memref<?x?x?xf32>, vector<4x8x[16]xf32> } : vector<4x8x[16]xi1> -> vector<4x8x[16]xf32>
+// CHECK: %[[MUL:.*]] = arith.mulf %[[VA]], %[[VB]] : vector<4x8x[16]x4xf32>
+// CHECK: %[[MRED:.*]] = vector.create_mask %[[BATCH_DIM]], %[[M]], %[[N]], %[[K]] : vector<4x8x[16]x4xi1>
+// CHECK: %[[RED:.*]] = vector.mask %[[MRED]] { vector.multi_reduction <add>, %[[MUL]], %[[VC]] [3] : vector<4x8x[16]x4xf32> to vector<4x8x[16]xf32> } : vector<4x8x[16]x4xi1> -> vector<4x8x[16]xf32>
+// CHECK: %[[C0_5:.*]] = arith.constant 0 : index
+// CHECK: vector.mask %[[MC]] { vector.transfer_write %[[RED]], %[[C]][%[[C0_5]], %[[C0_5]], %[[C0_5]]] {in_bounds = [true, true, true]} : vector<4x8x[16]xf32>, memref<?x?x?xf32> } : vector<4x8x[16]xi1>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %matmul = transform.structured.match ops{["linalg.batch_matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %matmul vector_sizes [4, 8, [16], 4] : !transform.any_op
+ transform.yield
+ }
+}
More information about the Mlir-commits
mailing list