[Mlir-commits] [mlir] 6942f1d - [MLIR][Linalg] Scalable Vectorization of Reduction on the Trailing Dimension (#97788)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jul 23 21:52:26 PDT 2024
Author: Zhaoshi Zheng
Date: 2024-07-23T21:52:22-07:00
New Revision: 6942f1d5aa232face8269ce78c4de7d45571a8e9
URL: https://github.com/llvm/llvm-project/commit/6942f1d5aa232face8269ce78c4de7d45571a8e9
DIFF: https://github.com/llvm/llvm-project/commit/6942f1d5aa232face8269ce78c4de7d45571a8e9.diff
LOG: [MLIR][Linalg] Scalable Vectorization of Reduction on the Trailing Dimension (#97788)
Allow scalable vectorization of linalg::reduce and linalg::generic that has
reduction iterator(s) with two restrictions:
1. The reduction dim is the last (innermost) dim of the op; and
2. Only the reduction dim is requested for scalable vectorization.
One exception is that scalable vectorization of the reduction dim in
Matmul-like ops are not supported even above restrictions are met.
Allowed combinations of scalable flags and iterator types:
Matmul:
Iterators: ["parallel", "parallel", "reduction"]
Scalable Flags: ["true", "true", "false"]
["false", "true", "false"]
Matvec:
Iterators: ["parallel", "reduction"]
Scalable Flags: ["false", "true"]
["true", "false"]
Added:
mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_1d.mlir
mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_2d.mlir
Modified:
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/test/Dialect/Linalg/vectorization-scalable.mlir
mlir/test/Dialect/Linalg/vectorization-unsupported.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 7f7168eb86832..c4dab7d061b4b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -586,6 +586,14 @@ static SmallVector<bool> getDimsToReduce(LinalgOp linalgOp) {
llvm::map_range(linalgOp.getIteratorTypesArray(), isReductionIterator));
}
+/// Check if `op` is a linalg.reduce or a linalg.generic that has at least one
+/// reduction iterator.
+static bool hasReductionIterator(LinalgOp &op) {
+ return isa<linalg::ReduceOp>(op) ||
+ (isa<linalg::GenericOp>(op) &&
+ llvm::any_of(op.getIteratorTypesArray(), isReductionIterator));
+}
+
/// Build a vector.transfer_write of `value` into `outputOperand` at indices set
/// to all `0`; where `outputOperand` is an output operand of the LinalgOp
/// currently being vectorized. If `dest` has null rank, build an memref.store.
@@ -1787,6 +1795,9 @@ vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op,
if (isa<ConvolutionOpInterface>(op.getOperation()))
return vectorizeDynamicConvOpPrecondition(op, flatten1DDepthwiseConv);
+ if (hasReductionIterator(op))
+ return reductionPreconditions(op);
+
// TODO: Masking only supports dynamic element-wise ops, linalg.generic ops,
// linalg.copy ops and ops that implement ContractionOpInterface for now.
if (!isElementwise(op) &&
@@ -1976,6 +1987,7 @@ vectorizeScalableVectorPrecondition(Operation *op,
// 1. exactly 1 dim is scalable and that's the _last_ parallel dim
// 2. exactly 2 dims are scalable and those are the _last two adjacent_
// parallel dims
+ // 3. exactly 1 reduction dim is scalable and that's the last (innermost) dim
// The 2nd restriction above means that only Matmul-like Ops are supported
// when 2 dims are scalable, e.g. :
// * iterators = [parallel, parallel, reduction]
@@ -1992,19 +2004,45 @@ vectorizeScalableVectorPrecondition(Operation *op,
scalableFlags.pop_back();
}
- // TODO: Support scalable vectorisation for reduction dims
- if (iterators.back() == utils::IteratorType::reduction)
- return failure();
-
- // If this is not the _last_ parallel dim, 1. above is not met
- if (seenParalell)
- return failure();
+ switch (iterators.back()) {
+ case utils::IteratorType::reduction: {
+ // Check 3. above is met.
+ if (iterators.size() != inputVectorSizes.size()) {
+ LDBG("Non-trailing reduction dim requested for scalable "
+ "vectorization\n");
+ return failure();
+ }
+ if (isa<linalg::MatmulOp>(op) || isa<linalg::MatmulTransposeAOp>(op)) {
+ LDBG("Scalable vectorization of the reduction dim in Matmul-like ops "
+ "is not supported\n");
+ return failure();
+ }
+ break;
+ }
+ case utils::IteratorType::parallel: {
+ // Check 1. and 2. above are met.
+ if (seenParalell) {
+ LDBG("Inner parallel dim not requested for scalable "
+ "vectorization\n");
+ return failure();
+ }
+ break;
+ }
+ }
// If present, check the 2nd scalable dim. ATM, only Matmul-like Ops are
// supported for which expect the folowing config:
// * iterators = [parallel, parallel, reduction]
// * scalable flags = [true, true, false]
if (numOfScalableDims == 2) {
+ // Disallow below case which breaks 3. above:
+ // * iterators = [..., parallel, reduction]
+ // * scalable flags = [..., true, true]
+ if (iterators.back() == utils::IteratorType::reduction) {
+ LDBG("Higher dim than the trailing reduction dim requested for scalable "
+ "vectorization\n");
+ return failure();
+ }
scalableFlags.pop_back();
iterators.pop_back();
@@ -2017,7 +2055,8 @@ vectorizeScalableVectorPrecondition(Operation *op,
// presence of scalable vectors
return success(isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
isa<linalg::MatmulTransposeAOp>(op) ||
- isa<linalg::DepthwiseConv1DNwcWcOp>(op));
+ isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
+ isa<linalg::MatvecOp>(op) || hasReductionIterator(linalgOp));
}
LogicalResult mlir::linalg::vectorizeOpPrecondition(
diff --git a/mlir/test/Dialect/Linalg/vectorization-scalable.mlir b/mlir/test/Dialect/Linalg/vectorization-scalable.mlir
index 4423ee6ea6a51..4ee3088cc3778 100644
--- a/mlir/test/Dialect/Linalg/vectorization-scalable.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization-scalable.mlir
@@ -189,3 +189,168 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+func.func @vectorize_dynamic_reduction_scalable_1d(%arg0: tensor<?xf32>,
+ %arg1: tensor<f32>) -> tensor<f32> {
+
+ %0 = linalg.reduce ins(%arg0 : tensor<?xf32>) outs(%arg1 : tensor<f32>) dimensions = [0]
+ (%in: f32, %init: f32) {
+ %0 = arith.addf %in, %init : f32
+ linalg.yield %0 : f32
+ }
+ return %0 : tensor<f32>
+}
+
+// CHECK-LABEL: func.func @vectorize_dynamic_reduction_scalable_1d(
+// CHECK-SAME: %[[ARG_0:.*]]: tensor<?xf32>, %[[ARG_1:.*]]: tensor<f32>) -> tensor<f32> {
+// CHECK: %[[C0_idx:.*]] = arith.constant 0 : index
+// CHECK: %[[DIM_A0_0:.*]] = tensor.dim %[[ARG_0]], %[[C0_idx]] : tensor<?xf32>
+// CHECK: %[[C0_idx:.*]] = arith.constant 0 : index
+// CHECK: %[[C0_f32:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[MASK:.*]] = vector.create_mask %[[DIM_A0_0]] : vector<[4]xi1>
+// CHECK: %[[VEC_RD_0:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[ARG_0]][%[[C0_idx]]], %[[C0_f32]] {in_bounds = [true]} : tensor<?xf32>, vector<[4]xf32> } : vector<[4]xi1> -> vector<[4]xf32>
+// CHECK: %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[VEC_RD_1:.*]] = vector.transfer_read %[[ARG_1]][], %[[C0_F32]] : tensor<f32>, vector<f32>
+// CHECK: %[[ACC_f32:.*]] = vector.extractelement %[[VEC_RD_1]][] : vector<f32>
+// CHECK: %[[REDUCE:.*]] = vector.mask %[[MASK]] { vector.multi_reduction <add>, %[[VEC_RD_0]], %[[ACC_f32]] [0] : vector<[4]xf32> to f32 } : vector<[4]xi1> -> f32
+// CHECK: %[[VEC_f32:.*]] = vector.broadcast %[[REDUCE]] : f32 to vector<f32>
+// CHECK: %{{.*}} = vector.transfer_write %[[VEC_f32]], %[[ARG_1]][] : vector<f32>, tensor<f32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.reduce"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %0 vector_sizes [[4]] : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// Note: scalable version of `vectorize_dynamic_reduction` in test/Dialect/Linalg/vectorization.mlir.
+func.func @vectorize_dynamic_reduction_scalable_2d(%arg0: tensor<?x?xf32>,
+ %arg1: tensor<?xf32>) -> tensor<?xf32> {
+ %0 = linalg.generic { indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0)>],
+ iterator_types = ["parallel", "reduction"] }
+ ins(%arg0 : tensor<?x?xf32>)
+ outs(%arg1 : tensor<?xf32>) {
+ ^bb(%in: f32, %out: f32) :
+ %0 = arith.addf %in, %out : f32
+ linalg.yield %0 : f32
+ } -> tensor<?xf32>
+ return %0 : tensor<?xf32>
+}
+
+// CHECK-LABEL: func.func @vectorize_dynamic_reduction_scalable_2d(
+// CHECK-SAME: %[[ARG_0:.*]]: tensor<?x?xf32>, %[[ARG_1:.*]]: tensor<?xf32>) -> tensor<?xf32> {
+// CHECK: %[[C0_idx:.*]] = arith.constant 0 : index
+// CHECK: %[[DIM_A0_0:.*]] = tensor.dim %[[ARG_0]], %[[C0_idx]] : tensor<?x?xf32>
+// CHECK: %[[C1_idx:.*]] = arith.constant 1 : index
+// CHECK: %[[DIM_A0_1:.*]] = tensor.dim %[[ARG_0]], %[[C1_idx]] : tensor<?x?xf32>
+// CHECK: %[[C0_idx:.*]] = arith.constant 0 : index
+// CHECK: %[[C0_f32:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[MASK_2d:.*]] = vector.create_mask %[[DIM_A0_0]], %[[DIM_A0_1]] : vector<4x[8]xi1>
+// CHECK: %[[VEC_RD_0:.*]] = vector.mask %[[MASK_2d]] { vector.transfer_read %[[ARG_0]][%[[C0_idx]], %[[C0_idx]]], %[[C0_f32]] {in_bounds = [true, true]} : tensor<?x?xf32>, vector<4x[8]xf32> } : vector<4x[8]xi1> -> vector<4x[8]xf32>
+// CHECK: %[[C0_f32:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[MASK_1d:.*]] = vector.create_mask %[[DIM_A0_0]] : vector<4xi1>
+// CHECK: %[[VEC_RD_1:.*]] = vector.mask %[[MASK_1d]] { vector.transfer_read %[[ARG_1]][%[[C0_idx]]], %[[C0_f32]] {in_bounds = [true]} : tensor<?xf32>, vector<4xf32> } : vector<4xi1> -> vector<4xf32>
+// CHECK: %[[REDUCE:.*]] = vector.mask %[[MASK_2d]] { vector.multi_reduction <add>, %[[VEC_RD_0]], %[[VEC_RD_1]] [1] : vector<4x[8]xf32> to vector<4xf32> } : vector<4x[8]xi1> -> vector<4xf32>
+// CHECK: %[[C0_idx:.*]] = arith.constant 0 : index
+// CHECK: %{{.*}} = vector.mask %[[MASK_1d]] { vector.transfer_write %[[REDUCE]], %[[ARG_1]][%[[C0_idx]]] {in_bounds = [true]} : vector<4xf32>, tensor<?xf32> } : vector<4xi1> -> tensor<?xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %0 vector_sizes [4, [8]] : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @vectorize_dynamic_matvec_trailing_reduction_dim(%arg0: tensor<?x?xf32>,
+ %arg1: tensor<?xf32>,
+ %arg2: tensor<?xf32>) {
+ linalg.matvec ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?xf32>)
+ outs(%arg2 : tensor<?xf32>) -> tensor<?xf32>
+ return
+}
+
+// CHECK-LABEL: func.func @vectorize_dynamic_matvec_trailing_reduction_dim(
+// CHECK-SAME: %[[ARG_0:.*]]: tensor<?x?xf32>, %[[ARG_1:.*]]: tensor<?xf32>, %[[ARG_2:.*]]: tensor<?xf32>) {
+// CHECK: %[[C0_idx:.*]] = arith.constant 0 : index
+// CHECK: %[[DIM_A0_0:.*]] = tensor.dim %[[ARG_0]], %[[C0_idx]] : tensor<?x?xf32>
+// CHECK: %[[C1_idx:.*]] = arith.constant 1 : index
+// CHECK: %[[DIM_A0_1:.*]] = tensor.dim %[[ARG_0]], %[[C1_idx]] : tensor<?x?xf32>
+// CHECK: %[[C0_idx:.*]] = arith.constant 0 : index
+// CHECK: %[[C0_f32:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[MASK_2d:.*]] = vector.create_mask %[[DIM_A0_0]], %[[DIM_A0_1]] : vector<4x[4]xi1>
+// CHECK: %[[VEC_RD_0:.*]] = vector.mask %[[MASK_2d]] { vector.transfer_read %[[ARG_0]][%[[C0_idx]], %[[C0_idx]]], %[[C0_f32]] {in_bounds = [true, true]} : tensor<?x?xf32>, vector<4x[4]xf32> } : vector<4x[4]xi1> -> vector<4x[4]xf32>
+// CHECK: %[[C0_f32:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[MASK_d1:.*]] = vector.create_mask %[[DIM_A0_1]] : vector<[4]xi1>
+// CHECK: %[[VEC_RD_1:.*]] = vector.mask %[[MASK_d1]] { vector.transfer_read %[[ARG_1]][%[[C0_idx]]], %[[C0_f32]] {in_bounds = [true, true], permutation_map = #map} : tensor<?xf32>, vector<4x[4]xf32> } : vector<[4]xi1> -> vector<4x[4]xf32>
+// CHECK: %[[C0_f32:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[MASK_d2:.*]] = vector.create_mask %[[DIM_A0_0]] : vector<4xi1>
+// CHECK: %[[VEC_RD_2:.*]] = vector.mask %[[MASK_d2]] { vector.transfer_read %[[ARG_2]][%[[C0_idx]]], %[[C0_f32]] {in_bounds = [true]} : tensor<?xf32>, vector<4xf32> } : vector<4xi1> -> vector<4xf32>
+// CHECK: %[[MUL:.*]] = arith.mulf %[[VEC_RD_0:.*]], %[[VEC_RD_1:.*]] : vector<4x[4]xf32>
+// CHECK: %[[REDUCE:.*]] = vector.mask %[[MASK_2d]] { vector.multi_reduction <add>, %[[MUL]], %[[VEC_RD_2]] [1] : vector<4x[4]xf32> to vector<4xf32> } : vector<4x[4]xi1> -> vector<4xf32>
+// CHECK: %[[C0_idx:.*]] = arith.constant 0 : index
+// CHECK: %{{.*}} = vector.mask %[[MASK_d2]] { vector.transfer_write %[[REDUCE]], %[[ARG_2]][%[[C0_idx]]] {in_bounds = [true]} : vector<4xf32>, tensor<?xf32> } : vector<4xi1> -> tensor<?xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.matvec"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %0 vector_sizes [4, [4]] : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @vectorize_dynamic_generic_matvec_leading_parallel_dim(%arg0: tensor<?x?xf32>,
+ %arg1: tensor<?xf32>,
+ %arg2: tensor<?xf32>) -> tensor<?xf32> {
+ %0 = linalg.generic { indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d1)>,
+ affine_map<(d0, d1) -> (d0)>],
+ iterator_types = ["parallel", "reduction"] }
+ ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?xf32>)
+ outs(%arg2 : tensor<?xf32>) {
+ ^bb(%mat: f32, %vec: f32, %res: f32) :
+ %0 = arith.mulf %mat, %vec : f32
+ %1 = arith.addf %res, %0 : f32
+ linalg.yield %1 : f32
+ } -> tensor<?xf32>
+ return %0 : tensor<?xf32>
+}
+
+// CHECK-LABEL: func.func @vectorize_dynamic_generic_matvec_leading_parallel_dim(
+// CHECK-SAME: %[[ARG_0:.*]]: tensor<?x?xf32>, %[[ARG_1:.*]]: tensor<?xf32>, %[[ARG_2:.*]]: tensor<?xf32>) -> tensor<?xf32> {
+// CHECK: %[[C0_idx:.*]] = arith.constant 0 : index
+// CHECK: %[[DIM_A0_0:.*]] = tensor.dim %[[ARG_0]], %[[C0_idx]] : tensor<?x?xf32>
+// CHECK: %[[C1_idx:.*]] = arith.constant 1 : index
+// CHECK: %[[DIM_A0_1:.*]] = tensor.dim %[[ARG_0]], %[[C1_idx]] : tensor<?x?xf32>
+// CHECK: %[[C0_idx:.*]] = arith.constant 0 : index
+// CHECK: %[[C0_f32:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[MASK_2d:.*]] = vector.create_mask %[[DIM_A0_0]], %[[DIM_A0_1]] : vector<[4]x4xi1>
+// CHECK: %[[VEC_RD_0:.*]] = vector.mask %[[MASK_2d]] { vector.transfer_read %[[ARG_0]][%[[C0_idx]], %[[C0_idx]]], %[[C0_f32]] {in_bounds = [true, true]} : tensor<?x?xf32>, vector<[4]x4xf32> } : vector<[4]x4xi1> -> vector<[4]x4xf32>
+// CHECK: %[[C0_f32:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[MASK_d1:.*]] = vector.create_mask %[[DIM_A0_1]] : vector<4xi1>
+// CHECK: %[[VEC_RD_1:.*]] = vector.mask %[[MASK_d1]] { vector.transfer_read %[[ARG_1]][%[[C0_idx]]], %[[C0_f32]] {in_bounds = [true, true], permutation_map = #map} : tensor<?xf32>, vector<[4]x4xf32> } : vector<4xi1> -> vector<[4]x4xf32>
+// CHECK: %[[C0_f32:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[MASK_d2:.*]] = vector.create_mask %[[DIM_A0_0]] : vector<[4]xi1>
+// CHECK: %[[VEC_RD_2:.*]] = vector.mask %[[MASK_d2]] { vector.transfer_read %[[ARG_2]][%[[C0_idx]]], %[[C0_f32]] {in_bounds = [true]} : tensor<?xf32>, vector<[4]xf32> } : vector<[4]xi1> -> vector<[4]xf32>
+// CHECK: %[[MUL:.*]] = arith.mulf %[[VEC_RD_0:.*]], %[[VEC_RD_1:.*]] : vector<[4]x4xf32>
+// CHECK: %[[REDUCE:.*]] = vector.mask %[[MASK_2d]] { vector.multi_reduction <add>, %[[MUL]], %[[VEC_RD_2]] [1] : vector<[4]x4xf32> to vector<[4]xf32> } : vector<[4]x4xi1> -> vector<[4]xf32>
+// CHECK: %[[C0_idx:.*]] = arith.constant 0 : index
+// CHECK: %{{.*}} = vector.mask %[[MASK_d2]] { vector.transfer_write %[[REDUCE]], %[[ARG_2]][%[[C0_idx]]] {in_bounds = [true]} : vector<[4]xf32>, tensor<?xf32> } : vector<[4]xi1> -> tensor<?xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %0 vector_sizes [[4], 4] : !transform.any_op
+ transform.yield
+ }
+}
diff --git a/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir b/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir
index c7ec39b0dbfb3..e9f8e08ca0c6b 100644
--- a/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir
@@ -129,35 +129,35 @@ module attributes {transform.with_named_sequence} {
// -----
-func.func @linalg_reduce_scalable(%input: tensor<?xf32>,
- %acc: tensor<f32>) -> tensor<f32> {
+func.func @linalg_reduce_scalable_leading_dim(%input: tensor<?x?xf32>,
+ %acc: tensor<?xf32>) -> tensor<?xf32> {
// expected-error @+1 {{Attempted to vectorize, but failed}}
- %0 = linalg.reduce ins(%input : tensor<?xf32>) outs(%acc : tensor<f32>) dimensions = [0]
+ %0 = linalg.reduce ins(%input : tensor<?x?xf32>) outs(%acc : tensor<?xf32>) dimensions = [0]
(%in: f32, %init: f32) {
%0 = arith.addf %in, %init : f32
linalg.yield %0 : f32
}
- return %0 : tensor<f32>
+ return %0 : tensor<?xf32>
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.reduce"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- transform.structured.vectorize %0 vector_sizes [[4]] : !transform.any_op
+ transform.structured.vectorize %0 vector_sizes [[4], 1] : !transform.any_op
transform.yield
}
}
// -----
-func.func @linalg_generic_scalable_reduction_dim(%input: tensor<?x?xf32>,
- %acc: tensor<?xf32>) -> tensor<?xf32> {
+func.func @linalg_generic_reduction_scalable_leading_dim(%input: tensor<?x?xf32>,
+ %acc: tensor<?xf32>) -> tensor<?xf32> {
// expected-error @+1 {{Attempted to vectorize, but failed}}
%0 = linalg.generic { indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
- affine_map<(d0, d1) -> (d0)>],
- iterator_types = ["parallel", "reduction"] }
+ affine_map<(d0, d1) -> (d1)>],
+ iterator_types = ["reduction", "parallel"] }
ins(%input : tensor<?x?xf32>)
outs(%acc : tensor<?xf32>) {
^bb(%in: f32, %out: f32) :
@@ -170,7 +170,24 @@ func.func @linalg_generic_scalable_reduction_dim(%input: tensor<?x?xf32>,
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- transform.structured.vectorize %0 vector_sizes [1, [4]] : !transform.any_op
+ transform.structured.vectorize %0 vector_sizes [[4], 1] : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @linalg_matvec_scalable_two_dims(%A: memref<?x?xf32>, %B: memref<?xf32>, %C: memref<?xf32>) {
+ // expected-error @+1 {{Attempted to vectorize, but failed}}
+ linalg.matvec ins(%A, %B: memref<?x?xf32>, memref<?xf32>)
+ outs(%C: memref<?xf32>)
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %matmul = transform.structured.match ops{["linalg.matvec"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %matmul vector_sizes [[4], [4]] : !transform.any_op
transform.yield
}
}
@@ -180,7 +197,7 @@ module attributes {transform.with_named_sequence} {
func.func @linalg_matmul_scalable_leading_parallel_dim(%A: memref<?x?xf32>, %B: memref<?x?xf32>, %C: memref<?x?xf32>) {
// expected-error @+1 {{Attempted to vectorize, but failed}}
linalg.matmul ins(%A, %B: memref<?x?xf32>, memref<?x?xf32>)
- outs(%C: memref<?x?xf32>)
+ outs(%C: memref<?x?xf32>)
return
}
@@ -191,3 +208,48 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+func.func @linalg_matmul_scalable_trailing_reduction_dim(%A: memref<?x?xf32>, %B: memref<?x?xf32>, %C: memref<?x?xf32>) {
+ // expected-error @+1 {{Attempted to vectorize, but failed}}
+ linalg.matmul ins(%A, %B: memref<?x?xf32>, memref<?x?xf32>)
+ outs(%C: memref<?x?xf32>)
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %matmul vector_sizes [8, 16, [4]] : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @linalg_generic_matmul_scalable_two_trailing_dims(%A: tensor<?x64xf32>, %B: tensor<64x?xf32>,
+ %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
+
+ // expected-error @+1 {{Attempted to vectorize, but failed}}
+ %0 = linalg.generic { indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d2, d1)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel", "reduction"] }
+ ins(%A, %B : tensor<?x64xf32>, tensor<64x?xf32>)
+ outs(%C: tensor<?x?xf32>) {
+ ^bb(%in1: f32, %in2: f32, %out: f32) :
+ %0 = arith.mulf %in1, %in2 : f32
+ %1 = arith.addf %0, %out : f32
+ linalg.yield %1 : f32
+ } -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %0 vector_sizes [2, [4], [4]] : !transform.any_op
+ transform.yield
+ }
+}
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_1d.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_1d.mlir
new file mode 100644
index 0000000000000..7cdb35918c4c0
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_1d.mlir
@@ -0,0 +1,175 @@
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE: -transform-interpreter -test-transform-dialect-erase-schedule \
+// DEFINE: -one-shot-bufferize="bufferize-function-boundaries" -buffer-deallocation-pipeline -cse -canonicalize -convert-vector-to-scf -arm-sve-legalize-vector-storage \
+// DEFINE: -convert-vector-to-llvm="enable-arm-sve" -test-lower-to-llvm -o %t
+// DEFINE: %{entry_point} = reduce_1d_f32
+// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void --march=aarch64 --mattr="+sve"\
+// DEFINE: -shared-libs=%mlir_native_utils_lib_dir/libmlir_runner_utils%shlibext,%mlir_native_utils_lib_dir/libmlir_c_runner_utils%shlibext
+
+// RUN: %{compile}
+
+// RUN: %{run} | FileCheck %s --check-prefix=REDUCE-F32
+
+// REDEFINE: %{entry_point} = reduce_1d_i32
+// RUN: %{run} | FileCheck %s --check-prefix=REDUCE-I32
+
+// REDEFINE: %{entry_point} = generic_reduce_1d_f32
+// RUN: %{run} | FileCheck %s --check-prefix=GENERIC-F32
+
+func.func @reduce_1d_f32() {
+ // 1-D Tensor
+ %N = arith.constant 1000 : index
+ %c0_f32 = arith.constant 0.0 : f32
+
+ // Allocate the input and output tensors
+ %A_alloc = bufferization.alloc_tensor(%N) : tensor<?xf32>
+ %C_alloc = bufferization.alloc_tensor() : tensor<f32>
+
+ // Initialise the tensors
+ %pi = arith.constant 3.1416 : f32
+ %A_in = linalg.fill ins(%pi : f32) outs(%A_alloc : tensor<?xf32>) -> tensor<?xf32>
+ %C_in = tensor.insert %c0_f32 into %C_alloc[] : tensor<f32>
+
+ // Reduce
+ %C_out = linalg.reduce ins(%A_in : tensor<?xf32>) outs(%C_in: tensor<f32>) dimensions = [0]
+ (%in: f32, %init: f32) {
+ %0 = arith.addf %in, %init : f32
+ linalg.yield %0 : f32
+ }
+
+ // Print and verify the output
+ // REDUCE-F32-LABEL: SVE: START OF TEST OUTPUT
+ vector.print str "SVE: START OF TEST OUTPUT\n"
+
+ // REDUCE-F32-NEXT: Unranked Memref {{.*}} rank = 0 offset = 0 sizes = [] strides = [] data =
+ // REDUCE-F32-NEXT: [3141.6]
+
+ %xf = tensor.cast %C_out : tensor<f32> to tensor<*xf32>
+ call @printMemrefF32(%xf) : (tensor<*xf32>) -> ()
+
+ // REDUCE-F32-NEXT: SVE: END OF TEST OUTPUT
+ vector.print str "SVE: END OF TEST OUTPUT\n"
+
+ return
+}
+
+func.func @reduce_1d_i32() {
+ // 1-D Tensor
+ %N = arith.constant 1000 : index
+ %c0_i32 = arith.constant 0 : i32
+
+ // Allocate the input and output tensors
+ %A_alloc = bufferization.alloc_tensor(%N) : tensor<?xi32>
+ %C_alloc = bufferization.alloc_tensor() : tensor<i32>
+
+ // Initialise the tensors
+ %pi = arith.constant 3 : i32
+ %A_in = linalg.fill ins(%pi : i32) outs(%A_alloc : tensor<?xi32>) -> tensor<?xi32>
+ %C_in = tensor.insert %c0_i32 into %C_alloc[] : tensor<i32>
+
+ // Reduce
+ %C_out = linalg.reduce ins(%A_in : tensor<?xi32>) outs(%C_in: tensor<i32>) dimensions = [0]
+ (%in: i32, %init: i32) {
+ %0 = arith.addi %in, %init : i32
+ linalg.yield %0 : i32
+ }
+
+ // Print and verify the output
+ // REDUCE-I32-LABEL: SVE: START OF TEST OUTPUT
+ vector.print str "SVE: START OF TEST OUTPUT\n"
+
+ // REDUCE-I32-NEXT: Unranked Memref {{.*}} rank = 0 offset = 0 sizes = [] strides = [] data =
+ // REDUCE-I32-NEXT: [3000]
+
+ %xf = tensor.cast %C_out : tensor<i32> to tensor<*xi32>
+ call @printMemrefI32(%xf) : (tensor<*xi32>) -> ()
+
+ // REDUCE-I32-NEXT: SVE: END OF TEST OUTPUT
+ vector.print str "SVE: END OF TEST OUTPUT\n"
+
+ return
+}
+
+func.func @generic_reduce_1d_f32() {
+ // 1-D Tensor
+ %N = arith.constant 1000 : index
+ %c0_f32 = arith.constant 0.0 : f32
+
+ // Allocate the input and output tensors
+ %A_alloc = bufferization.alloc_tensor(%N) : tensor<?xf32>
+ %C_alloc = bufferization.alloc_tensor() : tensor<f32>
+
+ // Initialise the tensors
+ %pi = arith.constant 3.1416 : f32
+ %A_in = linalg.fill ins(%pi : f32) outs(%A_alloc : tensor<?xf32>) -> tensor<?xf32>
+ %C_in = tensor.insert %c0_f32 into %C_alloc[] : tensor<f32>
+
+ // Reduce
+ %C_out = linalg.generic { indexing_maps = [affine_map<(d0) -> (d0)>,
+ affine_map<(d0) -> ()>],
+ iterator_types = ["reduction"] }
+ ins(%A_in : tensor<?xf32>)
+ outs(%C_in : tensor<f32>) {
+ ^bb(%in: f32, %out: f32) :
+ %0 = arith.addf %in, %out : f32
+ linalg.yield %0 : f32
+ } -> tensor<f32>
+
+ // Print and verify the output
+ // GENERIC-F32-LABEL: SVE: START OF TEST OUTPUT
+ vector.print str "SVE: START OF TEST OUTPUT\n"
+
+ // GENERIC-F32-NEXT: Unranked Memref {{.*}} rank = 0 offset = 0 sizes = [] strides = [] data =
+ // GENERIC-F32-NEXT: [3141.6]
+
+ %xf = tensor.cast %C_out : tensor<f32> to tensor<*xf32>
+ call @printMemrefF32(%xf) : (tensor<*xf32>) -> ()
+
+ // GENERIC-F32-NEXT: SVE: END OF TEST OUTPUT
+ vector.print str "SVE: END OF TEST OUTPUT\n"
+
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ // A sequence that will tile and vectorise a Reduce Op
+ transform.named_sequence @tile_and_vectorize_reduce(%func
+ : !transform.op<"func.func"> {transform.readonly}) {
+
+ // Step 0: Get a handle to the reduce Op
+ %reduce = transform.structured.match ops{["linalg.reduce", "linalg.generic"]} in %func
+ : (!transform.op<"func.func">) -> !transform.any_op
+
+ // Step 1: Tile
+ %tiled_reduce, %loops:1 = transform.structured.tile_using_for %reduce tile_sizes [[4]]
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+
+ // Step 2: Vectorize
+ transform.structured.vectorize %tiled_reduce vector_sizes [[4]] : !transform.any_op
+
+ // Step 3: Lower vector.multi_reduction
+ transform.apply_patterns to %func {
+ transform.apply_patterns.vector.lower_masked_transfers
+ transform.apply_patterns.vector.lower_multi_reduction lowering_strategy = "innerreduction"
+ } : !transform.op<"func.func">
+
+ transform.yield
+ }
+
+ // A sequence that goes over all functions in tis module and applies
+ // "tile_and_vectorize_reduce"
+ transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) {
+ %funcs = transform.structured.match ops{["func.func"]} in %module
+ : (!transform.any_op) -> !transform.op<"func.func">
+
+ transform.foreach %funcs : !transform.op<"func.func"> {
+ ^bb2(%func : !transform.op<"func.func">):
+ transform.include @tile_and_vectorize_reduce failures(propagate)
+ (%func) : (!transform.op<"func.func">) -> ()
+ }
+ transform.yield
+ }
+}
+
+func.func private @printMemrefF32(%ptr : tensor<*xf32>)
+func.func private @printMemrefI32(%ptr : tensor<*xi32>)
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_2d.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_2d.mlir
new file mode 100644
index 0000000000000..bcfe12e374b4e
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_2d.mlir
@@ -0,0 +1,180 @@
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE: -transform-interpreter -test-transform-dialect-erase-schedule \
+// DEFINE: -one-shot-bufferize="bufferize-function-boundaries" -buffer-deallocation-pipeline -cse -canonicalize -convert-vector-to-scf -arm-sve-legalize-vector-storage \
+// DEFINE: -convert-vector-to-llvm="enable-arm-sve" -test-lower-to-llvm -o %t
+// DEFINE: %{entry_point} = reduce_2d_f32
+// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void --march=aarch64 --mattr="+sve"\
+// DEFINE: -shared-libs=%mlir_native_utils_lib_dir/libmlir_runner_utils%shlibext,%mlir_native_utils_lib_dir/libmlir_c_runner_utils%shlibext
+
+// RUN: %{compile}
+
+// RUN: %{run} | FileCheck %s --check-prefix=REDUCE
+
+// REDEFINE: %{entry_point} = generic_reduce_2d_f32
+// RUN: %{run} | FileCheck %s --check-prefix=GENERIC
+
+func.func @reduce_2d_f32() {
+ // 2-D Tensor
+ %M = arith.constant 16 : index
+ %N = arith.constant 1000 : index
+ %c0_f32 = arith.constant 0.0 : f32
+
+ // Allocate the input and output tensors
+ %A_alloc = bufferization.alloc_tensor(%M, %N) : tensor<?x?xf32>
+ %C_alloc = bufferization.alloc_tensor(%M) : tensor<?xf32>
+
+ // Initialise the tensors
+ %pi = arith.constant 3.1416 : f32
+ %A_in = linalg.fill ins(%pi : f32) outs(%A_alloc : tensor<?x?xf32>) -> tensor<?x?xf32>
+ %C_in = linalg.fill ins(%c0_f32 : f32) outs(%C_alloc : tensor<?xf32>) -> tensor<?xf32>
+
+ // Reduce
+ %C_out = linalg.reduce ins(%A_in : tensor<?x?xf32>) outs(%C_in: tensor<?xf32>) dimensions = [1]
+ (%in: f32, %init: f32) {
+ %0 = arith.addf %in, %init : f32
+ linalg.yield %0 : f32
+ }
+
+ // Print and verify the output
+ // REDUCE-LABEL: SVE: START OF TEST OUTPUT
+ vector.print str "SVE: START OF TEST OUTPUT\n"
+
+ // REDUCE-NEXT: Unranked Memref {{.*}} rank = 1 offset = 0 sizes = [16] strides = [1] data =
+ // REDUCE-NEXT: [3141.6, 3141.6, 3141.6, 3141.6, 3141.6, 3141.6, 3141.6, 3141.6, 3141.6, 3141.6, 3141.6, 3141.6, 3141.6, 3141.6, 3141.6, 3141.6]
+
+ %xf = tensor.cast %C_out : tensor<?xf32> to tensor<*xf32>
+ call @printMemrefF32(%xf) : (tensor<*xf32>) -> ()
+
+ // REDUCE-NEXT: SVE: END OF TEST OUTPUT
+ vector.print str "SVE: END OF TEST OUTPUT\n"
+
+ return
+}
+
+func.func @generic_reduce_2d_f32() {
+ // 2-D Tensor
+ %M = arith.constant 16 : index
+ %N = arith.constant 1000 : index
+ %c0_f32 = arith.constant 0.0 : f32
+
+ // Allocate the input and output tensors
+ %A_alloc = bufferization.alloc_tensor(%M, %N) : tensor<?x?xf32>
+ %C_alloc = bufferization.alloc_tensor(%M) : tensor<?xf32>
+
+ // Initialise the tensors
+ %pi = arith.constant 3.1416 : f32
+ %A_in = linalg.fill ins(%pi : f32) outs(%A_alloc : tensor<?x?xf32>) -> tensor<?x?xf32>
+ %C_in = linalg.fill ins(%c0_f32 : f32) outs(%C_alloc : tensor<?xf32>) -> tensor<?xf32>
+
+ // Reduce
+ %C_out = linalg.generic { indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0)>],
+ iterator_types = ["parallel", "reduction"] }
+ ins(%A_in : tensor<?x?xf32>)
+ outs(%C_in : tensor<?xf32>) {
+ ^bb(%in: f32, %out: f32) :
+ %0 = arith.addf %in, %out : f32
+ linalg.yield %0 : f32
+ } -> tensor<?xf32>
+
+ // Print and verify the output
+ // GENERIC-LABEL: SVE: START OF TEST OUTPUT
+ vector.print str "SVE: START OF TEST OUTPUT\n"
+
+ // GENERIC-NEXT: Unranked Memref {{.*}} rank = 1 offset = 0 sizes = [16] strides = [1] data =
+ // GENERIC-NEXT: [3141.6, 3141.6, 3141.6, 3141.6, 3141.6, 3141.6, 3141.6, 3141.6, 3141.6, 3141.6, 3141.6, 3141.6, 3141.6, 3141.6, 3141.6, 3141.6]
+
+ %xf = tensor.cast %C_out : tensor<?xf32> to tensor<*xf32>
+ call @printMemrefF32(%xf) : (tensor<*xf32>) -> ()
+
+ // GENERIC-NEXT: SVE: END OF TEST OUTPUT
+ vector.print str "SVE: END OF TEST OUTPUT\n"
+
+ return
+}
+
+func.func @generic_reduce_2d_i32() {
+ // 2-D Tensor
+ %M = arith.constant 16 : index
+ %N = arith.constant 1000 : index
+ %c0_i32 = arith.constant 0 : i32
+
+ // Allocate the input and output tensors
+ %A_alloc = bufferization.alloc_tensor(%M, %N) : tensor<?x?xi32>
+ %C_alloc = bufferization.alloc_tensor(%M) : tensor<?xi32>
+
+ // Initialise the tensors
+ %pi = arith.constant 3 : i32
+ %A_in = linalg.fill ins(%pi : i32) outs(%A_alloc : tensor<?x?xi32>) -> tensor<?x?xi32>
+ %C_in = linalg.fill ins(%c0_i32 : i32) outs(%C_alloc : tensor<?xi32>) -> tensor<?xi32>
+
+ // Reduce
+ %C_out = linalg.generic { indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0)>],
+ iterator_types = ["parallel", "reduction"] }
+ ins(%A_in : tensor<?x?xi32>)
+ outs(%C_in : tensor<?xi32>) {
+ ^bb(%in: i32, %out: i32) :
+ %0 = arith.addi %in, %out : i32
+ linalg.yield %0 : i32
+ } -> tensor<?xi32>
+
+ // Print and verify the output
+ // GENERIC-I32-LABEL: SVE: START OF TEST OUTPUT
+ vector.print str "SVE: START OF TEST OUTPUT\n"
+
+ // GENERIC-I32-NEXT: Unranked Memref {{.*}} rank = 1 offset = 0 sizes = [16] strides = [1] data =
+ // GENERIC-I32-NEXT: [3000, 3000, 3000, 3000, 3000, 3000, 3000, 3000, 3000, 3000, 3000, 3000, 3000, 3000, 3000, 3000]
+
+ %xf = tensor.cast %C_out : tensor<?xi32> to tensor<*xi32>
+ call @printMemrefI32(%xf) : (tensor<*xi32>) -> ()
+
+ // GENERIC-I32-NEXT: SVE: END OF TEST OUTPUT
+ vector.print str "SVE: END OF TEST OUTPUT\n"
+
+ return
+}
+
+
+module attributes {transform.with_named_sequence} {
+ // A sequence that will tile and vectorise a Reduce Op
+ transform.named_sequence @tile_and_vectorize_reduce(%func
+ : !transform.op<"func.func"> {transform.readonly}) {
+
+ // Step 0: Get a handle to the reduce Op
+ %reduce = transform.structured.match ops{["linalg.reduce", "linalg.generic"]} in %func
+ : (!transform.op<"func.func">) -> !transform.any_op
+
+ // Step 1: Tile
+ %tiled_reduce, %loops:2 = transform.structured.tile_using_for %reduce tile_sizes [1, [4]]
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+
+ // Step 2: Vectorize
+ transform.structured.vectorize %tiled_reduce vector_sizes [1, [4]] : !transform.any_op
+
+ // Step 3: Lower vector.multi_reduction
+ transform.apply_patterns to %func {
+ transform.apply_patterns.vector.lower_masked_transfers
+ transform.apply_patterns.vector.lower_multi_reduction lowering_strategy = "innerreduction"
+ } : !transform.op<"func.func">
+
+ transform.yield
+ }
+
+ // A sequence that goes over all functions in tis module and applies
+ // "tile_and_vectorize_reduce"
+ transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) {
+ %funcs = transform.structured.match ops{["func.func"]} in %module
+ : (!transform.any_op) -> !transform.op<"func.func">
+
+ transform.foreach %funcs : !transform.op<"func.func"> {
+ ^bb2(%func : !transform.op<"func.func">):
+ transform.include @tile_and_vectorize_reduce failures(propagate)
+ (%func) : (!transform.op<"func.func">) -> ()
+ }
+ transform.yield
+ }
+}
+
+func.func private @printMemrefF32(%ptr : tensor<*xf32>)
+func.func private @printMemrefI32(%ptr : tensor<*xi32>)
More information about the Mlir-commits
mailing list