[Mlir-commits] [mlir] c7a3346 - [mlir][linalg] Fix scalable vectorisation of tensor.extract (#100325)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jul 25 01:44:37 PDT 2024

Author: Andrzej WarzyƄski
Date: 2024-07-25T09:44:34+01:00
New Revision: c7a3346ab6a8fbd551a80bd4028ec8624daa35e4

URL: https://github.com/llvm/llvm-project/commit/c7a3346ab6a8fbd551a80bd4028ec8624daa35e4
DIFF: https://github.com/llvm/llvm-project/commit/c7a3346ab6a8fbd551a80bd4028ec8624daa35e4.diff

LOG: [mlir][linalg] Fix scalable vectorisation of tensor.extract (#100325)

This PR fixes one very specific aspect of vectorising `tensor.extract`
Ops when targeting scalable vectors. Namely, it makes sure that the
scalable flag is correctly propagated when creating

vector.shape_cast %idx_vec : vector<1x1x[4]xindex> to vector<4xindex>

vector.shape_cast %idx_vec : vector<1x1x[4]xindex> to vector<[4]xindex>

This particular ShapeCastOp is created when generating an index for
`vector.transfer_read` operations. Strictly speaking, casting is not
really required. However, it makes the subsequent address calculation
much simpler (*).

The following test is updated to demonstrate the use of
`vector.shape_cast` by the vectoriser:

Similar test with scalable vectors is also added.

(*) At this point in the vectoriser it is known
that all leading dims in the index vector are "1").




diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index c4dab7d061b4b..9185663799e52 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1077,19 +1077,20 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
   //   * for vector indices (e.g. `vector<1x1x4xindex>`) - extract the bottom
   //    (0th) element and use that.
   SmallVector<Value> transferReadIdxs;
-  auto resTrailingDim = resultType.getShape().back();
   auto zero = rewriter.create<arith::ConstantOp>(
       loc, rewriter.getI32Type(), rewriter.getZeroAttr(rewriter.getI32Type()));
   for (size_t i = 0; i < extractOp.getIndices().size(); i++) {
-    auto idx = bvm.lookup(extractOp.getIndices()[i]);
+    Value idx = bvm.lookup(extractOp.getIndices()[i]);
     if (idx.getType().isIndex()) {
     auto indexAs1dVector = rewriter.create<vector::ShapeCastOp>(
-        loc, VectorType::get({resTrailingDim}, rewriter.getIndexType()),
-        bvm.lookup(extractOp.getIndices()[i]));
+        loc,
+        VectorType::get(resultType.getShape().back(), rewriter.getIndexType(),
+                        resultType.getScalableDims().back()),
+        idx);
         rewriter.create<vector::ExtractElementOp>(loc, indexAs1dVector, zero));

diff  --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir
index f042753780013..964565620fd01 100644
--- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir
@@ -1,29 +1,52 @@
 // RUN: mlir-opt %s -transform-interpreter -split-input-file | FileCheck %s
-func.func @masked_static_vectorize_nd_tensor_extract_with_affine_apply_contiguous(%6: tensor<80x16xf32>, %arg0: index, %extracted_slice : tensor<1x3xf32>) -> tensor<1x3xf32> {
+func.func @masked_static_vectorize_nd_tensor_extract_with_affine_apply_contiguous(
+    %src: tensor<80x16xf32>,
+    %output : tensor<1x3xf32>,
+    %idx: index) -> tensor<1x3xf32> {
   %c79 = arith.constant 79 : index
   %1 = linalg.generic {
     indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>],
     iterator_types = ["parallel", "parallel"]
-  } outs(%extracted_slice : tensor<1x3xf32>) {
+  } outs(%output : tensor<1x3xf32>) {
   ^bb0(%out: f32):
     %2 = linalg.index 1 : index
-    %3 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%2, %arg0)
-    %extracted = tensor.extract %6[%c79, %3] : tensor<80x16xf32>
+    %3 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%2, %idx)
+    %extracted = tensor.extract %src[%c79, %3] : tensor<80x16xf32>
     linalg.yield %extracted : f32
   } -> tensor<1x3xf32>
   return %1 : tensor<1x3xf32>
 // CHECK-LABEL:   func.func @masked_static_vectorize_nd_tensor_extract_with_affine_apply_contiguous
-// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 1 : index
-// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 3 : index
-// CHECK:           %[[VAL_8:.*]] = vector.create_mask %[[VAL_4]], %[[VAL_5]] : vector<1x4xi1>
-// CHECK:           %[[VAL_9:.*]] = vector.mask %[[VAL_8]] { vector.transfer_read {{.*}} {in_bounds = [true, true]} : tensor<1x3xf32>, vector<1x4xf32> } : vector<1x4xi1> -> vector<1x4xf32>
-// CHECK:           %[[VAL_11:.*]] = vector.broadcast {{.*}} : index to vector<4xindex>
-// CHECK:           %[[VAL_12:.*]] = arith.addi {{.*}} : vector<4xindex>
-// CHECK:           %[[VAL_20:.*]] = vector.mask %[[VAL_8]] { vector.transfer_read {{.*}} {in_bounds = [true, true]} : tensor<80x16xf32>, vector<1x4xf32> } : vector<1x4xi1> -> vector<1x4xf32>
-// CHECK:           %[[VAL_22:.*]] = vector.mask %[[VAL_8]] { vector.transfer_write {{.*}} {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x3xf32> } : vector<1x4xi1> -> tensor<1x3xf32>
+// CHECK-SAME:      %[[SRC:.*]]: tensor<80x16xf32>,
+// CHECK-SAME:      %[[OUTPUT:.*]]: tensor<1x3xf32>,
+// CHECK-SAME:      %[[IDX_IN:.*]]: index) -> tensor<1x3xf32> {
+/// Create the mask
+// CHECK-DAG:       %[[DIM_0:.*]] = arith.constant 1 : index
+// CHECK-DAG:       %[[DIM_1:.*]] = arith.constant 3 : index
+// CHECK-DAG:       %[[C79:.*]] = arith.constant 79 : index
+// CHECK:           %[[MASK:.*]] = vector.create_mask %[[DIM_0]], %[[DIM_1]] : vector<1x4xi1>
+/// TODO: This transfer_read is redundant - remove
+// CHECK:           vector.mask %[[MASK]] { vector.transfer_read {{.*}} {in_bounds = [true, true]} : tensor<1x3xf32>, vector<1x4xf32> } : vector<1x4xi1> -> vector<1x4xf32>
+/// Caluclate the index vector
+// CHECK:           %[[STEP:.*]] = vector.step : vector<4xindex>
+// CHECK:           %[[IDX_BC:.*]] = vector.broadcast %[[IDX_IN]] : index to vector<4xindex>
+// CHECK:           %[[IDX_VEC:.*]] = arith.addi %[[STEP]], %[[IDX_BC]] : vector<4xindex>
+// CHECK:           %[[C0:.*]] = arith.constant 0 : i32
+// CHECK:           %[[SC:.*]] = vector.shape_cast %[[IDX_VEC]] : vector<4xindex> to vector<4xindex>
+/// Extract the starting point from the index vector
+// CHECK:           %[[IDX_START:.*]] = vector.extractelement %[[SC]]{{\[}}%[[C0]] : i32] : vector<4xindex>
+// Final read and write
+// CHECK:           %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[SRC]]{{\[}}%[[C79]], %[[IDX_START]]], {{.*}} {in_bounds = [true, true]} : tensor<80x16xf32>, vector<1x4xf32> } : vector<1x4xi1> -> vector<1x4xf32>
+// CHECK:           %[[C0_1:.*]] = arith.constant 0 : index
+// CHECK:           vector.mask %[[MASK]] { vector.transfer_write %[[READ]], %[[OUTPUT]]{{\[}}%[[C0_1]], %[[C0_1]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x3xf32> } : vector<1x4xi1> -> tensor<1x3xf32>
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
@@ -33,7 +56,69 @@ module attributes {transform.with_named_sequence} {
- // -----
+// -----
+// Identical to the above, but with scalable vectors.
+func.func @masked_static_vectorize_nd_tensor_extract_with_affine_apply_contiguous_scalable(
+    %src: tensor<80x16xf32>,
+    %output : tensor<1x3xf32>,
+    %idx: index) -> tensor<1x3xf32> {
+  %c79 = arith.constant 79 : index
+  %1 = linalg.generic {
+    indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>],
+    iterator_types = ["parallel", "parallel"]
+  } outs(%output : tensor<1x3xf32>) {
+  ^bb0(%out: f32):
+    %2 = linalg.index 1 : index
+    %3 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%2, %idx)
+    %extracted = tensor.extract %src[%c79, %3] : tensor<80x16xf32>
+    linalg.yield %extracted : f32
+  } -> tensor<1x3xf32>
+  return %1 : tensor<1x3xf32>
+// CHECK-LABEL:   func.func @masked_static_vectorize_nd_tensor_extract_with_affine_apply_contiguous_scalable
+// CHECK-SAME:      %[[SRC:.*]]: tensor<80x16xf32>,
+// CHECK-SAME:      %[[OUTPUT:.*]]: tensor<1x3xf32>,
+// CHECK-SAME:      %[[IDX_IN:.*]]: index) -> tensor<1x3xf32> {
+/// Create the mask
+// CHECK-DAG:       %[[DIM_0:.*]] = arith.constant 1 : index
+// CHECK-DAG:       %[[DIM_1:.*]] = arith.constant 3 : index
+// CHECK-DAG:       %[[C79:.*]] = arith.constant 79 : index
+// CHECK:           %[[MASK:.*]] = vector.create_mask %[[DIM_0]], %[[DIM_1]] : vector<1x[4]xi1>
+/// TODO: This transfer_read is redundant - remove
+// CHECK:           vector.mask %[[MASK]] { vector.transfer_read {{.*}} {in_bounds = [true, true]} : tensor<1x3xf32>, vector<1x[4]xf32> } : vector<1x[4]xi1> -> vector<1x[4]xf32>
+/// Caluclate the index vector
+// CHECK:           %[[STEP:.*]] = vector.step : vector<[4]xindex>
+// CHECK:           %[[IDX_BC:.*]] = vector.broadcast %[[IDX_IN]] : index to vector<[4]xindex>
+// CHECK:           %[[IDX_VEC:.*]] = arith.addi %[[STEP]], %[[IDX_BC]] : vector<[4]xindex>
+// CHECK:           %[[C0:.*]] = arith.constant 0 : i32
+// CHECK:           %[[SC:.*]] = vector.shape_cast %[[IDX_VEC]] : vector<[4]xindex> to vector<[4]xindex>
+/// Extract the starting point from the index vector
+// CHECK:           %[[IDX_START:.*]] = vector.extractelement %[[SC]]{{\[}}%[[C0]] : i32] : vector<[4]xindex>
+// Final read and write
+// CHECK:           %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[SRC]]{{\[}}%[[C79]], %[[IDX_START]]], {{.*}} {in_bounds = [true, true]} : tensor<80x16xf32>, vector<1x[4]xf32> } : vector<1x[4]xi1> -> vector<1x[4]xf32>
+// CHECK:           %[[C0_1:.*]] = arith.constant 0 : index
+// CHECK:           vector.mask %[[MASK]] { vector.transfer_write %[[READ]], %[[OUTPUT]]{{\[}}%[[C0_1]], %[[C0_1]]] {in_bounds = [true, true]} : vector<1x[4]xf32>, tensor<1x3xf32> } : vector<1x[4]xi1> -> tensor<1x3xf32>
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+     %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+     transform.structured.vectorize %0 vector_sizes [1, [4]] {vectorize_nd_extract} : !transform.any_op
+     transform.yield
+   }
+// -----
 func.func @masked_dynamic_vectorize_nd_tensor_extract_with_affine_apply_contiguous(%6: tensor<?x?xf32>, %arg0: index, %extracted_slice : tensor<?x?xf32>) -> tensor<?x?xf32> {
   %c79 = arith.constant 79 : index


