[Mlir-commits] [mlir] [MLIR] Enable scalable vectorization for linalg.batch_matmul (PR #172333)
Momchil Velikov
llvmlistbot at llvm.org
Mon Dec 22 03:22:33 PST 2025
https://github.com/momchil-velikov updated https://github.com/llvm/llvm-project/pull/172333
>From de5f8767d95c03d3eee61dfdb23b5c7740a1b995 Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Mon, 15 Dec 2025 15:19:33 +0000
Subject: [PATCH 1/2] [MLIR] Enable scalable vectorization for
linalg.batch_matmul
Also add a missing testcase for fixed size `linalg.batch_matmul`
vectorization.
---
.../Linalg/Transforms/Vectorization.cpp | 1 +
.../Linalg/vectorization/linalg-ops.mlir | 84 +++++++++++++++++++
2 files changed, 85 insertions(+)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index bb3bccdae0e14..4d7e45aa8036f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -2640,6 +2640,7 @@ vectorizeScalableVectorPrecondition(Operation *op,
// Cond 4: Only the following ops are supported in the
// presence of scalable vectors
return success(isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
+ isa<linalg::BatchMatmulOp>(op) ||
isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
isa<linalg::MatvecOp>(op) || isa<linalg::Mmt4DOp>(op) ||
isa<linalg::BatchMmt4DOp>(op) ||
diff --git a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
index 170bae6141609..1f8762bd3b1ef 100644
--- a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
@@ -1725,3 +1725,87 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+func.func @batch_matmul(%A: memref<?x?x?xf32>, %B: memref<?x?x?xf32>, %C: memref<?x?x?xf32>) {
+ linalg.batch_matmul ins(%A, %B: memref<?x?x?xf32>, memref<?x?x?xf32>)
+ outs(%C: memref<?x?x?xf32>)
+ return
+}
+
+// CHECK-LABEL: func.func @batch_matmul(
+// CHECK-SAME: %[[A:.*]]: memref<?x?x?xf32>, %[[B:.*]]: memref<?x?x?xf32>, %[[C:.*]]: memref<?x?x?xf32>
+// CHECK: %[[c0:.*]] = arith.constant 0 : index
+// CHECK: %[[BATCH_DIM:.*]] = memref.dim %[[A]], %[[c0]] : memref<?x?x?xf32>
+// CHECK: %[[c1:.*]] = arith.constant 1 : index
+// CHECK: %[[M:.*]] = memref.dim %[[A]], %[[c1]] : memref<?x?x?xf32>
+// CHECK: %[[c2:.*]] = arith.constant 2 : index
+// CHECK: %[[N:.*]] = memref.dim %[[B]], %[[c2]] : memref<?x?x?xf32>
+// CHECK: %[[c2_2:.*]] = arith.constant 2 : index
+// CHECK: %[[K:.*]] = memref.dim %[[A]], %[[c2_2]] : memref<?x?x?xf32>
+// CHECK: %[[c0_4:.*]] = arith.constant 0 : index
+// CHECK: %[[P0:.*]] = ub.poison : f32
+// CHECK: %[[MA:.*]] = vector.create_mask %[[BATCH_DIM]], %[[M]], %[[K]] : vector<4x8x4xi1>
+// CHECK: %[[VA:.*]] = vector.mask %[[MA]] { vector.transfer_read %[[A]][%[[c0_4]], %[[c0_4]], %[[c0_4]]], %[[P0]] {in_bounds = [true, true, true, true], permutation_map = #{{.*}}} : memref<?x?x?xf32>, vector<4x8x16x4xf32> } : vector<4x8x4xi1> -> vector<4x8x16x4xf32>
+// CHECK: %[[P1:.*]] = ub.poison : f32
+// CHECK: %[[MB:.*]] = vector.create_mask %[[BATCH_DIM]], %[[K]], %[[N]] : vector<4x4x16xi1>
+// CHECK: %[[VB:.*]] = vector.mask %[[MB]] { vector.transfer_read %[[B]][%[[c0_4]], %[[c0_4]], %[[c0_4]]], %[[P1]] {in_bounds = [true, true, true, true], permutation_map = #{{.*}}} : memref<?x?x?xf32>, vector<4x8x16x4xf32> } : vector<4x4x16xi1> -> vector<4x8x16x4xf32>
+// CHECK: %[[P2:.*]] = ub.poison : f32
+// CHECK: %[[MC:.*]] = vector.create_mask %[[BATCH_DIM]], %[[M]], %[[N]] : vector<4x8x16xi1>
+// CHECK: %[[VC:.*]] = vector.mask %[[MC]] { vector.transfer_read %[[C]][%[[c0_4]], %[[c0_4]], %[[c0_4]]], %[[P2]] {in_bounds = [true, true, true]} : memref<?x?x?xf32>, vector<4x8x16xf32> } : vector<4x8x16xi1> -> vector<4x8x16xf32>
+// CHECK: %[[MUL:.*]] = arith.mulf %[[VA]], %[[VB]] : vector<4x8x16x4xf32>
+// CHECK: %[[MRED:.*]] = vector.create_mask %[[BATCH_DIM]], %[[M]], %[[N]], %[[K]] : vector<4x8x16x4xi1>
+// CHECK: %[[RED:.*]] = vector.mask %[[MRED]] { vector.multi_reduction <add>, %[[MUL]], %[[VC]] [3] : vector<4x8x16x4xf32> to vector<4x8x16xf32> } : vector<4x8x16x4xi1> -> vector<4x8x16xf32>
+// CHECK: %[[c0_5:.*]] = arith.constant 0 : index
+// CHECK: vector.mask %[[MC]] { vector.transfer_write %[[RED]], %[[C]][%[[c0_5]], %[[c0_5]], %[[c0_5]]] {in_bounds = [true, true, true]} : vector<4x8x16xf32>, memref<?x?x?xf32> } : vector<4x8x16xi1>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %matmul = transform.structured.match ops{["linalg.batch_matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %matmul vector_sizes [4, 8, 16, 4] : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @batch_matmul_scalable(%A: memref<?x?x?xf32>, %B: memref<?x?x?xf32>, %C: memref<?x?x?xf32>) {
+ linalg.batch_matmul ins(%A, %B: memref<?x?x?xf32>, memref<?x?x?xf32>)
+ outs(%C: memref<?x?x?xf32>)
+ return
+}
+
+// CHECK-LABEL: func.func @batch_matmul_scalable
+// CHECK-SAME: (%[[A:.*]]: memref<?x?x?xf32>, %[[B:.*]]: memref<?x?x?xf32>, %[[C:.*]]: memref<?x?x?xf32>) {
+// CHECK: %[[c0:.*]] = arith.constant 0 : index
+// CHECK: %[[BATCH_DIM:.*]] = memref.dim %[[A]], %[[c0]] : memref<?x?x?xf32>
+// CHECK: %[[c1:.*]] = arith.constant 1 : index
+// CHECK: %[[M:.*]] = memref.dim %[[A]], %[[c1]] : memref<?x?x?xf32>
+// CHECK: %[[c2:.*]] = arith.constant 2 : index
+// CHECK: %[[N:.*]] = memref.dim %[[B]], %[[c2]] : memref<?x?x?xf32>
+// CHECK: %[[c2_2:.*]] = arith.constant 2 : index
+// CHECK: %[[K:.*]] = memref.dim %[[A]], %[[c2_2]] : memref<?x?x?xf32>
+// CHECK: %[[c0_4:.*]] = arith.constant 0 : index
+// CHECK: %[[P0:.*]] = ub.poison : f32
+// CHECK: %[[MA:.*]] = vector.create_mask %[[BATCH_DIM]], %[[M]], %[[K]] : vector<4x8x4xi1>
+// CHECK: %[[VA:.*]] = vector.mask %[[MA]] { vector.transfer_read %[[A]][%[[c0_4]], %[[c0_4]], %[[c0_4]]], %[[P0]] {in_bounds = [true, true, true, true], permutation_map = #{{.*}}} : memref<?x?x?xf32>, vector<4x8x[16]x4xf32> } : vector<4x8x4xi1> -> vector<4x8x[16]x4xf32>
+// CHECK: %[[P1:.*]] = ub.poison : f32
+// CHECK: %[[MB:.*]] = vector.create_mask %[[BATCH_DIM]], %[[K]], %[[N]] : vector<4x4x[16]xi1>
+// CHECK: %[[VB:.*]] = vector.mask %[[MB]] { vector.transfer_read %[[B]][%[[c0_4]], %[[c0_4]], %[[c0_4]]], %[[P1]] {in_bounds = [true, true, true, true], permutation_map = #{{.*}}} : memref<?x?x?xf32>, vector<4x8x[16]x4xf32> } : vector<4x4x[16]xi1> -> vector<4x8x[16]x4xf32>
+// CHECK: %[[P2:.*]] = ub.poison : f32
+// CHECK: %[[MC:.*]] = vector.create_mask %[[BATCH_DIM]], %[[M]], %[[N]] : vector<4x8x[16]xi1>
+// CHECK: %[[VC:.*]] = vector.mask %[[MC]] { vector.transfer_read %[[C]][%[[c0_4]], %[[c0_4]], %[[c0_4]]], %[[P2]] {in_bounds = [true, true, true]} : memref<?x?x?xf32>, vector<4x8x[16]xf32> } : vector<4x8x[16]xi1> -> vector<4x8x[16]xf32>
+// CHECK: %[[MUL:.*]] = arith.mulf %[[VA]], %[[VB]] : vector<4x8x[16]x4xf32>
+// CHECK: %[[MRED:.*]] = vector.create_mask %[[BATCH_DIM]], %[[M]], %[[N]], %[[K]] : vector<4x8x[16]x4xi1>
+// CHECK: %[[RED:.*]] = vector.mask %[[MRED]] { vector.multi_reduction <add>, %[[MUL]], %[[VC]] [3] : vector<4x8x[16]x4xf32> to vector<4x8x[16]xf32> } : vector<4x8x[16]x4xi1> -> vector<4x8x[16]xf32>
+// CHECK: %[[c0_5:.*]] = arith.constant 0 : index
+// CHECK: vector.mask %[[MC]] { vector.transfer_write %[[RED]], %[[C]][%[[c0_5]], %[[c0_5]], %[[c0_5]]] {in_bounds = [true, true, true]} : vector<4x8x[16]xf32>, memref<?x?x?xf32> } : vector<4x8x[16]xi1>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %matmul = transform.structured.match ops{["linalg.batch_matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %matmul vector_sizes [4, 8, [16], 4] : !transform.any_op
+ transform.yield
+ }
+}
>From c86fccc9ab98208c36d69c23327f2d30d4e22e4e Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Mon, 22 Dec 2025 11:16:53 +0000
Subject: [PATCH 2/2] [fixup] Formatting/capitalization
---
.../Linalg/vectorization/linalg-ops.mlir | 56 +++++++++----------
1 file changed, 28 insertions(+), 28 deletions(-)
diff --git a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
index 1f8762bd3b1ef..a5d94bc4f581c 100644
--- a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
@@ -1736,29 +1736,29 @@ func.func @batch_matmul(%A: memref<?x?x?xf32>, %B: memref<?x?x?xf32>, %C: memref
// CHECK-LABEL: func.func @batch_matmul(
// CHECK-SAME: %[[A:.*]]: memref<?x?x?xf32>, %[[B:.*]]: memref<?x?x?xf32>, %[[C:.*]]: memref<?x?x?xf32>
-// CHECK: %[[c0:.*]] = arith.constant 0 : index
-// CHECK: %[[BATCH_DIM:.*]] = memref.dim %[[A]], %[[c0]] : memref<?x?x?xf32>
-// CHECK: %[[c1:.*]] = arith.constant 1 : index
-// CHECK: %[[M:.*]] = memref.dim %[[A]], %[[c1]] : memref<?x?x?xf32>
-// CHECK: %[[c2:.*]] = arith.constant 2 : index
-// CHECK: %[[N:.*]] = memref.dim %[[B]], %[[c2]] : memref<?x?x?xf32>
-// CHECK: %[[c2_2:.*]] = arith.constant 2 : index
-// CHECK: %[[K:.*]] = memref.dim %[[A]], %[[c2_2]] : memref<?x?x?xf32>
-// CHECK: %[[c0_4:.*]] = arith.constant 0 : index
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[BATCH_DIM:.*]] = memref.dim %[[A]], %[[C0]] : memref<?x?x?xf32>
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[M:.*]] = memref.dim %[[A]], %[[C1]] : memref<?x?x?xf32>
+// CHECK: %[[C2:.*]] = arith.constant 2 : index
+// CHECK: %[[N:.*]] = memref.dim %[[B]], %[[C2]] : memref<?x?x?xf32>
+// CHECK: %[[C2_2:.*]] = arith.constant 2 : index
+// CHECK: %[[K:.*]] = memref.dim %[[A]], %[[C2_2]] : memref<?x?x?xf32>
+// CHECK: %[[C0_4:.*]] = arith.constant 0 : index
// CHECK: %[[P0:.*]] = ub.poison : f32
// CHECK: %[[MA:.*]] = vector.create_mask %[[BATCH_DIM]], %[[M]], %[[K]] : vector<4x8x4xi1>
-// CHECK: %[[VA:.*]] = vector.mask %[[MA]] { vector.transfer_read %[[A]][%[[c0_4]], %[[c0_4]], %[[c0_4]]], %[[P0]] {in_bounds = [true, true, true, true], permutation_map = #{{.*}}} : memref<?x?x?xf32>, vector<4x8x16x4xf32> } : vector<4x8x4xi1> -> vector<4x8x16x4xf32>
+// CHECK: %[[VA:.*]] = vector.mask %[[MA]] { vector.transfer_read %[[A]][%[[C0_4]], %[[C0_4]], %[[C0_4]]], %[[P0]] {in_bounds = [true, true, true, true], permutation_map = #{{.*}}} : memref<?x?x?xf32>, vector<4x8x16x4xf32> } : vector<4x8x4xi1> -> vector<4x8x16x4xf32>
// CHECK: %[[P1:.*]] = ub.poison : f32
// CHECK: %[[MB:.*]] = vector.create_mask %[[BATCH_DIM]], %[[K]], %[[N]] : vector<4x4x16xi1>
-// CHECK: %[[VB:.*]] = vector.mask %[[MB]] { vector.transfer_read %[[B]][%[[c0_4]], %[[c0_4]], %[[c0_4]]], %[[P1]] {in_bounds = [true, true, true, true], permutation_map = #{{.*}}} : memref<?x?x?xf32>, vector<4x8x16x4xf32> } : vector<4x4x16xi1> -> vector<4x8x16x4xf32>
+// CHECK: %[[VB:.*]] = vector.mask %[[MB]] { vector.transfer_read %[[B]][%[[C0_4]], %[[C0_4]], %[[C0_4]]], %[[P1]] {in_bounds = [true, true, true, true], permutation_map = #{{.*}}} : memref<?x?x?xf32>, vector<4x8x16x4xf32> } : vector<4x4x16xi1> -> vector<4x8x16x4xf32>
// CHECK: %[[P2:.*]] = ub.poison : f32
// CHECK: %[[MC:.*]] = vector.create_mask %[[BATCH_DIM]], %[[M]], %[[N]] : vector<4x8x16xi1>
-// CHECK: %[[VC:.*]] = vector.mask %[[MC]] { vector.transfer_read %[[C]][%[[c0_4]], %[[c0_4]], %[[c0_4]]], %[[P2]] {in_bounds = [true, true, true]} : memref<?x?x?xf32>, vector<4x8x16xf32> } : vector<4x8x16xi1> -> vector<4x8x16xf32>
+// CHECK: %[[VC:.*]] = vector.mask %[[MC]] { vector.transfer_read %[[C]][%[[C0_4]], %[[C0_4]], %[[C0_4]]], %[[P2]] {in_bounds = [true, true, true]} : memref<?x?x?xf32>, vector<4x8x16xf32> } : vector<4x8x16xi1> -> vector<4x8x16xf32>
// CHECK: %[[MUL:.*]] = arith.mulf %[[VA]], %[[VB]] : vector<4x8x16x4xf32>
// CHECK: %[[MRED:.*]] = vector.create_mask %[[BATCH_DIM]], %[[M]], %[[N]], %[[K]] : vector<4x8x16x4xi1>
// CHECK: %[[RED:.*]] = vector.mask %[[MRED]] { vector.multi_reduction <add>, %[[MUL]], %[[VC]] [3] : vector<4x8x16x4xf32> to vector<4x8x16xf32> } : vector<4x8x16x4xi1> -> vector<4x8x16xf32>
-// CHECK: %[[c0_5:.*]] = arith.constant 0 : index
-// CHECK: vector.mask %[[MC]] { vector.transfer_write %[[RED]], %[[C]][%[[c0_5]], %[[c0_5]], %[[c0_5]]] {in_bounds = [true, true, true]} : vector<4x8x16xf32>, memref<?x?x?xf32> } : vector<4x8x16xi1>
+// CHECK: %[[C0_5:.*]] = arith.constant 0 : index
+// CHECK: vector.mask %[[MC]] { vector.transfer_write %[[RED]], %[[C]][%[[C0_5]], %[[C0_5]], %[[C0_5]]] {in_bounds = [true, true, true]} : vector<4x8x16xf32>, memref<?x?x?xf32> } : vector<4x8x16xi1>
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
@@ -1778,29 +1778,29 @@ func.func @batch_matmul_scalable(%A: memref<?x?x?xf32>, %B: memref<?x?x?xf32>, %
// CHECK-LABEL: func.func @batch_matmul_scalable
// CHECK-SAME: (%[[A:.*]]: memref<?x?x?xf32>, %[[B:.*]]: memref<?x?x?xf32>, %[[C:.*]]: memref<?x?x?xf32>) {
-// CHECK: %[[c0:.*]] = arith.constant 0 : index
-// CHECK: %[[BATCH_DIM:.*]] = memref.dim %[[A]], %[[c0]] : memref<?x?x?xf32>
-// CHECK: %[[c1:.*]] = arith.constant 1 : index
-// CHECK: %[[M:.*]] = memref.dim %[[A]], %[[c1]] : memref<?x?x?xf32>
-// CHECK: %[[c2:.*]] = arith.constant 2 : index
-// CHECK: %[[N:.*]] = memref.dim %[[B]], %[[c2]] : memref<?x?x?xf32>
-// CHECK: %[[c2_2:.*]] = arith.constant 2 : index
-// CHECK: %[[K:.*]] = memref.dim %[[A]], %[[c2_2]] : memref<?x?x?xf32>
-// CHECK: %[[c0_4:.*]] = arith.constant 0 : index
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[BATCH_DIM:.*]] = memref.dim %[[A]], %[[C0]] : memref<?x?x?xf32>
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[M:.*]] = memref.dim %[[A]], %[[C1]] : memref<?x?x?xf32>
+// CHECK: %[[C2:.*]] = arith.constant 2 : index
+// CHECK: %[[N:.*]] = memref.dim %[[B]], %[[C2]] : memref<?x?x?xf32>
+// CHECK: %[[C2_2:.*]] = arith.constant 2 : index
+// CHECK: %[[K:.*]] = memref.dim %[[A]], %[[C2_2]] : memref<?x?x?xf32>
+// CHECK: %[[C0_4:.*]] = arith.constant 0 : index
// CHECK: %[[P0:.*]] = ub.poison : f32
// CHECK: %[[MA:.*]] = vector.create_mask %[[BATCH_DIM]], %[[M]], %[[K]] : vector<4x8x4xi1>
-// CHECK: %[[VA:.*]] = vector.mask %[[MA]] { vector.transfer_read %[[A]][%[[c0_4]], %[[c0_4]], %[[c0_4]]], %[[P0]] {in_bounds = [true, true, true, true], permutation_map = #{{.*}}} : memref<?x?x?xf32>, vector<4x8x[16]x4xf32> } : vector<4x8x4xi1> -> vector<4x8x[16]x4xf32>
+// CHECK: %[[VA:.*]] = vector.mask %[[MA]] { vector.transfer_read %[[A]][%[[C0_4]], %[[C0_4]], %[[C0_4]]], %[[P0]] {in_bounds = [true, true, true, true], permutation_map = #{{.*}}} : memref<?x?x?xf32>, vector<4x8x[16]x4xf32> } : vector<4x8x4xi1> -> vector<4x8x[16]x4xf32>
// CHECK: %[[P1:.*]] = ub.poison : f32
// CHECK: %[[MB:.*]] = vector.create_mask %[[BATCH_DIM]], %[[K]], %[[N]] : vector<4x4x[16]xi1>
-// CHECK: %[[VB:.*]] = vector.mask %[[MB]] { vector.transfer_read %[[B]][%[[c0_4]], %[[c0_4]], %[[c0_4]]], %[[P1]] {in_bounds = [true, true, true, true], permutation_map = #{{.*}}} : memref<?x?x?xf32>, vector<4x8x[16]x4xf32> } : vector<4x4x[16]xi1> -> vector<4x8x[16]x4xf32>
+// CHECK: %[[VB:.*]] = vector.mask %[[MB]] { vector.transfer_read %[[B]][%[[C0_4]], %[[C0_4]], %[[C0_4]]], %[[P1]] {in_bounds = [true, true, true, true], permutation_map = #{{.*}}} : memref<?x?x?xf32>, vector<4x8x[16]x4xf32> } : vector<4x4x[16]xi1> -> vector<4x8x[16]x4xf32>
// CHECK: %[[P2:.*]] = ub.poison : f32
// CHECK: %[[MC:.*]] = vector.create_mask %[[BATCH_DIM]], %[[M]], %[[N]] : vector<4x8x[16]xi1>
-// CHECK: %[[VC:.*]] = vector.mask %[[MC]] { vector.transfer_read %[[C]][%[[c0_4]], %[[c0_4]], %[[c0_4]]], %[[P2]] {in_bounds = [true, true, true]} : memref<?x?x?xf32>, vector<4x8x[16]xf32> } : vector<4x8x[16]xi1> -> vector<4x8x[16]xf32>
+// CHECK: %[[VC:.*]] = vector.mask %[[MC]] { vector.transfer_read %[[C]][%[[C0_4]], %[[C0_4]], %[[C0_4]]], %[[P2]] {in_bounds = [true, true, true]} : memref<?x?x?xf32>, vector<4x8x[16]xf32> } : vector<4x8x[16]xi1> -> vector<4x8x[16]xf32>
// CHECK: %[[MUL:.*]] = arith.mulf %[[VA]], %[[VB]] : vector<4x8x[16]x4xf32>
// CHECK: %[[MRED:.*]] = vector.create_mask %[[BATCH_DIM]], %[[M]], %[[N]], %[[K]] : vector<4x8x[16]x4xi1>
// CHECK: %[[RED:.*]] = vector.mask %[[MRED]] { vector.multi_reduction <add>, %[[MUL]], %[[VC]] [3] : vector<4x8x[16]x4xf32> to vector<4x8x[16]xf32> } : vector<4x8x[16]x4xi1> -> vector<4x8x[16]xf32>
-// CHECK: %[[c0_5:.*]] = arith.constant 0 : index
-// CHECK: vector.mask %[[MC]] { vector.transfer_write %[[RED]], %[[C]][%[[c0_5]], %[[c0_5]], %[[c0_5]]] {in_bounds = [true, true, true]} : vector<4x8x[16]xf32>, memref<?x?x?xf32> } : vector<4x8x[16]xi1>
+// CHECK: %[[C0_5:.*]] = arith.constant 0 : index
+// CHECK: vector.mask %[[MC]] { vector.transfer_write %[[RED]], %[[C]][%[[C0_5]], %[[C0_5]], %[[C0_5]]] {in_bounds = [true, true, true]} : vector<4x8x[16]xf32>, memref<?x?x?xf32> } : vector<4x8x[16]xi1>
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
More information about the Mlir-commits
mailing list