[Mlir-commits] [mlir] [MLIR][Linalg] Scalable Vectorization of Reduction (PR #97788)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jul 4 22:25:24 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-sve

Author: Zhaoshi Zheng (zhaoshiz)

<details>
<summary>Changes</summary>

Allow scalable vectorization of linalg::reduce and linalg::generic with reduction iterator. For now, only reduction on the trailing dimension is supported.

---

Patch is 22.75 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/97788.diff


6 Files Affected:

- (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+27-1) 
- (modified) mlir/test/Dialect/Linalg/vectorization-scalable.mlir (+80) 
- (modified) mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir (+24) 
- (added) mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/generic_reduce_2d.mlir (+95) 
- (added) mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_1d.mlir (+90) 
- (added) mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_2d.mlir (+91) 


``````````diff
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 3a75d2ac08157..b1aae46237451 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -582,6 +582,12 @@ static SmallVector<bool> getDimsToReduce(LinalgOp linalgOp) {
       llvm::map_range(linalgOp.getIteratorTypesArray(), isReductionIterator));
 }
 
+static bool isLinalgReduction(LinalgOp &op) {
+  return isa<linalg::ReduceOp>(op) ||
+         (isa<linalg::GenericOp>(op) &&
+          llvm::any_of(op.getIteratorTypesArray(), isReductionIterator));
+}
+
 /// Build a vector.transfer_write of `value` into `outputOperand` at indices set
 /// to all `0`; where `outputOperand` is an output operand of the LinalgOp
 /// currently being vectorized. If `dest` has null rank, build an memref.store.
@@ -1773,6 +1779,9 @@ vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op,
   if (isa<ConvolutionOpInterface>(op.getOperation()))
     return vectorizeDynamicConvOpPrecondition(op, flatten1DDepthwiseConv);
 
+  if (isLinalgReduction(op))
+    return reductionPreconditions(op);
+
   // TODO: Masking only supports dynamic element-wise ops, linalg.generic ops,
   // linalg.copy ops and ops that implement ContractionOpInterface for now.
   if (!isElementwise(op) &&
@@ -1942,13 +1951,30 @@ vectorizeScalableVectorPrecondition(Operation *op,
   if (inputVectorSizes.empty())
     return success();
 
+  auto linalgOp = dyn_cast<LinalgOp>(op);
+  if (linalgOp && isLinalgReduction(linalgOp)) {
+    LDBG("Checking reduce op dims for scalable vectorization\n");
+    auto iteratorTypes = linalgOp.getIteratorTypesArray();
+    assert(iteratorTypes.size() == inputScalableVecDims.size() &&
+           "Number of iterator types and input scalable dims mismatch");
+    // For now, only support scalable vectorization of a reduction on the
+    // trailing dim.
+    for (size_t i = 0; i < inputScalableVecDims.size() - 1; ++i) {
+      if (inputScalableVecDims[i] && isReductionIterator(iteratorTypes[i])) {
+        LDBG("Non-trailing reduction dim requested for scalable "
+             "vectorization\n");
+        return failure();
+      }
+    }
+    return success();
+  }
+
   bool isScalable = inputScalableVecDims.back();
   if (!isScalable)
     return success();
 
   // Only element-wise and 1d depthwise conv ops supported in the presence of
   // scalable dims.
-  auto linalgOp = dyn_cast<LinalgOp>(op);
   return success(linalgOp && (isElementwise(linalgOp) ||
                               isa<linalg::DepthwiseConv1DNwcWcOp>(op)));
 }
diff --git a/mlir/test/Dialect/Linalg/vectorization-scalable.mlir b/mlir/test/Dialect/Linalg/vectorization-scalable.mlir
index d6f8d78358370..e0dae167b8625 100644
--- a/mlir/test/Dialect/Linalg/vectorization-scalable.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization-scalable.mlir
@@ -142,3 +142,83 @@ module attributes {transform.with_named_sequence} {
   }
 }
 
+// -----
+
+func.func @vectorize_dynamic_reduction_1d(%arg0: tensor<?xf32>,
+                                          %arg1: tensor<f32>) -> tensor<f32> {
+
+  %0 = linalg.reduce ins(%arg0 : tensor<?xf32>) outs(%arg1 : tensor<f32>) dimensions = [0]
+  (%in: f32, %init: f32) {
+    %0 = arith.addf %in, %init : f32
+    linalg.yield %0 : f32
+  }
+  return %0 : tensor<f32>
+}
+
+// CHECK-LABEL:  func.func @vectorize_dynamic_reduction_1d(
+// CHECK-SAME:     %[[ARG_0:.*]]: tensor<?xf32>, %[[ARG_1:.*]]: tensor<f32>) -> tensor<f32> {
+// CHECK:          %[[VAL_0:.*]] = arith.constant 0 : index
+// CHECK:          %[[VAL_1:.*]] = tensor.dim %[[ARG_0]], %[[VAL_0]] : tensor<?xf32>
+// CHECK:          %[[VAL_2:.*]] = arith.constant 0 : index
+// CHECK:          %[[VAL_3:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:          %[[VAL_4:.*]] = vector.create_mask %[[VAL_1]] : vector<[4]xi1>
+// CHECK:          %[[VAL_5:.*]] = vector.mask %[[VAL_4]] { vector.transfer_read %[[ARG_0]][%[[VAL_2]]], %[[VAL_3]] {in_bounds = [true]} : tensor<?xf32>, vector<[4]xf32> } : vector<[4]xi1> -> vector<[4]xf32>
+// CHECK:          %[[VAL_6:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:          %[[VAL_7:.*]] = vector.transfer_read %[[ARG_1]][], %[[VAL_6]] : tensor<f32>, vector<f32>
+// CHECK:          %[[VAL_8:.*]] = vector.extractelement %[[VAL_7]][] : vector<f32>
+// CHECK:          %[[VAL_9:.*]] = vector.mask %[[VAL_4]] { vector.multi_reduction <add>, %[[VAL_5]], %[[VAL_8]] [0] : vector<[4]xf32> to f32 } : vector<[4]xi1> -> f32
+// CHECK:          %[[VAL_10:.*]] = vector.broadcast %[[VAL_9]] : f32 to vector<f32>
+// CHECK:          %[[VAL_11:.*]] = vector.transfer_write %[[VAL_10]], %[[ARG_1]][] : vector<f32>, tensor<f32>
+// CHECK:          return %[[VAL_11]] : tensor<f32>
+// CHECK:        }
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.reduce"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 vector_sizes [[4]] : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @vectorize_dynamic_reduction_2d(%arg0: tensor<?x?xf32>,
+                                          %arg1: tensor<?xf32>) -> tensor<?xf32> {
+  %0 = linalg.generic { indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                                         affine_map<(d0, d1) -> (d0)>],
+                        iterator_types = ["parallel", "reduction"] }
+    ins(%arg0 : tensor<?x?xf32>)
+    outs(%arg1 : tensor<?xf32>) {
+    ^bb(%in: f32, %out: f32) :
+      %0 = arith.addf %in, %out : f32
+      linalg.yield %0 : f32
+    } -> tensor<?xf32>
+  return %0 : tensor<?xf32>
+}
+
+// CHECK-LABEL:  func.func @vectorize_dynamic_reduction_2d(
+// CHECK-SAME:     %[[ARG_0:.*]]: tensor<?x?xf32>, %[[ARG_1:.*]]: tensor<?xf32>) -> tensor<?xf32> {
+// CHECK:    %[[VAL_0:.*]] = arith.constant 0 : index
+// CHECK:    %[[VAL_1:.*]] = tensor.dim %[[ARG_0]], %[[VAL_0]] : tensor<?x?xf32>
+// CHECK:    %[[VAL_2:.*]] = arith.constant 1 : index
+// CHECK:    %[[VAL_3:.*]] = tensor.dim %[[ARG_0]], %[[VAL_2]] : tensor<?x?xf32>
+// CHECK:    %[[VAL_4:.*]] = arith.constant 0 : index
+// CHECK:    %[[VAL_5:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:    %[[VAL_6:.*]] = vector.create_mask %[[VAL_1]], %[[VAL_3]] : vector<1x[4]xi1>
+// CHECK:    %[[VAL_7:.*]] = vector.mask %[[VAL_6]] { vector.transfer_read %[[ARG_0]][%[[VAL_4]], %[[VAL_4]]], %[[VAL_5]] {in_bounds = [true, true]} : tensor<?x?xf32>, vector<1x[4]xf32> } : vector<1x[4]xi1> -> vector<1x[4]xf32>
+// CHECK:    %[[VAL_8:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:    %[[VAL_9:.*]] = vector.create_mask %[[VAL_1]] : vector<1xi1>
+// CHECK:    %[[VAL_10:.*]] = vector.mask %[[VAL_9]] { vector.transfer_read %[[ARG_1]][%[[VAL_4]]], %[[VAL_8]] {in_bounds = [true]} : tensor<?xf32>, vector<1xf32> } : vector<1xi1> -> vector<1xf32>
+// CHECK:    %[[VAL_11:.*]] = vector.mask %[[VAL_6]] { vector.multi_reduction <add>, %[[VAL_7]], %[[VAL_10]] [1] : vector<1x[4]xf32> to vector<1xf32> } : vector<1x[4]xi1> -> vector<1xf32>
+// CHECK:    %[[VAL_12:.*]] = arith.constant 0 : index
+// CHECK:    %[[VAL_13:.*]] = vector.mask %[[VAL_9]] { vector.transfer_write %[[VAL_11]], %[[ARG_1]][%[[VAL_12]]] {in_bounds = [true]} : vector<1xf32>, tensor<?xf32> } : vector<1xi1> -> tensor<?xf32>
+// CHECK:    return %[[VAL_13]] : tensor<?xf32>
+// CHECK:  }
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 vector_sizes [1, [4]] : !transform.any_op
+    transform.yield
+  }
+}
diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
index f70d23a193229..03cdd4f1cc2b6 100644
--- a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
@@ -298,6 +298,30 @@ func.func @scalable_dim_1d(%A: vector<[4]xf32>, %B: f32, %C: vector<[4]xi1>) ->
 // CHECK:          %[[VAL_4:.*]] = vector.extract %[[VAL_3]][0] : f32 from vector<1xf32>
 // CHECK:          return %[[VAL_4]] : f32
 
+func.func @scalable_dim_2d(%A: vector<2x[4]xf32>, %B: vector<2xf32>, %C: vector<2x[4]xi1>) -> vector<2xf32> {
+    %0 = vector.mask %C { vector.multi_reduction <add>, %A, %B [1] : vector<2x[4]xf32> to vector<2xf32> } : vector<2x[4]xi1> -> vector<2xf32>
+    return %0 : vector<2xf32>
+}
+
+// CHECK-LABEL:  func.func @scalable_dim_2d(
+// CHECK-SAME:                                      %[[ARG_0:.*]]: vector<2x[4]xf32>,
+// CHECK-SAME:                                      %[[ARG_1:.*]]: vector<2xf32>,
+// CHECK-SAME:                                      %[[ARG_2:.*]]: vector<2x[4]xi1>) -> vector<2xf32> {
+// CHECK-DAG:      %[[CON_0:.*]] = arith.constant 1 : index
+// CHECK-DAG:      %[[CON_1:.*]] = arith.constant 0 : index
+// CHECK-DAG:      %[[CON_2:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32>
+// CHECK:          %[[VAL_0:.*]] = vector.extract %[[ARG_0]][0] : vector<[4]xf32> from vector<2x[4]xf32>
+// CHECK:          %[[VAL_1:.*]] = vector.extract %[[ARG_1]][0] : f32 from vector<2xf32>
+// CHECK:          %[[VAL_2:.*]] = vector.extract %[[ARG_2]][0] : vector<[4]xi1> from vector<2x[4]xi1>
+// CHECK:          %[[VAL_3:.*]] = vector.mask %[[VAL_2]] { vector.reduction <add>, %[[VAL_0]], %[[VAL_1]] : vector<[4]xf32> into f32 } : vector<[4]xi1> -> f32
+// CHECK:          %[[VAL_4:.*]] = vector.insertelement %[[VAL_3]], %[[CON_2]][%[[CON_1]] : index] : vector<2xf32>
+// CHECK:          %[[VAL_5:.*]] = vector.extract %[[ARG_0]][1] : vector<[4]xf32> from vector<2x[4]xf32>
+// CHECK:          %[[VAL_6:.*]] = vector.extract %[[ARG_1]][1] : f32 from vector<2xf32>
+// CHECK:          %[[VAL_7:.*]] = vector.extract %[[ARG_2]][1] : vector<[4]xi1> from vector<2x[4]xi1>
+// CHECK:          %[[VAL_8:.*]] = vector.mask %[[VAL_7]] { vector.reduction <add>, %[[VAL_5]], %[[VAL_6]] : vector<[4]xf32> into f32 } : vector<[4]xi1> -> f32
+// CHECK:          %[[VAL_9:.*]] = vector.insertelement %[[VAL_8]], %[[VAL_4]][%[[CON_0]] : index] : vector<2xf32>
+// CHECK:          return %[[VAL_9]] : vector<2xf32>
+
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) {
     %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/generic_reduce_2d.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/generic_reduce_2d.mlir
new file mode 100644
index 0000000000000..42a6f55e56a6f
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/generic_reduce_2d.mlir
@@ -0,0 +1,95 @@
+// DEFINE: %{compile} =  mlir-opt %s \
+// DEFINE:    -transform-interpreter -test-transform-dialect-erase-schedule \
+// DEFINE:    -one-shot-bufferize="bufferize-function-boundaries" -buffer-deallocation-pipeline -cse -canonicalize -convert-vector-to-scf -arm-sve-legalize-vector-storage \
+// DEFINE:    -convert-vector-to-llvm="enable-arm-sve" -test-lower-to-llvm -o %t
+// DEFINE: %{entry_point} = generic_reduce_2d_f32
+// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void --march=aarch64 --mattr="+sve"\
+// DEFINE:    -shared-libs=%mlir_native_utils_lib_dir/libmlir_runner_utils%shlibext,%mlir_native_utils_lib_dir/libmlir_c_runner_utils%shlibext
+
+// RUN: %{compile}
+
+// RUN: %{run} | FileCheck %s --check-prefix=F32
+
+func.func @generic_reduce_2d_f32() {
+  // 2-D Tensor
+  %M = arith.constant 16 : index
+  %N = arith.constant 1000 : index
+  %c0_f32 = arith.constant 0.0 : f32
+
+  // Allocate the input and output tensors
+  %A_alloc = bufferization.alloc_tensor(%M, %N) : tensor<?x?xf32>
+  %C_alloc = bufferization.alloc_tensor(%M) : tensor<?xf32>
+
+  // Initialise the tensors
+  %pi = arith.constant  3.1416 : f32
+  %A_in = linalg.fill ins(%pi : f32) outs(%A_alloc : tensor<?x?xf32>) -> tensor<?x?xf32>
+  %C_in = linalg.fill ins(%c0_f32 : f32) outs(%C_alloc : tensor<?xf32>) -> tensor<?xf32>
+
+  // Reduce
+  %C_out = linalg.generic { indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                                         affine_map<(d0, d1) -> (d0)>],
+                        iterator_types = ["parallel", "reduction"] }
+    ins(%A_in : tensor<?x?xf32>)
+    outs(%C_in : tensor<?xf32>) {
+    ^bb(%in: f32, %out: f32) :
+      %0 = arith.addf %in, %out : f32
+      linalg.yield %0 : f32
+    } -> tensor<?xf32>
+
+  // Print and verify the output
+  // F32-LABEL: SVE: START OF TEST OUTPUT
+  vector.print str "SVE: START OF TEST OUTPUT\n"
+
+  // F32-NEXT: Unranked Memref {{.*}} rank = 1 offset = 0 sizes = [16] strides = [1] data =
+  // F32-NEXT: [3141.6,  3141.6,  3141.6,  3141.6,  3141.6,  3141.6,  3141.6,  3141.6,  3141.6,  3141.6,  3141.6,  3141.6,  3141.6,  3141.6,  3141.6,  3141.6]
+
+  %xf = tensor.cast %C_out : tensor<?xf32> to tensor<*xf32>
+  call @printMemrefF32(%xf) : (tensor<*xf32>) -> ()
+
+  // F32-NEXT: SVE: END OF TEST OUTPUT
+  vector.print str "SVE: END OF TEST OUTPUT\n"
+
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  // A sequence that will tile and vectorise a Reduce Op
+  transform.named_sequence @tile_and_vectorize_reduce(%func
+    : !transform.op<"func.func"> {transform.readonly}) {
+
+    // Step 0: Get a handle to the reduce Op
+    %reduce = transform.structured.match ops{["linalg.generic"]} in %func
+      : (!transform.op<"func.func">) -> !transform.any_op
+
+    // Step 1: Tile
+    %tiled_reduce, %loops:2 = transform.structured.tile_using_for %reduce tile_sizes [1, [4]]
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+
+    // Step 2: Vectorize
+    transform.structured.vectorize %tiled_reduce vector_sizes [1, [4]] : !transform.any_op
+
+    // Step 3: Lower vector.multi_reduction
+    transform.apply_patterns to %func {
+      transform.apply_patterns.vector.lower_masked_transfers
+      transform.apply_patterns.vector.lower_multi_reduction lowering_strategy = "innerreduction"
+    } : !transform.op<"func.func">
+
+    transform.yield
+  }
+
+  // A sequence that goes over all functions in tis module and applies
+  // "tile_and_vectorize_reduce"
+  transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) {
+    %funcs = transform.structured.match ops{["func.func"]} in %module
+        : (!transform.any_op) -> !transform.op<"func.func">
+
+    transform.foreach %funcs : !transform.op<"func.func"> {
+      ^bb2(%func : !transform.op<"func.func">):
+        transform.include @tile_and_vectorize_reduce failures(propagate)
+        (%func) : (!transform.op<"func.func">) -> ()
+    }
+    transform.yield
+  }
+}
+
+func.func private @printMemrefF32(%ptr : tensor<*xf32>)
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_1d.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_1d.mlir
new file mode 100644
index 0000000000000..e9f7154b10d42
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_1d.mlir
@@ -0,0 +1,90 @@
+// DEFINE: %{compile} =  mlir-opt %s \
+// DEFINE:    -transform-interpreter -test-transform-dialect-erase-schedule \
+// DEFINE:    -one-shot-bufferize="bufferize-function-boundaries" -buffer-deallocation-pipeline -cse -canonicalize -convert-vector-to-scf -arm-sve-legalize-vector-storage \
+// DEFINE:    -convert-vector-to-llvm="enable-arm-sve" -test-lower-to-llvm -o %t
+// DEFINE: %{entry_point} = reduce_1d_f32
+// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void --march=aarch64 --mattr="+sve"\
+// DEFINE:    -shared-libs=%mlir_native_utils_lib_dir/libmlir_runner_utils%shlibext,%mlir_native_utils_lib_dir/libmlir_c_runner_utils%shlibext
+
+// RUN: %{compile}
+
+// RUN: %{run} | FileCheck %s --check-prefix=F32
+
+func.func @reduce_1d_f32() {
+  // 1-D Tensor
+  %N = arith.constant 1000 : index
+  %c0_f32 = arith.constant 0.0 : f32
+
+  // Allocate the input and output tensors
+  %A_alloc = bufferization.alloc_tensor(%N) : tensor<?xf32>
+  %C_alloc = bufferization.alloc_tensor() : tensor<f32>
+
+  // Initialise the tensors
+  %pi = arith.constant  3.1416 : f32
+  %A_in = linalg.fill ins(%pi : f32) outs(%A_alloc : tensor<?xf32>) -> tensor<?xf32>
+  %C_in = tensor.insert %c0_f32 into %C_alloc[] : tensor<f32>
+
+  // Reduce
+  %C_out = linalg.reduce ins(%A_in : tensor<?xf32>) outs(%C_in: tensor<f32>) dimensions = [0]
+    (%in: f32, %init: f32) {
+      %0 = arith.addf %in, %init : f32
+      linalg.yield %0 : f32
+    }
+
+  // Print and verify the output
+  // F32-LABEL: SVE: START OF TEST OUTPUT
+  vector.print str "SVE: START OF TEST OUTPUT\n"
+
+  // F32-NEXT: Unranked Memref {{.*}} rank = 0 offset = 0 sizes = [] strides = [] data =
+  // F32-NEXT: [3141.6]
+
+  %xf = tensor.cast %C_out : tensor<f32> to tensor<*xf32>
+  call @printMemrefF32(%xf) : (tensor<*xf32>) -> ()
+
+  // F32-NEXT: SVE: END OF TEST OUTPUT
+  vector.print str "SVE: END OF TEST OUTPUT\n"
+
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  // A sequence that will tile and vectorise a Reduce Op
+  transform.named_sequence @tile_and_vectorize_reduce(%func
+    : !transform.op<"func.func"> {transform.readonly}) {
+
+    // Step 0: Get a handle to the reduce Op
+    %reduce = transform.structured.match ops{["linalg.reduce"]} in %func
+      : (!transform.op<"func.func">) -> !transform.any_op
+
+    // Step 1: Tile
+    %tiled_reduce, %loops:1 = transform.structured.tile_using_for %reduce tile_sizes [[4]]
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+
+    // Step 2: Vectorize
+    transform.structured.vectorize %tiled_reduce vector_sizes [[4]] : !transform.any_op
+
+    // Step 3: Lower vector.multi_reduction
+    transform.apply_patterns to %func {
+      transform.apply_patterns.vector.lower_masked_transfers
+      transform.apply_patterns.vector.lower_multi_reduction lowering_strategy = "innerreduction"
+    } : !transform.op<"func.func">
+
+    transform.yield
+  }
+
+  // A sequence that goes over all functions in tis module and applies
+  // "tile_and_vectorize_reduce"
+  transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) {
+    %funcs = transform.structured.match ops{["func.func"]} in %module
+        : (!transform.any_op) -> !transform.op<"func.func">
+
+    transform.foreach %funcs : !transform.op<"func.func"> {
+      ^bb2(%func : !transform.op<"func.func">):
+        transform.include @tile_and_vectorize_reduce failures(propagate)
+        (%func) : (!transform.op<"func.func">) -> ()
+    }
+    transform.yield
+  }
+}
+
+func.func private @printMemrefF32(%ptr : tensor<*xf32>)
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_2d.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_2d.mlir
new file mode 100644
index 0000000000000..349966d7c85d5
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_2d.mlir
@@ -0,0 +1,91 @@
+// DEFINE: %{compile} =  mlir-opt %s \
+// DEFINE:    -transform-interpreter -test-transform-dialect-erase-schedule \
+// DEFINE:    -one-shot-bufferize="bufferize-function-boundaries" -buffer-deallocation-pipeline -cse -canonicalize -convert-vector-to-scf -arm-sve-legalize-vector-storage \
+// DEFINE:    -convert-vector-to-llvm="enable-arm-sve" -test-lower-to-llvm -o %t
+// DEFINE: %{entry_point} = reduce_2d_f32
+// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void --march=aarch64 --mattr="+sve"\
+// DEFINE:    -shared-libs=%mlir_native_utils_lib_dir/libmlir_runner_utils%shlibext,%mlir_native_utils_lib_dir/libmlir_c_runner_utils%shlibext
+
+// RUN: %{compile}
+
+// RUN: %{run} | FileCheck %s --check-prefix=F32
+
+func.func @reduce_2d_f32() {
+  // 2-D Tensor
+  %M = arith.constant 16 : index
+  %N = arith.constant 1000 : index
+  %c0_f32 = arith.constant 0.0 : f32
+
+  // Allocate the input and output tensors
+  %A_alloc = bufferization.alloc_tensor(%M, %N) : tensor<?x?xf32>
+  %C_alloc = bufferization.alloc_tenso...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/97788


More information about the Mlir-commits mailing list