[Mlir-commits] [mlir] c7a3346 - [mlir][linalg] Fix scalable vectorisation of tensor.extract (#100325)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jul 25 01:44:37 PDT 2024
Author: Andrzej WarzyĆski
Date: 2024-07-25T09:44:34+01:00
New Revision: c7a3346ab6a8fbd551a80bd4028ec8624daa35e4
URL: https://github.com/llvm/llvm-project/commit/c7a3346ab6a8fbd551a80bd4028ec8624daa35e4
DIFF: https://github.com/llvm/llvm-project/commit/c7a3346ab6a8fbd551a80bd4028ec8624daa35e4.diff
LOG: [mlir][linalg] Fix scalable vectorisation of tensor.extract (#100325)
This PR fixes one very specific aspect of vectorising `tensor.extract`
Ops when targeting scalable vectors. Namely, it makes sure that the
scalable flag is correctly propagated when creating
`vector::ShapeCastOp`.
BEFORE:
```mlir
vector.shape_cast %idx_vec : vector<1x1x[4]xindex> to vector<4xindex>
```
AFTER:
```mlir
vector.shape_cast %idx_vec : vector<1x1x[4]xindex> to vector<[4]xindex>
```
This particular ShapeCastOp is created when generating an index for
`vector.transfer_read` operations. Strictly speaking, casting is not
really required. However, it makes the subsequent address calculation
much simpler (*).
The following test is updated to demonstrate the use of
`vector.shape_cast` by the vectoriser:
*
@masked_static_vectorize_nd_tensor_extract_with_affine_apply_contiguous
Similar test with scalable vectors is also added.
(*) At this point in the vectoriser it is known
that all leading dims in the index vector are "1").
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index c4dab7d061b4b..9185663799e52 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1077,19 +1077,20 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
// * for vector indices (e.g. `vector<1x1x4xindex>`) - extract the bottom
// (0th) element and use that.
SmallVector<Value> transferReadIdxs;
- auto resTrailingDim = resultType.getShape().back();
auto zero = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI32Type(), rewriter.getZeroAttr(rewriter.getI32Type()));
for (size_t i = 0; i < extractOp.getIndices().size(); i++) {
- auto idx = bvm.lookup(extractOp.getIndices()[i]);
+ Value idx = bvm.lookup(extractOp.getIndices()[i]);
if (idx.getType().isIndex()) {
transferReadIdxs.push_back(idx);
continue;
}
auto indexAs1dVector = rewriter.create<vector::ShapeCastOp>(
- loc, VectorType::get({resTrailingDim}, rewriter.getIndexType()),
- bvm.lookup(extractOp.getIndices()[i]));
+ loc,
+ VectorType::get(resultType.getShape().back(), rewriter.getIndexType(),
+ resultType.getScalableDims().back()),
+ idx);
transferReadIdxs.push_back(
rewriter.create<vector::ExtractElementOp>(loc, indexAs1dVector, zero));
}
diff --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir
index f042753780013..964565620fd01 100644
--- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir
@@ -1,29 +1,52 @@
// RUN: mlir-opt %s -transform-interpreter -split-input-file | FileCheck %s
-func.func @masked_static_vectorize_nd_tensor_extract_with_affine_apply_contiguous(%6: tensor<80x16xf32>, %arg0: index, %extracted_slice : tensor<1x3xf32>) -> tensor<1x3xf32> {
+func.func @masked_static_vectorize_nd_tensor_extract_with_affine_apply_contiguous(
+ %src: tensor<80x16xf32>,
+ %output : tensor<1x3xf32>,
+ %idx: index) -> tensor<1x3xf32> {
+
%c79 = arith.constant 79 : index
%1 = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]
- } outs(%extracted_slice : tensor<1x3xf32>) {
+ } outs(%output : tensor<1x3xf32>) {
^bb0(%out: f32):
%2 = linalg.index 1 : index
- %3 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%2, %arg0)
- %extracted = tensor.extract %6[%c79, %3] : tensor<80x16xf32>
+ %3 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%2, %idx)
+ %extracted = tensor.extract %src[%c79, %3] : tensor<80x16xf32>
linalg.yield %extracted : f32
} -> tensor<1x3xf32>
return %1 : tensor<1x3xf32>
}
// CHECK-LABEL: func.func @masked_static_vectorize_nd_tensor_extract_with_affine_apply_contiguous
-// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 3 : index
-// CHECK: %[[VAL_8:.*]] = vector.create_mask %[[VAL_4]], %[[VAL_5]] : vector<1x4xi1>
-// CHECK: %[[VAL_9:.*]] = vector.mask %[[VAL_8]] { vector.transfer_read {{.*}} {in_bounds = [true, true]} : tensor<1x3xf32>, vector<1x4xf32> } : vector<1x4xi1> -> vector<1x4xf32>
-// CHECK: %[[VAL_11:.*]] = vector.broadcast {{.*}} : index to vector<4xindex>
-// CHECK: %[[VAL_12:.*]] = arith.addi {{.*}} : vector<4xindex>
-// CHECK: %[[VAL_20:.*]] = vector.mask %[[VAL_8]] { vector.transfer_read {{.*}} {in_bounds = [true, true]} : tensor<80x16xf32>, vector<1x4xf32> } : vector<1x4xi1> -> vector<1x4xf32>
-// CHECK: %[[VAL_22:.*]] = vector.mask %[[VAL_8]] { vector.transfer_write {{.*}} {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x3xf32> } : vector<1x4xi1> -> tensor<1x3xf32>
+// CHECK-SAME: %[[SRC:.*]]: tensor<80x16xf32>,
+// CHECK-SAME: %[[OUTPUT:.*]]: tensor<1x3xf32>,
+// CHECK-SAME: %[[IDX_IN:.*]]: index) -> tensor<1x3xf32> {
+
+/// Create the mask
+// CHECK-DAG: %[[DIM_0:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[DIM_1:.*]] = arith.constant 3 : index
+// CHECK-DAG: %[[C79:.*]] = arith.constant 79 : index
+// CHECK: %[[MASK:.*]] = vector.create_mask %[[DIM_0]], %[[DIM_1]] : vector<1x4xi1>
+
+/// TODO: This transfer_read is redundant - remove
+// CHECK: vector.mask %[[MASK]] { vector.transfer_read {{.*}} {in_bounds = [true, true]} : tensor<1x3xf32>, vector<1x4xf32> } : vector<1x4xi1> -> vector<1x4xf32>
+
+/// Caluclate the index vector
+// CHECK: %[[STEP:.*]] = vector.step : vector<4xindex>
+// CHECK: %[[IDX_BC:.*]] = vector.broadcast %[[IDX_IN]] : index to vector<4xindex>
+// CHECK: %[[IDX_VEC:.*]] = arith.addi %[[STEP]], %[[IDX_BC]] : vector<4xindex>
+// CHECK: %[[C0:.*]] = arith.constant 0 : i32
+// CHECK: %[[SC:.*]] = vector.shape_cast %[[IDX_VEC]] : vector<4xindex> to vector<4xindex>
+
+/// Extract the starting point from the index vector
+// CHECK: %[[IDX_START:.*]] = vector.extractelement %[[SC]]{{\[}}%[[C0]] : i32] : vector<4xindex>
+
+// Final read and write
+// CHECK: %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[SRC]]{{\[}}%[[C79]], %[[IDX_START]]], {{.*}} {in_bounds = [true, true]} : tensor<80x16xf32>, vector<1x4xf32> } : vector<1x4xi1> -> vector<1x4xf32>
+// CHECK: %[[C0_1:.*]] = arith.constant 0 : index
+// CHECK: vector.mask %[[MASK]] { vector.transfer_write %[[READ]], %[[OUTPUT]]{{\[}}%[[C0_1]], %[[C0_1]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x3xf32> } : vector<1x4xi1> -> tensor<1x3xf32>
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
@@ -33,7 +56,69 @@ module attributes {transform.with_named_sequence} {
}
}
- // -----
+// -----
+
+// Identical to the above, but with scalable vectors.
+
+func.func @masked_static_vectorize_nd_tensor_extract_with_affine_apply_contiguous_scalable(
+ %src: tensor<80x16xf32>,
+ %output : tensor<1x3xf32>,
+ %idx: index) -> tensor<1x3xf32> {
+
+ %c79 = arith.constant 79 : index
+ %1 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]
+ } outs(%output : tensor<1x3xf32>) {
+ ^bb0(%out: f32):
+ %2 = linalg.index 1 : index
+ %3 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%2, %idx)
+ %extracted = tensor.extract %src[%c79, %3] : tensor<80x16xf32>
+ linalg.yield %extracted : f32
+ } -> tensor<1x3xf32>
+
+ return %1 : tensor<1x3xf32>
+}
+
+// CHECK-LABEL: func.func @masked_static_vectorize_nd_tensor_extract_with_affine_apply_contiguous_scalable
+// CHECK-SAME: %[[SRC:.*]]: tensor<80x16xf32>,
+// CHECK-SAME: %[[OUTPUT:.*]]: tensor<1x3xf32>,
+// CHECK-SAME: %[[IDX_IN:.*]]: index) -> tensor<1x3xf32> {
+
+/// Create the mask
+// CHECK-DAG: %[[DIM_0:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[DIM_1:.*]] = arith.constant 3 : index
+// CHECK-DAG: %[[C79:.*]] = arith.constant 79 : index
+// CHECK: %[[MASK:.*]] = vector.create_mask %[[DIM_0]], %[[DIM_1]] : vector<1x[4]xi1>
+
+/// TODO: This transfer_read is redundant - remove
+// CHECK: vector.mask %[[MASK]] { vector.transfer_read {{.*}} {in_bounds = [true, true]} : tensor<1x3xf32>, vector<1x[4]xf32> } : vector<1x[4]xi1> -> vector<1x[4]xf32>
+
+/// Caluclate the index vector
+// CHECK: %[[STEP:.*]] = vector.step : vector<[4]xindex>
+// CHECK: %[[IDX_BC:.*]] = vector.broadcast %[[IDX_IN]] : index to vector<[4]xindex>
+// CHECK: %[[IDX_VEC:.*]] = arith.addi %[[STEP]], %[[IDX_BC]] : vector<[4]xindex>
+// CHECK: %[[C0:.*]] = arith.constant 0 : i32
+// CHECK: %[[SC:.*]] = vector.shape_cast %[[IDX_VEC]] : vector<[4]xindex> to vector<[4]xindex>
+
+/// Extract the starting point from the index vector
+// CHECK: %[[IDX_START:.*]] = vector.extractelement %[[SC]]{{\[}}%[[C0]] : i32] : vector<[4]xindex>
+
+// Final read and write
+// CHECK: %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[SRC]]{{\[}}%[[C79]], %[[IDX_START]]], {{.*}} {in_bounds = [true, true]} : tensor<80x16xf32>, vector<1x[4]xf32> } : vector<1x[4]xi1> -> vector<1x[4]xf32>
+// CHECK: %[[C0_1:.*]] = arith.constant 0 : index
+// CHECK: vector.mask %[[MASK]] { vector.transfer_write %[[READ]], %[[OUTPUT]]{{\[}}%[[C0_1]], %[[C0_1]]] {in_bounds = [true, true]} : vector<1x[4]xf32>, tensor<1x3xf32> } : vector<1x[4]xi1> -> tensor<1x3xf32>
+
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %0 vector_sizes [1, [4]] {vectorize_nd_extract} : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
func.func @masked_dynamic_vectorize_nd_tensor_extract_with_affine_apply_contiguous(%6: tensor<?x?xf32>, %arg0: index, %extracted_slice : tensor<?x?xf32>) -> tensor<?x?xf32> {
%c79 = arith.constant 79 : index
More information about the Mlir-commits
mailing list