[Mlir-commits] [mlir] Reapply "[mlir][linalg] Relax tensor.extract vectorization" (#102232) (PR #102321)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Aug 7 08:02:12 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Andrzej WarzyƄski (banach-space)

<details>
<summary>Changes</summary>

[This reverts commit 6662523d6b2ca0198141c94ee80ebbb41601df9f]

Simplifies the vectorization of tensor.extract so that:
* all cases that read into a genuinely multi-dim vector (*) are
  considered a gather load,
* all other cases are considered as potential contiguous loads.

This change means that the following extraction from a "column" tensor
is correctly identified as a scalar load followed by a broadcast (rather
than a gather load).

```mlir
func.func @<!-- -->vectorize_scalar_broadcast_column_tensor(%in: tensor<1x1x4xi32>) -> tensor<1x1x4xi32> {
  %c4 = arith.constant 4 : index
  %c0 = arith.constant 0 : index
  %cst = arith.constant dense<[...]> : tensor<15x1xi32>

  %out = linalg.generic {
    indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
    iterator_types = ["parallel", "parallel", "parallel"]}
    outs(%in : tensor<1x1x4xi32>) {

  ^bb0(%out: i32):
    %8 = linalg.index 0 : index
    %idx_0 = linalg.index 0 : index
    %extracted = tensor.extract %cst[%idx_0, %c0] : tensor<15x1xi32>
    linalg.yield %extracted : i32
  } -> tensor<1x1x4xi32>

  return %out:tensor<1x1x4xi32>
}
```

Overview of the delta when compared to the original submission:
  * removed an assert representing a conditon that is being relaxed
    here,
  * added a test (reading from a column tensor) based on a repro from
    @<!-- -->hanhanW.

(*) `vector<1x4x1xf32>` is considered as 1D vector in this context.


---
Full diff: https://github.com/llvm/llvm-project/pull/102321.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+16-25) 
- (modified) mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir (+105-12) 


``````````diff
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 3d0d6abf702d7..63dcda78d0f2b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -814,11 +814,9 @@ enum VectorMemoryAccessKind { ScalarBroadcast, Contiguous, Gather };
 static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val) {
 
   auto targetShape = linalgOp.getStaticLoopRanges();
-  assert(((llvm::count_if(targetShape,
-                          [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
+  assert(llvm::count_if(targetShape,
+                        [](int64_t dimSize) { return dimSize > 1; }) == 1 &&
          "n-D vectors are not yet supported");
-  assert(targetShape.back() != 1 &&
-         "1-D vectors with the trailing dim eqaual 1 are not yet supported");
 
   // Blocks outside _this_ linalg.generic are effectively loop invariant.
   // However, analysing block arguments for _this_ linalg.generic Op is a bit
@@ -879,8 +877,6 @@ static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val,
   assert(((llvm::count_if(targetShape,
                           [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
          "n-D vectors are not yet supported");
-  assert(targetShape.back() != 1 &&
-         "1-D vectors with the trailing dim 1 are not yet supported");
 
   // Blocks outside _this_ linalg.generic are effectively loop invariant.
   // However, analysing block arguments for _this_ linalg.generic Op is a bit
@@ -946,27 +942,22 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
   if (linalgOp.hasDynamicShape())
     return VectorMemoryAccessKind::Gather;
 
-  // 1. Assume that it's a gather load when reading _into_:
-  //    * an n-D "vector", like `tensor<1x2x4xi32` or `tensor<2x1x4xi32>`, or
-  //    * a 1-D "vector" with the trailing dim equal 1, e.g. `tensor<1x4x1xi32`.
-  // TODO: Relax these conditions.
-  // FIXME: This condition assumes non-dynamic sizes.
-  if ((llvm::count_if(targetShape,
-                      [](int64_t dimSize) { return dimSize > 1; }) != 1) ||
-      targetShape.back() == 1)
-    return VectorMemoryAccessKind::Gather;
+  // True for vectors that are effectively 1D, e.g. `vector<1x4x1xi32>`, false
+  // otherwise.
+  bool isOutput1DVector = (llvm::count_if(targetShape, [](int64_t dimSize) {
+                             return dimSize > 1;
+                           }) == 1);
 
-  // 2. Assume that it's a gather load when reading _from_ a tensor for which
-  // the trailing dimension is 1, e.g. `tensor<1x4x1xi32>`.
-  // TODO: Relax this condition.
-  if (inputShape.getShape().back() == 1)
+  // 1. Assume that it's a gather load when reading non-1D vector.
+  if (!isOutput1DVector)
     return VectorMemoryAccessKind::Gather;
 
   bool leadingIdxsLoopInvariant = true;
 
-  // 3. Analyze the leading indices of `extractOp`.
+  // 2. Analyze the leading indices of `extractOp`.
   // Look at the way each index is calculated and decide whether it is suitable
-  // for a contiguous load, i.e. whether it's loop invariant.
+  // for a contiguous load, i.e. whether it's loop invariant. If not, it's a
+  // gather load.
   auto indices = extractOp.getIndices();
   auto leadIndices = indices.drop_back(1);
 
@@ -982,13 +973,13 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
     return VectorMemoryAccessKind::Gather;
   }
 
-  // 4. Analyze the trailing index for `extractOp`.
+  // 3. Analyze the trailing index for `extractOp`.
   // At this point we know that the leading indices are loop invariant. This
   // means that is potentially a scalar or a contiguous load. We can decide
   // based on the trailing idx.
   auto extractOpTrailingIdx = indices.back();
 
-  // 4a. Scalar broadcast load
+  // 3a. Scalar broadcast load
   // If the trailing index is loop invariant then this is a scalar load.
   if (leadingIdxsLoopInvariant &&
       isLoopInvariantIdx(linalgOp, extractOpTrailingIdx)) {
@@ -997,7 +988,7 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
     return VectorMemoryAccessKind::ScalarBroadcast;
   }
 
-  // 4b. Contiguous loads
+  // 3b. Contiguous loads
   // The trailing `extractOp` index should increment with every loop iteration.
   // This effectively means that it must be based on the trailing loop index.
   // This is what the following bool captures.
@@ -1011,7 +1002,7 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
     return VectorMemoryAccessKind::Contiguous;
   }
 
-  // 5. Fallback case - gather load.
+  // 4. Fallback case - gather load.
   LDBG("Found gather load: " << extractOp);
   return VectorMemoryAccessKind::Gather;
 }
diff --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
index 85e1c56dd45a0..bdaa20c3bf971 100644
--- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
@@ -37,6 +37,7 @@ module attributes {transform.with_named_sequence} {
 }
 
 // -----
+
 #map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
 func.func @vectorize_nd_tensor_extract_constant_idx(%arg0: tensor<3x3xf32>, %arg2: tensor<1x1x3xf32>) -> tensor<1x1x3xf32> {
   %c0 = arith.constant 1 : index
@@ -74,20 +75,24 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
-#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-func.func @vectorize_nd_tensor_extract_transfer_read_basic(%arg0: tensor<3x3x3xf32>, %arg2: tensor<1x1x3xf32>) -> tensor<1x1x3xf32> {
-  %1 = linalg.generic {
-    indexing_maps = [#map1],
+#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+func.func @vectorize_nd_tensor_extract_transfer_read_basic(
+    %arg0: tensor<3x3x3xf32>,
+    %arg1: tensor<1x1x3xf32>) -> tensor<1x1x3xf32> {
+
+  %res = linalg.generic {
+    indexing_maps = [#map],
     iterator_types = ["parallel", "parallel", "parallel"]
-  } outs(%arg2 : tensor<1x1x3xf32>) {
-  ^bb0(%arg4: f32):
-    %2 = linalg.index 0 : index
-    %3 = linalg.index 1 : index
-    %4 = linalg.index 2 : index
-    %5 = tensor.extract %arg0[%2, %3, %4] : tensor<3x3x3xf32>
-    linalg.yield %5 : f32
+  } outs(%arg1 : tensor<1x1x3xf32>) {
+  ^bb0(%out: f32):
+    %1 = linalg.index 0 : index
+    %2 = linalg.index 1 : index
+    %3 = linalg.index 2 : index
+    %4 = tensor.extract %arg0[%1, %2, %3] : tensor<3x3x3xf32>
+    linalg.yield %4 : f32
   } -> tensor<1x1x3xf32>
-  return %1 : tensor<1x1x3xf32>
+
+  return %res : tensor<1x1x3xf32>
 }
 
 // CHECK-LABEL: func.func @vectorize_nd_tensor_extract_transfer_read_basic
@@ -104,6 +109,38 @@ func.func @vectorize_nd_tensor_extract_transfer_read_basic(%arg0: tensor<3x3x3xf
 // CHECK:   %[[READ:.*]] = vector.transfer_read %[[ARG0]][%[[IDX1]], %[[IDX2]], %[[C0:.*]]], %[[CST_0]] {in_bounds = [true, true, true]} : tensor<3x3x3xf32>, vector<1x1x3xf32>
 // CHECK:   vector.transfer_write %[[READ]], %[[ARG1]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x1x3xf32>, tensor<1x1x3xf32>
 
+// Same as example above, but reading into a column tensor. Note that after the
+// vectorizatoin, the `TransferOpReduceRank` will replace
+// `vector.transfer_read` with `tensor.extract -> scalar`.
+
+// TODO: Currently this fails to vectorise when the indices are non-constant.
+
+func.func @vectorize_nd_tensor_extract_transfer_read_basic_column(
+    %input: tensor<3x3x3xf32>,
+    %output: tensor<3x1x1xf32>) -> tensor<3x1x1xf32> {
+
+  %c0 = arith.constant 0 : index
+  %res = linalg.generic {
+    indexing_maps = [#map],
+    iterator_types = ["parallel", "parallel", "parallel"]
+  } outs(%output : tensor<3x1x1xf32>) {
+  ^bb0(%out: f32):
+    %5 = tensor.extract %input[%c0, %c0, %c0] : tensor<3x3x3xf32>
+    linalg.yield %5 : f32
+  } -> tensor<3x1x1xf32>
+
+  return %res : tensor<3x1x1xf32>
+}
+
+// CHECK-LABEL:   func.func @vectorize_nd_tensor_extract_transfer_read_basic_column(
+// CHECK-SAME:      %[[INPUT:.*]]: tensor<3x3x3xf32>,
+// CHECK-SAME:      %[[OUTPUT:.*]]: tensor<3x1x1xf32>)
+// CHECK:           %[[C0:.*]] = arith.constant 0 : index
+// CHECK:           %[[EXTRACT:.*]] = tensor.extract %[[INPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] : tensor<3x3x3xf32>
+// CHECK:           %[[BCAST:.*]] = vector.broadcast %[[EXTRACT]] : f32 to vector<3x1x1xf32>
+// CHECK:           %[[RES:.*]] = vector.transfer_write %[[BCAST]], %[[OUTPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<3x1x1xf32>, tensor<3x1x1xf32>
+// CHECK:           return %[[RES]] : tensor<3x1x1xf32>
+
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
@@ -595,3 +632,59 @@ module attributes {transform.with_named_sequence} {
      transform.yield
    }
 }
+
+
+// -----
+
+func.func @vectorize_scalar_broadcast_column_tensor(%in: tensor<1x1x4xi32>) -> tensor<1x1x4xi32> {
+  %c4 = arith.constant 4 : index
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant dense<[[0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10], [11], [12], [13], [14]]> : tensor<15x1xi32>
+
+  %out = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} outs(%in : tensor<1x1x4xi32>) {
+  ^bb0(%out: i32):
+    %8 = linalg.index 0 : index
+    %idx_0 = linalg.index 0 : index
+    %extracted = tensor.extract %cst[%idx_0, %c0] : tensor<15x1xi32>
+    linalg.yield %extracted : i32
+  } -> tensor<1x1x4xi32>
+
+  return %out:tensor<1x1x4xi32>
+}
+
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1) -> (0, 0, 0)>
+// CHECK-LABEL:   func.func @vectorize_scalar_broadcast_column_tensor(
+// CHECK-SAME:      %[[VAL_0:.*]]: tensor<1x1x4xi32>) -> tensor<1x1x4xi32> {
+// CHECK:           %[[VAL_1:.*]] = arith.constant 4 : index
+// CHECK:           %[[VAL_2:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_3:.*]] = arith.constant dense<{{\[\[}}0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10], [11], [12], [13], [14]]> : tensor<15x1xi32>
+// CHECK:           %[[VAL_4:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_5:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_6:.*]] = arith.constant 4 : index
+// CHECK:           %[[VAL_7:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_8:.*]] = arith.constant 0 : i32
+// CHECK:           %[[VAL_9:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_7]], %[[VAL_7]], %[[VAL_7]]], %[[VAL_8]] : tensor<1x1x4xi32>, vector<1x1x4xi32>
+// CHECK:           %[[VAL_10:.*]] = vector.step : vector<1xindex>
+// CHECK:           %[[VAL_11:.*]] = vector.broadcast %[[VAL_10]] : vector<1xindex> to vector<4x1x1xindex>
+// CHECK:           %[[VAL_12:.*]] = vector.transpose %[[VAL_11]], [2, 1, 0] : vector<4x1x1xindex> to vector<1x1x4xindex>
+// CHECK:           %[[VAL_13:.*]] = vector.step : vector<1xindex>
+// CHECK:           %[[VAL_14:.*]] = vector.broadcast %[[VAL_13]] : vector<1xindex> to vector<4x1x1xindex>
+// CHECK:           %[[VAL_15:.*]] = vector.transpose %[[VAL_14]], [2, 1, 0] : vector<4x1x1xindex> to vector<1x1x4xindex>
+// CHECK:           %[[VAL_16:.*]] = arith.constant dense<true> : vector<1x1x4xi1>
+// CHECK:           %[[VAL_17:.*]] = arith.constant dense<0> : vector<1x1x4xi32>
+// CHECK:           %[[VAL_18:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_19:.*]] = arith.constant 0 : i32
+// CHECK:           %[[VAL_20:.*]] = vector.shape_cast %[[VAL_15]] : vector<1x1x4xindex> to vector<4xindex>
+// CHECK:           %[[VAL_21:.*]] = vector.extractelement %[[VAL_20]]{{\[}}%[[VAL_19]] : i32] : vector<4xindex>
+// CHECK:           %[[VAL_22:.*]] = arith.constant 0 : i32
+// CHECK:           %[[VAL_23:.*]] = vector.transfer_read %[[VAL_3]]{{\[}}%[[VAL_21]], %[[VAL_2]]], %[[VAL_22]] {in_bounds = [true, true, true], permutation_map = #[[$ATTR_1]]} : tensor<15x1xi32>, vector<1x1x4xi32>
+// CHECK:           %[[VAL_24:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_25:.*]] = vector.transfer_write %[[VAL_23]], %[[VAL_0]]{{\[}}%[[VAL_24]], %[[VAL_24]], %[[VAL_24]]] : vector<1x1x4xi32>, tensor<1x1x4xi32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0  vector_sizes [1, 1, 4]{ vectorize_nd_extract } : !transform.any_op
+    transform.yield
+  }
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/102321


More information about the Mlir-commits mailing list