[Mlir-commits] [mlir] 5f6c036 - [mlir][linalg] Extend scalable vectorisation support

Andrzej Warzynski llvmlistbot at llvm.org
Tue Aug 22 02:42:06 PDT 2023


Author: Andrzej Warzynski
Date: 2023-08-22T09:40:34Z
New Revision: 5f6c03636d22029550c8ab94388827e728a0c5fe

URL: https://github.com/llvm/llvm-project/commit/5f6c03636d22029550c8ab94388827e728a0c5fe
DIFF: https://github.com/llvm/llvm-project/commit/5f6c03636d22029550c8ab94388827e728a0c5fe.diff

LOG: [mlir][linalg] Extend scalable vectorisation support

This patch simply removes one of the pre-conditions of scalable
vectorisation. Namely, that only the trailing vector dimension can be
scalable, e.g.

  vector<2x4x[8]xi32>

This limitation can be lifted following the recent work to support
scalable vectorisation in Linalg, most notably:
  * https://reviews.llvm.org/D154336
  * https://reviews.llvm.org/D153372

A test is added that demonstrates scalable vectorisation of
`linalg.matmul`. This is simply a copy of a similar test for fixed-width
vectorisation. The latter is updated - check lines are simplified and
annotated with regex patterns. This in the spirit of [1]:

  Tests should be minimal, and only check what is absolutely necessary.

Note that this change is just one step towards scalable vectorisation in
Linalg. It will allow us to exercise new code paths in the context of
scalable vectors in Linalg and hence make further progress in the
forthcoming patches.

[1] https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices

Differential Revision: https://reviews.llvm.org/D158423

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/test/Dialect/Linalg/vectorization-masked.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 84eea84bd10046..fdb697cf1167af 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -169,21 +169,6 @@ static Value insertConvResultSlices(RewriterBase &rewriter, Location loc,
   return res;
 }
 
-/// Return true if the scalable vector dimensions are supported. For now, we
-/// only support scalable vectors in the trailing dimension.
-static bool areValidScalableVecDims(ArrayRef<bool> scalableVecDims) {
-  if (scalableVecDims.empty())
-    return true;
-
-  auto isScalable = [](bool isScalableVecSize) { return isScalableVecSize; };
-  if (std::any_of(scalableVecDims.begin(), scalableVecDims.end() - 1,
-                  isScalable)) {
-    return false;
-  }
-
-  return true;
-}
-
 /// Contains the vectorization state and related methods used across the
 /// vectorization process of a given operation.
 struct VectorizationState {
@@ -217,12 +202,6 @@ struct VectorizationState {
       scalableDims.append(scalableVecDims.begin(), scalableVecDims.end());
     }
 
-    // Make sure we don't end up with unsupported scalable vector dimensions
-    // after the permutation. If so, we should bail out on that operation in the
-    // scalable preconditions.
-    assert(areValidScalableVecDims(scalableDims) &&
-           "Permuted scalable vector dimensions are not supported");
-
     return VectorType::get(vectorShape, elementType, scalableDims);
   }
 
@@ -1630,11 +1609,6 @@ vectorizeScalableVectorPrecondition(Operation *op,
   if (inputVectorSizes.empty())
     return success();
 
-  if (!areValidScalableVecDims(inputScalableVecDims)) {
-    LDBG("Non-trailing scalable vector dimensions are not supported\n");
-    return failure();
-  }
-
   bool isScalable = inputScalableVecDims.back();
   if (!isScalable)
     return success();

diff  --git a/mlir/test/Dialect/Linalg/vectorization-masked.mlir b/mlir/test/Dialect/Linalg/vectorization-masked.mlir
index fc7749aaa36482..82e8dfe37f7999 100644
--- a/mlir/test/Dialect/Linalg/vectorization-masked.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization-masked.mlir
@@ -447,41 +447,68 @@ transform.sequence failures(propagate) {
 
 // -----
 
-func.func @vectorize_dynamic_matmul(%A: memref<?x?xf32>, %B: memref<?x?xf32>, %C: memref<?x?xf32>) {
+func.func @matmul(%A: memref<?x?xf32>, %B: memref<?x?xf32>, %C: memref<?x?xf32>) {
   linalg.matmul ins(%A, %B: memref<?x?xf32>, memref<?x?xf32>)
             outs(%C: memref<?x?xf32>)
   return
 }
 
-// CHECK-LABEL:   func.func @vectorize_dynamic_matmul(
-// CHECK-SAME:      %[[VAL_0:.*]]: memref<?x?xf32>, %[[VAL_1:.*]]: memref<?x?xf32>, %[[VAL_2:.*]]: memref<?x?xf32>) {
+// CHECK-LABEL:   func.func @matmul(
+// CHECK-SAME:      %[[A:.*]]: memref<?x?xf32>, %[[B:.*]]: memref<?x?xf32>, %[[C:.*]]: memref<?x?xf32>) {
 // CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 0 : index
-// CHECK-DAG:       %[[VAL_4:.*]] = memref.dim %[[VAL_0]], %[[VAL_3]] : memref<?x?xf32>
+// CHECK-DAG:       %[[VAL_4:.*]] = memref.dim %[[A]], %[[VAL_3]] : memref<?x?xf32>
 // CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 1 : index
-// CHECK-DAG:       %[[VAL_6:.*]] = memref.dim %[[VAL_1]], %[[VAL_5]] : memref<?x?xf32>
+// CHECK-DAG:       %[[VAL_6:.*]] = memref.dim %[[B]], %[[VAL_5]] : memref<?x?xf32>
 // CHECK-DAG:       %[[VAL_7:.*]] = arith.constant 1 : index
-// CHECK-DAG:       %[[VAL_8:.*]] = memref.dim %[[VAL_0]], %[[VAL_7]] : memref<?x?xf32>
-// CHECK-DAG:       %[[VAL_9:.*]] = arith.constant 0 : index
-// CHECK-DAG:       %[[VAL_10:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK:           %[[VAL_11:.*]] = vector.create_mask %[[VAL_4]], %[[VAL_8]] : vector<8x4xi1>
-// CHECK:           %[[VAL_12:.*]] = vector.mask %[[VAL_11]] { vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_9]], %[[VAL_9]]], %[[VAL_10]] {in_bounds = [true, true, true], permutation_map = #map} : memref<?x?xf32>, vector<8x16x4xf32> } : vector<8x4xi1> -> vector<8x16x4xf32>
-// CHECK:           %[[VAL_13:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK:           %[[VAL_14:.*]] = vector.create_mask %[[VAL_8]], %[[VAL_6]] : vector<4x16xi1>
-// CHECK:           %[[VAL_15:.*]] = vector.mask %[[VAL_14]] { vector.transfer_read %[[VAL_1]]{{\[}}%[[VAL_9]], %[[VAL_9]]], %[[VAL_13]] {in_bounds = [true, true, true], permutation_map = #map1} : memref<?x?xf32>, vector<8x16x4xf32> } : vector<4x16xi1> -> vector<8x16x4xf32>
-// CHECK:           %[[VAL_16:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK:           %[[VAL_17:.*]] = vector.create_mask %[[VAL_4]], %[[VAL_6]] : vector<8x16xi1>
-// CHECK:           %[[VAL_18:.*]] = vector.mask %[[VAL_17]] { vector.transfer_read %[[VAL_2]]{{\[}}%[[VAL_9]], %[[VAL_9]]], %[[VAL_16]] {in_bounds = [true, true]} : memref<?x?xf32>, vector<8x16xf32> } : vector<8x16xi1> -> vector<8x16xf32>
-// CHECK:           %[[VAL_19:.*]] = arith.mulf %[[VAL_12]], %[[VAL_15]] : vector<8x16x4xf32>
-// CHECK:           %[[VAL_20:.*]] = vector.create_mask %[[VAL_4]], %[[VAL_6]], %[[VAL_8]] : vector<8x16x4xi1>
-// CHECK:           %[[VAL_21:.*]] = vector.mask %[[VAL_20]] { vector.multi_reduction <add>, %[[VAL_19]], %[[VAL_18]] [2] : vector<8x16x4xf32> to vector<8x16xf32> } : vector<8x16x4xi1> -> vector<8x16xf32>
-// CHECK:           %[[VAL_22:.*]] = arith.constant 0 : index
-// CHECK:           vector.mask %[[VAL_17]] { vector.transfer_write %[[VAL_21]], %[[VAL_2]]{{\[}}%[[VAL_22]], %[[VAL_22]]] {in_bounds = [true, true]} : vector<8x16xf32>, memref<?x?xf32> } : vector<8x16xi1>
-// CHECK:           return
-// CHECK:         }
+// CHECK-DAG:       %[[VAL_8:.*]] = memref.dim %[[A]], %[[VAL_7]] : memref<?x?xf32>
+// CHECK:           %[[MASK_A:.*]] = vector.create_mask %[[VAL_4]], %[[VAL_8]] : vector<8x4xi1>
+// CHECK:           %[[LOAD_A:.*]] = vector.mask %[[MASK_A]] { vector.transfer_read %[[A]]{{\[}}%{{.*}}, %{{.*}}], %{{.*}} {in_bounds = [true, true, true], permutation_map = #{{.*}}} : memref<?x?xf32>, vector<8x16x4xf32> } : vector<8x4xi1> -> vector<8x16x4xf32>
+// CHECK:           %[[MASK_B:.*]] = vector.create_mask %[[VAL_8]], %[[VAL_6]] : vector<4x16xi1>
+// CHECK:           %[[LOAD_B:.*]] = vector.mask %[[MASK_B]] { vector.transfer_read %[[B]]{{\[}}%{{.*}}, %{{.*}}], %{{.*}} {in_bounds = [true, true, true], permutation_map = #{{.*}}} : memref<?x?xf32>, vector<8x16x4xf32> } : vector<4x16xi1> -> vector<8x16x4xf32>
+// CHECK:           %[[MASK_C:.*]] = vector.create_mask %[[VAL_4]], %[[VAL_6]] : vector<8x16xi1>
+// CHECK:           %[[LOAD_C:.*]] = vector.mask %[[MASK_C]] { vector.transfer_read %[[C]]{{\[}}%{{.*}}, %{{.*}}], %{{.*}} {in_bounds = [true, true]} : memref<?x?xf32>, vector<8x16xf32> } : vector<8x16xi1> -> vector<8x16xf32>
+// CHECK:           %[[MULF:.*]] = arith.mulf %[[LOAD_A]], %[[LOAD_B]] : vector<8x16x4xf32>
+// CHECK:           %[[MASK_MULIT_RED:.*]] = vector.create_mask %[[VAL_4]], %[[VAL_6]], %[[VAL_8]] : vector<8x16x4xi1>
+// CHECK:           %[[MULTI_RED:.*]] = vector.mask %[[MASK_MULIT_RED]] { vector.multi_reduction <add>, %[[MULF]], %[[LOAD_C]] [2] : vector<8x16x4xf32> to vector<8x16xf32> } : vector<8x16x4xi1> -> vector<8x16xf32>
+// CHECK:           %[[C2:.*]] = arith.constant 0 : index
+// CHECK:           vector.mask %[[MASK_C]] { vector.transfer_write %[[MULTI_RED]], %[[C]]{{\[}}%[[C2]], %[[C2]]] {in_bounds = [true, true]} : vector<8x16xf32>, memref<?x?xf32> } : vector<8x16xi1>
 
 transform.sequence failures(propagate) {
 ^bb1(%arg1: !transform.any_op):
-  %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-  transform.structured.masked_vectorize %0 vector_sizes [8, 16, 4] : !transform.any_op
+  %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+  transform.structured.masked_vectorize %matmul vector_sizes [8, 16, 4] : !transform.any_op
+}
+
+// -----
+
+func.func @matmul_scalable(%A: memref<?x?xf32>, %B: memref<?x?xf32>, %C: memref<?x?xf32>) {
+  linalg.matmul ins(%A, %B: memref<?x?xf32>, memref<?x?xf32>)
+            outs(%C: memref<?x?xf32>)
+  return
 }
 
+// CHECK-LABEL:   func.func @matmul_scalable(
+// CHECK-SAME:      %[[A:.*]]: memref<?x?xf32>, %[[B:.*]]: memref<?x?xf32>, %[[C:.*]]: memref<?x?xf32>) {
+// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[VAL_4:.*]] = memref.dim %[[A]], %[[VAL_3]] : memref<?x?xf32>
+// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 1 : index
+// CHECK-DAG:       %[[VAL_6:.*]] = memref.dim %[[B]], %[[VAL_5]] : memref<?x?xf32>
+// CHECK-DAG:       %[[VAL_7:.*]] = arith.constant 1 : index
+// CHECK-DAG:       %[[VAL_8:.*]] = memref.dim %[[A]], %[[VAL_7]] : memref<?x?xf32>
+// CHECK:           %[[MASK_A:.*]] = vector.create_mask %[[VAL_4]], %[[VAL_8]] : vector<8x4xi1>
+// CHECK:           %[[LOAD_A:.*]] = vector.mask %[[MASK_A]] { vector.transfer_read %[[A]]{{\[}}%{{.*}}, %{{.*}}], %{{.*}} {in_bounds = [true, true, true], permutation_map = #{{.*}}} : memref<?x?xf32>, vector<8x[16]x4xf32> } : vector<8x4xi1> -> vector<8x[16]x4xf32>
+// CHECK:           %[[MASK_B:.*]] = vector.create_mask %[[VAL_8]], %[[VAL_6]] : vector<4x[16]xi1>
+// CHECK:           %[[LOAD_B:.*]] = vector.mask %[[MASK_B]] { vector.transfer_read %[[B]]{{\[}}%{{.*}}, %{{.*}}], %{{.*}} {in_bounds = [true, true, true], permutation_map = #{{.*}}} : memref<?x?xf32>, vector<8x[16]x4xf32> } : vector<4x[16]xi1> -> vector<8x[16]x4xf32>
+// CHECK:           %[[MASK_C:.*]] = vector.create_mask %[[VAL_4]], %[[VAL_6]] : vector<8x[16]xi1>
+// CHECK:           %[[LOAD_C:.*]] = vector.mask %[[MASK_C]] { vector.transfer_read %[[C]]{{\[}}%{{.*}}, %{{.*}}], %{{.*}} {in_bounds = [true, true]} : memref<?x?xf32>, vector<8x[16]xf32> } : vector<8x[16]xi1> -> vector<8x[16]xf32>
+// CHECK:           %[[MULF:.*]] = arith.mulf %[[LOAD_A]], %[[LOAD_B]] : vector<8x[16]x4xf32>
+// CHECK:           %[[MASK_MULIT_RED:.*]] = vector.create_mask %[[VAL_4]], %[[VAL_6]], %[[VAL_8]] : vector<8x[16]x4xi1>
+// CHECK:           %[[MULTI_RED:.*]] = vector.mask %[[MASK_MULIT_RED]] { vector.multi_reduction <add>, %[[MULF]], %[[LOAD_C]] [2] : vector<8x[16]x4xf32> to vector<8x[16]xf32> } : vector<8x[16]x4xi1> -> vector<8x[16]xf32>
+// CHECK:           %[[C2:.*]] = arith.constant 0 : index
+// CHECK:           vector.mask %[[MASK_C]] { vector.transfer_write %[[MULTI_RED]], %[[C]]{{\[}}%[[C2]], %[[C2]]] {in_bounds = [true, true]} : vector<8x[16]xf32>, memref<?x?xf32> } : vector<8x[16]xi1>
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+  transform.structured.masked_vectorize %matmul vector_sizes [8, [16], 4] : !transform.any_op
+}


        


More information about the Mlir-commits mailing list