[Mlir-commits] [mlir] 8de85e7 - [mlir][linalg] Add support for scalable vectorization of `linalg.batch_mmt4d` (#152984)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Aug 14 02:47:54 PDT 2025


Author: Ege Beysel
Date: 2025-08-14T11:47:51+02:00
New Revision: 8de85e753fc995d1a8b980c35dc9413ae51d0a1a

URL: https://github.com/llvm/llvm-project/commit/8de85e753fc995d1a8b980c35dc9413ae51d0a1a
DIFF: https://github.com/llvm/llvm-project/commit/8de85e753fc995d1a8b980c35dc9413ae51d0a1a.diff

LOG: [mlir][linalg] Add support for scalable vectorization of `linalg.batch_mmt4d` (#152984)

This PR builds upon the previous #146531 and enables scalable
vectorization for `batch_mmt4d` as well.

---------

Signed-off-by: Ege Beysel <beyselege at gmail.com>

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 80fbe3cb9ff17..406f05c1b08ef 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -2609,6 +2609,7 @@ vectorizeScalableVectorPrecondition(Operation *op,
   return success(isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
                  isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
                  isa<linalg::MatvecOp>(op) || isa<linalg::Mmt4DOp>(op) ||
+                 isa<linalg::BatchMmt4DOp>(op) ||
                  hasReductionIterator(linalgOp));
 }
 

diff  --git a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
index 095810fe0451e..01eb210a8ff5f 100644
--- a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
@@ -880,22 +880,22 @@ func.func @mmt4d_scalable(%A: memref<16x16x8x1xf32>, %B: memref<16x16x?x1xf32>,
 // CHECK-SAME:      %[[A:.*]]: memref<16x16x8x1xf32>,
 // CHECK-SAME:      %[[B:.*]]: memref<16x16x?x1xf32>,
 // CHECK-SAME:      %[[C_IN:.*]]: memref<16x16x8x?xf32>) {
-// CHECK:           %[[VAL_0:.*]] = arith.constant 16 : index
-// CHECK:           %[[VAL_1:.*]] = arith.constant 16 : index
-// CHECK:           %[[VAL_2:.*]] = arith.constant 16 : index
+// CHECK:           %[[C16_M:.*]] = arith.constant 16 : index
+// CHECK:           %[[C16_N:.*]] = arith.constant 16 : index
+// CHECK:           %[[C16_K:.*]] = arith.constant 16 : index
 // CHECK:           %[[C8:.*]] = arith.constant 8 : index
 // CHECK:           %[[C2:.*]] = arith.constant 2 : index
 // CHECK:           %[[DIM_2:.*]] = memref.dim %[[B]], %[[C2]] : memref<16x16x?x1xf32>
-// CHECK:           %[[VAL_6:.*]] = arith.constant 1 : index
+// CHECK:           %[[C1:.*]] = arith.constant 1 : index
 // CHECK:           %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x[4]x1xf32>
-// CHECK:           %[[MASK_1:.*]] = vector.create_mask %[[VAL_1]], %[[VAL_2]], %[[DIM_2]], %[[VAL_6]] : vector<16x16x[4]x1xi1>
+// CHECK:           %[[MASK_1:.*]] = vector.create_mask %[[C16_N]], %[[C16_K]], %[[DIM_2]], %[[C1]] : vector<16x16x[4]x1xi1>
 // CHECK:           %[[VEC_B:.*]] = vector.mask %[[MASK_1]] { vector.transfer_read %[[B]]{{.*}} : memref<16x16x?x1xf32>, vector<16x16x16x8x[4]x1xf32> } : vector<16x16x[4]x1xi1> -> vector<16x16x16x8x[4]x1xf32>
-// CHECK:           %[[MASK_2:.*]] = vector.create_mask %[[VAL_0]], %[[VAL_1]], %[[C8]], %[[DIM_2]] : vector<16x16x8x[4]xi1>
-// CHECK:           %[[VAL_15:.*]] = vector.mask %[[MASK_2]] { vector.transfer_read %[[C_IN]]{{.*}} : memref<16x16x8x?xf32>, vector<16x16x8x[4]xf32> } : vector<16x16x8x[4]xi1> -> vector<16x16x8x[4]xf32>
-// CHECK:           %[[VAL_16:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x[4]x1xf32>
-// CHECK:           %[[MASK_3:.*]] = vector.create_mask %[[VAL_0]], %[[VAL_1]], %[[VAL_2]], %[[C8]], %[[DIM_2]], %[[VAL_6]] : vector<16x16x16x8x[4]x1xi1>
-// CHECK:           %[[VAL_18:.*]] = vector.mask %[[MASK_3]] { vector.multi_reduction <add>, %[[VAL_16]], %[[VAL_15]] [2, 5] : vector<16x16x16x8x[4]x1xf32> to vector<16x16x8x[4]xf32> } : vector<16x16x16x8x[4]x1xi1> -> vector<16x16x8x[4]xf32>
-// CHECK:           vector.mask %[[MASK_2]] { vector.transfer_write %[[VAL_18]], %[[C_IN]]{{.*}} : vector<16x16x8x[4]xf32>, memref<16x16x8x?xf32> } : vector<16x16x8x[4]xi1>
+// CHECK:           %[[MASK_2:.*]] = vector.create_mask %[[C16_M]], %[[C16_N]], %[[C8]], %[[DIM_2]] : vector<16x16x8x[4]xi1>
+// CHECK:           %[[VEC_C:.*]] = vector.mask %[[MASK_2]] { vector.transfer_read %[[C_IN]]{{.*}} : memref<16x16x8x?xf32>, vector<16x16x8x[4]xf32> } : vector<16x16x8x[4]xi1> -> vector<16x16x8x[4]xf32>
+// CHECK:           %[[MUL:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x[4]x1xf32>
+// CHECK:           %[[MASK_3:.*]] = vector.create_mask %[[C16_M]], %[[C16_N]], %[[C16_K]], %[[C8]], %[[DIM_2]], %[[C1]] : vector<16x16x16x8x[4]x1xi1>
+// CHECK:           %[[RED:.*]] = vector.mask %[[MASK_3]] { vector.multi_reduction <add>, %[[MUL]], %[[VEC_C]] [2, 5] : vector<16x16x16x8x[4]x1xf32> to vector<16x16x8x[4]xf32> } : vector<16x16x16x8x[4]x1xi1> -> vector<16x16x8x[4]xf32>
+// CHECK:           vector.mask %[[MASK_2]] { vector.transfer_write %[[RED]], %[[C_IN]]{{.*}} : vector<16x16x8x[4]xf32>, memref<16x16x8x?xf32> } : vector<16x16x8x[4]xi1>
 
 
 module attributes {transform.with_named_sequence} {
@@ -920,10 +920,10 @@ func.func @mmt4d_scalable_with_assume(%A: memref<16x16x8x1xf32>, %B: memref<16x1
 // CHECK-NOT:       mask
 // CHECK:           %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x[4]x1xf32>
 // CHECK:           %[[VEC_B:.*]] = vector.transfer_read %[[B]]{{.*}} : memref<16x16x?x1xf32>, vector<16x16x16x8x[4]x1xf32>
-// CHECK:           %[[VAL_13:.*]] = vector.transfer_read %[[C_IN]]{{.*}} : memref<16x16x8x?xf32>, vector<16x16x8x[4]xf32>
-// CHECK:           %[[VAL_14:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x[4]x1xf32>
-// CHECK:           %[[VAL_15:.*]] = vector.multi_reduction <add>, %[[VAL_14]], %[[VAL_13]] [2, 5] : vector<16x16x16x8x[4]x1xf32> to vector<16x16x8x[4]xf32>
-// CHECK:           vector.transfer_write %[[VAL_15]], %[[C_IN]]{{.*}} : vector<16x16x8x[4]xf32>, memref<16x16x8x?xf32>
+// CHECK:           %[[VEC_C:.*]] = vector.transfer_read %[[C_IN]]{{.*}} : memref<16x16x8x?xf32>, vector<16x16x8x[4]xf32>
+// CHECK:           %[[MUL:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x[4]x1xf32>
+// CHECK:           %[[RED:.*]] = vector.multi_reduction <add>, %[[MUL]], %[[VEC_C]] [2, 5] : vector<16x16x16x8x[4]x1xf32> to vector<16x16x8x[4]xf32>
+// CHECK:           vector.transfer_write %[[RED]], %[[C_IN]]{{.*}} : vector<16x16x8x[4]xf32>, memref<16x16x8x?xf32>
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
@@ -933,6 +933,100 @@ module attributes {transform.with_named_sequence} {
   }
 }
 
+// -----
+
+///----------------------------------------------------------------------------------------
+/// Tests for linalg.batch_mmt4d
+///----------------------------------------------------------------------------------------
+
+func.func @batch_mmt4d(%A: memref<2x16x16x8x1xf32>, %B: memref<2x16x16x8x1xf32>, %C_in: memref<2x16x16x8x8xf32>) {
+  linalg.batch_mmt4d ins(%A, %B: memref<2x16x16x8x1xf32>, memref<2x16x16x8x1xf32>)
+               outs(%C_in: memref<2x16x16x8x8xf32>)
+  return
+}
+
+// CHECK-LABEL:   func.func @batch_mmt4d(
+// CHECK-SAME:      %[[A:.*]]: memref<2x16x16x8x1xf32>, %[[B:.*]]: memref<2x16x16x8x1xf32>, %[[C:.*]]: memref<2x16x16x8x8xf32>) {
+// CHECK:           %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<2x16x16x8x1xf32>, vector<2x16x16x16x8x8x1xf32>
+// CHECK:           %[[VEC_B:.*]] = vector.transfer_read %[[B]]{{.*}} : memref<2x16x16x8x1xf32>, vector<2x16x16x16x8x8x1xf32>
+// CHECK:           %[[VEC_C:.*]] = vector.transfer_read %[[C]]{{.*}} : memref<2x16x16x8x8xf32>, vector<2x16x16x8x8xf32>
+// CHECK:           %[[MUL:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<2x16x16x16x8x8x1xf32>
+// CHECK:           %[[RED:.*]] = vector.multi_reduction <add>, %[[MUL]], %[[VEC_C]] [3, 6] : vector<2x16x16x16x8x8x1xf32> to vector<2x16x16x8x8xf32>
+// CHECK:           vector.transfer_write %[[RED]], %[[C]]{{.*}} : vector<2x16x16x8x8xf32>, memref<2x16x16x8x8xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %batch_mmt4d = transform.structured.match ops{["linalg.batch_mmt4d"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %batch_mmt4d : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @batch_mmt4d_scalable(%A: memref<2x16x16x8x1xf32>, %B: memref<2x16x16x?x1xf32>, %C_in: memref<2x16x16x8x?xf32>) {
+  linalg.batch_mmt4d ins(%A, %B: memref<2x16x16x8x1xf32>, memref<2x16x16x?x1xf32>)
+               outs(%C_in: memref<2x16x16x8x?xf32>)
+  return
+}
+// CHECK-LABEL:   func.func @batch_mmt4d_scalable(
+// CHECK-SAME:      %[[A:.*]]: memref<2x16x16x8x1xf32>,
+// CHECK-SAME:      %[[B:.*]]: memref<2x16x16x?x1xf32>,
+// CHECK-SAME:      %[[C_IN:.*]]: memref<2x16x16x8x?xf32>) {
+// CHECK:           %[[C2:.*]] = arith.constant 2 : index
+// CHECK:           %[[C16_M:.*]] = arith.constant 16 : index
+// CHECK:           %[[C16_N:.*]] = arith.constant 16 : index
+// CHECK:           %[[C16_K:.*]] = arith.constant 16 : index
+// CHECK:           %[[C8:.*]] = arith.constant 8 : index
+// CHECK:           %[[C3:.*]] = arith.constant 3 : index
+// CHECK:           %[[DIM_N_IN:.*]] = memref.dim %[[B]], %[[C3]] : memref<2x16x16x?x1xf32>
+// CHECK:           %[[C1:.*]] = arith.constant 1 : index
+// CHECK:           %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<2x16x16x8x1xf32>, vector<2x16x16x16x8x[4]x1xf32>
+// CHECK:           %[[MASK_1:.*]] = vector.create_mask %[[C2]], %[[C16_N]], %[[C16_K]], %[[DIM_N_IN]], %[[C1]] : vector<2x16x16x[4]x1xi1>
+// CHECK:           %[[VEC_B:.*]] = vector.mask %[[MASK_1]] { vector.transfer_read %[[B]]{{.*}} : memref<2x16x16x?x1xf32>, vector<2x16x16x16x8x[4]x1xf32> } : vector<2x16x16x[4]x1xi1> -> vector<2x16x16x16x8x[4]x1xf32>
+// CHECK:           %[[MASK_2:.*]] = vector.create_mask %[[C2]], %[[C16_M]], %[[C16_N]], %[[C8]], %[[DIM_N_IN]] : vector<2x16x16x8x[4]xi1>
+// CHECK:           %[[VEC_C:.*]] = vector.mask %[[MASK_2]] { vector.transfer_read %[[C_IN]]{{.*}} : memref<2x16x16x8x?xf32>, vector<2x16x16x8x[4]xf32> } : vector<2x16x16x8x[4]xi1> -> vector<2x16x16x8x[4]xf32>
+// CHECK:           %[[MUL:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<2x16x16x16x8x[4]x1xf32>
+// CHECK:           %[[MASK_3:.*]] = vector.create_mask %[[C2]], %[[C16_M]], %[[C16_N]], %[[C16_K]], %[[C8]], %[[DIM_N_IN]], %[[C1]] : vector<2x16x16x16x8x[4]x1xi1>
+// CHECK:           %[[RED:.*]] = vector.mask %[[MASK_3]] { vector.multi_reduction <add>, %[[MUL]], %[[VEC_C]] [3, 6] : vector<2x16x16x16x8x[4]x1xf32> to vector<2x16x16x8x[4]xf32> } : vector<2x16x16x16x8x[4]x1xi1> -> vector<2x16x16x8x[4]xf32>
+// CHECK:           vector.mask %[[MASK_2]] { vector.transfer_write %[[RED]], %[[C_IN]]{{.*}} : vector<2x16x16x8x[4]xf32>, memref<2x16x16x8x?xf32> } : vector<2x16x16x8x[4]xi1>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %batch_mmt4d = transform.structured.match ops{["linalg.batch_mmt4d"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %batch_mmt4d vector_sizes [2, 16, 16, 16, 8, [4], 1] : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @batch_mmt4d_scalable_with_assume(%A: memref<2x16x16x8x1xf32>, %B: memref<2x16x16x?x1xf32>, %C_in: memref<2x16x16x8x?xf32>) {
+  linalg.batch_mmt4d ins(%A, %B: memref<2x16x16x8x1xf32>, memref<2x16x16x?x1xf32>)
+               outs(%C_in: memref<2x16x16x8x?xf32>)
+  return
+}
+// CHECK-LABEL:   func.func @batch_mmt4d_scalable_with_assume(
+// CHECK-SAME:      %[[A:.*]]: memref<2x16x16x8x1xf32>,
+// CHECK-SAME:      %[[B:.*]]: memref<2x16x16x?x1xf32>,
+// CHECK-SAME:      %[[C_IN:.*]]: memref<2x16x16x8x?xf32>) {
+// CHECK-NOT:       mask
+// CHECK:           %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<2x16x16x8x1xf32>, vector<2x16x16x16x8x[4]x1xf32>
+// CHECK:           %[[VEC_B:.*]] = vector.transfer_read %[[B]]{{.*}} : memref<2x16x16x?x1xf32>, vector<2x16x16x16x8x[4]x1xf32>
+// CHECK:           %[[VEC_C:.*]] = vector.transfer_read %[[C_IN]]{{.*}} : memref<2x16x16x8x?xf32>, vector<2x16x16x8x[4]xf32>
+// CHECK:           %[[MUL:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<2x16x16x16x8x[4]x1xf32>
+// CHECK:           %[[RED:.*]] = vector.multi_reduction <add>, %[[MUL]], %[[VEC_C]] [3, 6] : vector<2x16x16x16x8x[4]x1xf32> to vector<2x16x16x8x[4]xf32>
+// CHECK:           vector.transfer_write %[[RED]], %[[C_IN]]{{.*}} : vector<2x16x16x8x[4]xf32>, memref<2x16x16x8x?xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %batch_mmt4d = transform.structured.match ops{["linalg.batch_mmt4d"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %batch_mmt4d vector_sizes [2, 16, 16, 16, 8, [4], 1] {assume_dynamic_dims_match_vec_sizes} : !transform.any_op
+    transform.yield
+  }
+}
+
+
 // -----
 
 ///----------------------------------------------------------------------------------------


        


More information about the Mlir-commits mailing list