[Mlir-commits] [mlir] [mlir][linalg] Relax scalable vectorization restrictions (PR #117991)

Andrzej WarzyƄski llvmlistbot at llvm.org
Thu Nov 28 02:52:47 PST 2024


https://github.com/banach-space created https://github.com/llvm/llvm-project/pull/117991

Currently, the Linalg vectorizer disallows non-trailing parallel
dimensions to be scalable, e.g., `vector_sizes [[8], 1]` (*), for cases
like:

```mlir
%0 = linalg.fill ins(%arg0 : f32) outs(%A : tensor<?x?xf32>) -> tensor<?x?xf32>
```

This restriction exists to avoid generating "scalable" arrays of
aggregates, which LLVM does not support (multi-dim vectors are lowered
into arrays of aggregates at the LLVM level).

This patch relaxes that restriction when the trailing parallel vector
dimension is `1`, e.g., for `vector_sizes [[8], 1]`. Such cases are safe
since trailing unit dimensions can be collapsed. This relaxation is
necessary to support scalable vectorization for tensor.pack, where inner
tile sizes are `[8]` (scalable) and `1` (scalar).

(*) Transform Dialect notation


>From 687cdfd84fc4d0494e23a262f6443352c18d910a Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Thu, 28 Nov 2024 10:39:25 +0000
Subject: [PATCH] [mlir][linalg] Relax scalable vectorization restrictions

Currently, the Linalg vectorizer disallows non-trailing parallel
dimensions to be scalable, e.g., `vector_sizes [[8], 1]` (*), for cases
like:

```mlir
%0 = linalg.fill ins(%arg0 : f32) outs(%A : tensor<?x?xf32>) -> tensor<?x?xf32>
```

This restriction exists to avoid generating "scalable" arrays of
aggregates, which LLVM does not support (multi-dim vectors are lowered
into arrays of aggregates at the LLVM level).

This patch relaxes that restriction when the trailing parallel vector
dimension is `1`, e.g., for `vector_sizes [[8], 1]`. Such cases are safe
since trailing unit dimensions can be collapsed. This relaxation is
necessary to support scalable vectorization for tensor.pack, where inner
tile sizes are `[8]` (scalable) and `1` (scalar).

(*) Transform Dialect notation
---
 .../Linalg/Transforms/Vectorization.cpp       | 27 ++++++++++++-------
 .../Linalg/vectorization-scalable.mlir        | 17 +++++++-----
 2 files changed, 29 insertions(+), 15 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 06bb6c0fb1cac9..f3fffbef67dc71 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -2022,26 +2022,35 @@ vectorizeScalableVectorPrecondition(Operation *op,
 
   // Cond 3: Look at the configuration in `inputScalableVecDims` and verify that
   // it matches one of the supported cases:
-  //  1. exactly 1 dim is scalable and that's the _last_ parallel dim
-  //  2. exactly 2 dims are scalable and those are the _last two adjacent_
-  //     parallel dims
-  //  3. exactly 1 reduction dim is scalable and that's the last (innermost) dim
+  //  1. Exactly 1 dim is scalable and that's the _last_ non-unit parallel dim
+  //    (*).
+  //  2. Exactly 2 dims are scalable and those are the _last two adjacent_
+  //     parallel dims.
+  //  3. Exactly 1 reduction dim is scalable and that's the last (innermost) dim.
   // The 2nd restriction above means that only Matmul-like Ops are supported
   // when 2 dims are scalable, e.g. :
   //    * iterators = [parallel, parallel, reduction]
   //    * scalable flags = [true, true, false]
+  //
+  // (*) Non-unit dims get folded away in practice.
+  // TODO: Relax these conditions as good motivating examples are identified.
 
-  // Find the first scalable flag
-  bool seenParalell = false;
+  // Find the first scalable flag, and ...
+  bool seenNonUnitParallel = false;
   auto iterators = linalgOp.getIteratorTypesArray();
   SmallVector<bool> scalableFlags(inputScalableVecDims);
-  while (!scalableFlags.back()) {
-    seenParalell |= (iterators.back() == utils::IteratorType::parallel);
+  int64_t idx = scalableFlags.size() - 1;
+  while (!scalableFlags[idx]) {
+    bool isNonUnitDim = (inputVectorSizes[idx] != 1);
+    seenNonUnitParallel |=
+        (iterators[idx] == utils::IteratorType::parallel && isNonUnitDim);
 
     iterators.pop_back();
     scalableFlags.pop_back();
+    idx--;
   }
 
+  // ... analyze the corresponding iterator.
   switch (iterators.back()) {
   case utils::IteratorType::reduction: {
     // Check 3. above is met.
@@ -2059,7 +2068,7 @@ vectorizeScalableVectorPrecondition(Operation *op,
   }
   case utils::IteratorType::parallel: {
     // Check 1. and 2. above are met.
-    if (seenParalell) {
+    if (seenNonUnitParallel) {
       LDBG("Inner parallel dim not requested for scalable "
            "vectorization\n");
       return failure();
diff --git a/mlir/test/Dialect/Linalg/vectorization-scalable.mlir b/mlir/test/Dialect/Linalg/vectorization-scalable.mlir
index 68bac72a1465d0..227829238a3d79 100644
--- a/mlir/test/Dialect/Linalg/vectorization-scalable.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization-scalable.mlir
@@ -122,22 +122,27 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
-func.func @vectorize_dynamic_fill(%A : tensor<?x?xf32>, %arg0 : f32) -> tensor<?x?xf32> {
+// NOTE: Often, non-trailing scalable sizes are problematic - there are no
+// "scalable" arrays of vectors at the LLVM level (multi-dim vectors are
+// decomposed into arrays of aggregates). However, the trailing dim in this
+// case is 1 and that can be folded away later.
+
+func.func @vectorize_dynamic_fill_leading_scalable(%A : tensor<?x?xf32>, %arg0 : f32) -> tensor<?x?xf32> {
   %0 = linalg.fill ins(%arg0 : f32) outs(%A : tensor<?x?xf32>) -> tensor<?x?xf32>
   return %0 : tensor<?x?xf32>
 }
 
-// CHECK-LABEL: func.func @vectorize_dynamic_fill
+// CHECK-LABEL: func.func @vectorize_dynamic_fill_leading_scalable
 //   CHECK: %[[DIM0:.*]] = tensor.dim
 //   CHECK: %[[DIM1:.*]] = tensor.dim
-//   CHECK: %[[MASK:.*]] = vector.create_mask %[[DIM0]], %[[DIM1]] : vector<8x[16]xi1>
-//   CHECK: %[[BCAST:.*]] = vector.broadcast %{{.*}} : f32 to vector<8x[16]xf32>
-//   CHECK: vector.mask %[[MASK]] { vector.transfer_write %[[BCAST]], {{.*}} {in_bounds = [true, true]} : vector<8x[16]xf32>, tensor<?x?xf32> } : vector<8x[16]xi1>
+//   CHECK: %[[MASK:.*]] = vector.create_mask %[[DIM0]], %[[DIM1]] : vector<[8]x1xi1>
+//   CHECK: %[[BCAST:.*]] = vector.broadcast %{{.*}} : f32 to vector<[8]x1xf32>
+//   CHECK: vector.mask %[[MASK]] { vector.transfer_write %[[BCAST]], {{.*}} {in_bounds = [true, true]} : vector<[8]x1xf32>, tensor<?x?xf32> } : vector<[8]x1xi1>
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    transform.structured.vectorize %0 vector_sizes [8, [16]] : !transform.any_op
+    transform.structured.vectorize %0 vector_sizes [[8], 1] : !transform.any_op
     transform.yield
   }
 }



More information about the Mlir-commits mailing list