[Mlir-commits] [mlir] [mlir][linalg] Restrict scalable vectorisation (PR #98639)

Zhaoshi Zheng llvmlistbot at llvm.org
Fri Jul 12 21:52:58 PDT 2024


================
@@ -1936,26 +1936,79 @@ vectorizePadOpPrecondition(tensor::PadOp padOp,
   return success();
 }
 
-/// Preconditions for scalable vectors.
+/// Preconditions for scalable vectors. This is quite restrictive - it models
+/// the fact that in practice we would only make selected dimensions scalable.
 static LogicalResult
 vectorizeScalableVectorPrecondition(Operation *op,
                                     ArrayRef<int64_t> inputVectorSizes,
                                     ArrayRef<bool> inputScalableVecDims) {
   assert(inputVectorSizes.size() == inputScalableVecDims.size() &&
          "Number of input vector sizes and scalable dims doesn't match");
 
-  if (inputVectorSizes.empty())
-    return success();
+  size_t numOfScalableDims =
+      llvm::count_if(inputScalableVecDims, [](bool flag) { return flag; });
 
-  bool isScalable = inputScalableVecDims.back();
-  if (!isScalable)
+  if (numOfScalableDims == 0)
     return success();
 
-  // Only element-wise and 1d depthwise conv ops supported in the presence of
-  // scalable dims.
   auto linalgOp = dyn_cast<LinalgOp>(op);
-  return success(linalgOp && (isElementwise(linalgOp) ||
-                              isa<linalg::DepthwiseConv1DNwcWcOp>(op)));
+
+  // Cond 1: There's been no need for scalable vectorisation of
+  // non-linalg Ops so far
+  if (!linalgOp)
+    return failure();
+
+  // Cond 2: There's been no need for more than 2 scalable dims so far
+  if (numOfScalableDims > 2)
+    return failure();
+
+  // Cond 3: Look at the configuration in `inputScalableVecDims` and verify that
+  // it matches one of the supported cases:
+  //  1. exactly 1 dim is scalable and that's the _last_ parallel dim
----------------
zhaoshiz wrote:

I'm reworking scalable vectorization of reduction (https://github.com/llvm/llvm-project/pull/97788) on top of this one. My goal is to allow linalg::ReduceOp and linalg::GenericOp with reduction iterators. I am testing with matvec and matmul. For now I'm restricting reduction to the last dim.

> It should be ok as long as we have a single scalable dimension, isn't it?

At MLIR level it seems ok, both vectorizing linalg and lowering vector multi-dim reduction are producing reasonable results. But I have difficulties on lowering to LLVM dialect and IR. Perhaps due to

> it would be impractical given the limitations of LLVM (which usually
> reflect the limitations of actual hardware) - e.g. no support for
> "scalable" arrays of scalable or fixed width vectors (\*).
> ...
> (\*) At MLIR vector level that would correspond to e.g.
> vector<[4]x8xf32>. 

Here's an example:

```
func.func @linalg_reduce_scalable_leading_dim(%input: tensor<?x?xf32>,
                                              %acc: tensor<?xf32>) -> tensor<?xf32> {
  %0 = linalg.reduce ins(%input : tensor<?x?xf32>) outs(%acc : tensor<?xf32>) dimensions = [0]
  (%in: f32, %init: f32) {
    %0 = arith.addf %in, %init : f32
    linalg.yield %0 : f32
  }
  return %0 : tensor<?xf32>
}

module attributes {transform.with_named_sequence} {
  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
    %0 = transform.structured.match ops{["linalg.reduce"]} in %arg1 : (!transform.any_op) -> !transform.any_op
    transform.structured.vectorize %0 vector_sizes [[4], 1] : !transform.any_op

    %func = transform.structured.match ops{["func.func"]} in %arg1
      : (!transform.any_op) -> !transform.any_op

    transform.apply_patterns to %func {
      transform.apply_patterns.vector.lower_masked_transfers
      transform.apply_patterns.vector.lower_multi_reduction lowering_strategy = "innerreduction"
    } : !transform.any_op

    transform.yield
  }
}
```
After linalg-vectorization:
```
module {
  func.func @linalg_reduce_scalable_leading_dim(%arg0: tensor<?x?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
    %c0 = arith.constant 0 : index
    %dim = tensor.dim %arg0, %c0 : tensor<?x?xf32>
    %c1 = arith.constant 1 : index
    %dim_0 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
    %c0_1 = arith.constant 0 : index
    %cst = arith.constant 0.000000e+00 : f32
    %0 = vector.create_mask %dim, %dim_0 : vector<[4]x1xi1>
    %1 = vector.mask %0 { vector.transfer_read %arg0[%c0_1, %c0_1], %cst {in_bounds = [true, true]} : tensor<?x?xf32>, vector<[4]x1xf32> } : vector<[4]x1xi1> -> vector<[4]x1xf32>
    %cst_2 = arith.constant 0.000000e+00 : f32
    %2 = vector.create_mask %dim_0 : vector<1xi1>
    %3 = vector.mask %2 { vector.transfer_read %arg1[%c0_1], %cst_2 {in_bounds = [true]} : tensor<?xf32>, vector<1xf32> } : vector<1xi1> -> vector<1xf32>
    %4 = vector.mask %0 { vector.multi_reduction <add>, %1, %3 [0] : vector<[4]x1xf32> to vector<1xf32> } : vector<[4]x1xi1> -> vector<1xf32>
    %c0_3 = arith.constant 0 : index
    %5 = vector.mask %2 { vector.transfer_write %4, %arg1[%c0_3] {in_bounds = [true]} : vector<1xf32>, tensor<?xf32> } : vector<1xi1> -> tensor<?xf32>
    return %5 : tensor<?xf32>
  }
  module attributes {transform.with_named_sequence} {
  }
}
```
After lowering vector masked xfer and multi reduction:
```
module {
  func.func @linalg_reduce_scalable_leading_dim(%arg0: tensor<?x?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
    %cst = arith.constant dense<0.000000e+00> : vector<1xf32>
    %cst_0 = arith.constant 0.000000e+00 : f32
    %c1 = arith.constant 1 : index
    %c0 = arith.constant 0 : index
    %dim = tensor.dim %arg0, %c0 : tensor<?x?xf32>
    %dim_1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
    %0 = vector.create_mask %dim, %dim_1 : vector<[4]x1xi1>
    %1 = vector.transfer_read %arg0[%c0, %c0], %cst_0, %0 {in_bounds = [true, true]} : tensor<?x?xf32>, vector<[4]x1xf32>
    %2 = vector.create_mask %dim_1 : vector<1xi1>
    %3 = vector.transfer_read %arg1[%c0], %cst_0, %2 {in_bounds = [true]} : tensor<?xf32>, vector<1xf32>
    %4 = vector.transpose %0, [1, 0] : vector<[4]x1xi1> to vector<1x[4]xi1>
    %5 = vector.transpose %1, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
    %6 = vector.extract %5[0] : vector<[4]xf32> from vector<1x[4]xf32>
    %7 = vector.extract %3[0] : f32 from vector<1xf32>
    %8 = vector.extract %4[0] : vector<[4]xi1> from vector<1x[4]xi1>
    %9 = vector.mask %8 { vector.reduction <add>, %6, %7 : vector<[4]xf32> into f32 } : vector<[4]xi1> -> f32
    %10 = vector.insertelement %9, %cst[%c0 : index] : vector<1xf32>
    %11 = vector.transfer_write %10, %arg1[%c0], %2 {in_bounds = [true]} : vector<1xf32>, tensor<?xf32>
    return %11 : tensor<?xf32>
  }
  module attributes {transform.with_named_sequence} {
  }
}
```
Trying to lower above mlir to llvm with `mlir-opt -test-lower-to-llvm`:
```
module {
  func.func @linalg_reduce_scalable_leading_dim(%arg0: tensor<?x?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
    %0 = llvm.mlir.constant(4 : i32) : i32
    %1 = llvm.mlir.constant(0 : i64) : i64
    %2 = llvm.mlir.undef : vector<[4]xi32>
    %3 = llvm.mlir.constant(0 : i32) : i32
    %4 = llvm.mlir.undef : vector<1xi32>
    %5 = llvm.mlir.constant(dense<0> : vector<1xi32>) : vector<1xi32>
    %6 = llvm.mlir.constant(dense<false> : vector<[4]xi1>) : vector<[4]xi1>
    %7 = llvm.mlir.constant(dense<0.000000e+00> : vector<1xf32>) : vector<1xf32>
    %8 = llvm.mlir.constant(0.000000e+00 : f32) : f32
    %9 = llvm.mlir.constant(1 : index) : i64
    %10 = builtin.unrealized_conversion_cast %9 : i64 to index
    %11 = llvm.mlir.constant(0 : index) : i64
    %12 = builtin.unrealized_conversion_cast %11 : i64 to index
    %dim = tensor.dim %arg0, %12 : tensor<?x?xf32>
    %13 = builtin.unrealized_conversion_cast %dim : index to i64
    %dim_0 = tensor.dim %arg0, %10 : tensor<?x?xf32>
    %14 = builtin.unrealized_conversion_cast %dim_0 : index to i64
--> %15 = vector.create_mask %dim, %dim_0 : vector<[4]x1xi1>
--> %16 = vector.transfer_read %arg0[%12, %12], %8, %15 {in_bounds = [true, true]} : tensor<?x?xf32>, vector<[4]x1xf32>
    %17 = llvm.trunc %14 : i64 to i32
    %18 = llvm.insertelement %17, %4[%3 : i32] : vector<1xi32>
    %19 = llvm.shufflevector %18, %4 [0] : vector<1xi32>
    %20 = llvm.icmp "sgt" %19, %5 : vector<1xi32>
--> %21 = vector.transfer_read %arg1[%12], %8, %20 {in_bounds = [true]} : tensor<?xf32>, vector<1xf32>
    %22 = llvm.intr.experimental.stepvector : vector<[4]xi32>
    %23 = llvm.trunc %13 : i64 to i32
    %24 = llvm.insertelement %23, %2[%3 : i32] : vector<[4]xi32>
    %25 = llvm.shufflevector %24, %2 [0, 0, 0, 0] : vector<[4]xi32>
    %26 = llvm.icmp "slt" %22, %25 : vector<[4]xi32>
    %27 = llvm.icmp "sgt" %14, %11 : i64
    %28 = llvm.select %27, %26, %6 : i1, vector<[4]xi1>
--> %29 = vector.shape_cast %16 : vector<[4]x1xf32> to vector<1x[4]xf32>
    %30 = builtin.unrealized_conversion_cast %29 : vector<1x[4]xf32> to !llvm.array<1 x vector<[4]xf32>>
    %31 = llvm.extractvalue %30[0] : !llvm.array<1 x vector<[4]xf32>>
    %32 = llvm.extractelement %21[%1 : i64] : vector<1xf32>
    %33 = "llvm.intr.vscale"() : () -> i64
    %34 = llvm.trunc %33 : i64 to i32
    %35 = llvm.mul %34, %0 : i32
    %36 = "llvm.intr.vp.reduce.fadd"(%32, %31, %28, %35) : (f32, vector<[4]xf32>, vector<[4]xi1>, i32) -> f32
    %37 = llvm.insertelement %36, %7[%11 : i64] : vector<1xf32>
--> %38 = vector.transfer_write %37, %arg1[%12], %20 {in_bounds = [true]} : vector<1xf32>, tensor<?xf32>
    return %38 : tensor<?xf32>
  }
  module attributes {transform.with_named_sequence} {
  }
}
```
Note some vector ops are not converted and results of builtin.unrealized_conversion_cast are being used. `mlir-translate --mlir-to-llvmir` will fail due to these ops. 

https://github.com/llvm/llvm-project/pull/98639


More information about the Mlir-commits mailing list