[Mlir-commits] [mlir] [mlir][SVE] Add an e2e test for vector.contract (PR #69845)

Andrzej WarzyƄski llvmlistbot at llvm.org
Thu Oct 26 05:00:52 PDT 2023


https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/69845

>From 74d86703223e3c5fba3f424189643889a11ae9a4 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Fri, 20 Oct 2023 16:13:49 +0000
Subject: [PATCH] [mlir][SVE] Add an e2e test for vector.contract

Adds an end-to-end test for `vector.contract` that targets SVE (i.e.
scalable vectors). Note that this requires lifting the restriction on
`vector.outerproduct` (to which `vector.contract` is lowered to) that
would deem the following as invalid by the Op verifier (*):

```
vector.outerproduct %27, %28, %26 {kind = #vector.kind<add>} : vector<3xf32>, vector<[2]xf32>
```

This is indeed valid as the end-to-end test demonstrates (at least when
compiling for SVE).

Depends on #68794
---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      |   9 +-
 ...r-contract-to-outerproduct-transforms.mlir |  59 ++++++--
 .../Vector/vector-scalable-outerproduct.mlir  |   5 +-
 .../Vector/CPU/ArmSVE/test-contraction.mlir   | 137 ++++++++++++++++++
 4 files changed, 191 insertions(+), 19 deletions(-)
 create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/test-contraction.mlir

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 8de132daf3c6a5d..d77476c10908395 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -3067,9 +3067,12 @@ LogicalResult OuterProductOp::verify() {
       return emitOpError("expected #1 operand dim to match result dim #1");
     if (vRHS.getDimSize(0) != vRES.getDimSize(1))
       return emitOpError("expected #2 operand dim to match result dim #2");
-    if (vRHS.isScalable() != vLHS.isScalable())
-      return emitOpError("expected either all or none of vector operands #1 "
-                         "and #2 to be scalable");
+    if (vLHS.isScalable() && !vRHS.isScalable()) {
+      // This restriction reflects what's currently supported in terms of
+      // scalable vectors. However, we could relax this if there's a use case.
+      return emitOpError(
+          "expected either both or only #2 operand dim to be scalable");
+    }
   } else {
     // An AXPY operation.
     if (vRES.getRank() != 1)
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
index 3530393c013a1ae..44fb23088cea933 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
@@ -79,21 +79,21 @@ func.func @masked_extract_contract2_scalable_parallel_dim(%arg0: vector<[2]x3xf3
 }
 
 // CHECK-LABEL: func.func @masked_extract_contract4(
-// CHECK-SAME:                                      %[[VAL_0:.*]]: vector<3x5xf32>,
-// CHECK-SAME:                                      %[[VAL_1:.*]]: vector<5x7xf32>,
-// CHECK-SAME:                                      %[[VAL_2:.*]]: vector<3x7xf32>,
-// CHECK-SAME:                                      %[[VAL_3:.*]]: vector<3x7x5xi1>) -> vector<3x7xf32> {
-// CHECK:         %[[VAL_5:.*]] = vector.transpose %[[VAL_3]], [2, 0, 1] : vector<3x7x5xi1> to vector<5x3x7xi1>
-// CHECK:         %[[VAL_8:.*]] = vector.extract %[[VAL_5]][0] : vector<3x7xi1> from vector<5x3x7xi1>
-// CHECK:         %[[VAL_9:.*]] = vector.mask %[[VAL_8]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
-// CHECK:         %[[VAL_12:.*]] = vector.extract %[[VAL_5]][1] : vector<3x7xi1> from vector<5x3x7xi1>
-// CHECK:         %[[VAL_13:.*]] = vector.mask %[[VAL_12]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
-// CHECK:         %[[VAL_16:.*]] = vector.extract %[[VAL_5]][2] : vector<3x7xi1> from vector<5x3x7xi1>
-// CHECK:         %[[VAL_17:.*]] = vector.mask %[[VAL_16]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
-// CHECK:         %[[VAL_20:.*]] = vector.extract %[[VAL_5]][3] : vector<3x7xi1> from vector<5x3x7xi1>
-// CHECK:         %[[VAL_21:.*]] = vector.mask %[[VAL_20]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
-// CHECK:         %[[VAL_24:.*]] = vector.extract %[[VAL_5]][4] : vector<3x7xi1> from vector<5x3x7xi1>
-// CHECK:         %[[VAL_25:.*]] = vector.mask %[[VAL_24]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
+// CHECK-SAME:    %{{.*}}: vector<3x5xf32>,
+// CHECK-SAME:    %{{.*}}: vector<5x7xf32>,
+// CHECK-SAME:    %{{.*}}: vector<3x7xf32>,
+// CHECK-SAME:    %[[IN_MASK:.*]]: vector<3x7x5xi1>) -> vector<3x7xf32> {
+// CHECK:         %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [2, 0, 1] : vector<3x7x5xi1> to vector<5x3x7xi1>
+// CHECK:         %[[T_MASK_R0:.*]] = vector.extract %[[T_MASK]][0] : vector<3x7xi1> from vector<5x3x7xi1>
+// CHECK:         %{{.*}} = vector.mask %[[T_MASK_R0]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
+// CHECK:         %[[T_MASK_R1:.*]] = vector.extract %[[T_MASK]][1] : vector<3x7xi1> from vector<5x3x7xi1>
+// CHECK:         %{{.*}} = vector.mask %[[T_MASK_R1]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
+// CHECK:         %[[T_MASK_R2:.*]] = vector.extract %[[T_MASK]][2] : vector<3x7xi1> from vector<5x3x7xi1>
+// CHECK:         %{{.*}} = vector.mask %[[T_MASK_R2]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
+// CHECK:         %[[T_MASK_R3:.*]] = vector.extract %[[T_MASK]][3] : vector<3x7xi1> from vector<5x3x7xi1>
+// CHECK:         %{{.*}} = vector.mask %[[T_MASK_R3]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
+// CHECK:         %[[T_MASK_R4:.*]] = vector.extract %[[T_MASK]][4] : vector<3x7xi1> from vector<5x3x7xi1>
+// CHECK:         %{{.*}} = vector.mask %[[T_MASK_R4]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
 
 func.func @masked_extract_contract4(%arg0: vector<3x5xf32>,
                                     %arg1: vector<5x7xf32>,
@@ -104,6 +104,35 @@ func.func @masked_extract_contract4(%arg0: vector<3x5xf32>,
   return %0 : vector<3x7xf32>
 }
 
+// CHECK-LABEL: func.func @masked_extract_contract4_scalable_J_dim(
+// CHECK-SAME:    %{{.*}}: vector<3x5xf32>,
+// CHECK-SAME:    %{{.*}}: vector<5x[7]xf32>,
+// CHECK-SAME:    %{{.*}}: vector<3x[7]xf32>,
+// CHECK-SAME:    %[[IN_MASK:.*]]: vector<3x[7]x5xi1>) -> vector<3x[7]xf32> {
+// CHECK:         %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [2, 0, 1] : vector<3x[7]x5xi1> to vector<5x3x[7]xi1>
+// CHECK:         %[[T_MASK_R0:.*]] = vector.extract %[[T_MASK]][0] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
+// CHECK:         %{{.*}} = vector.mask %[[T_MASK_R0]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
+// CHECK:         %[[T_MASK_R1:.*]] = vector.extract %[[T_MASK]][1] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
+// CHECK:         %[[VAL_13:.*]] = vector.mask %[[T_MASK_R1]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
+// CHECK:         %[[T_MASK_R2:.*]] = vector.extract %[[T_MASK]][2] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
+// CHECK:         %{{.*}} = vector.mask %[[T_MASK_R2]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
+// CHECK:         %[[T_MASK_R3:.*]] = vector.extract %[[T_MASK]][3] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
+// CHECK:         %{{.*}} = vector.mask %[[T_MASK_R3]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
+// CHECK:         %[[T_MASK_R4:.*]] = vector.extract %[[T_MASK]][4] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
+// CHECK:         %{{.*}} = vector.mask %[[T_MASK_R4]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
+
+// Note that only the J dimension is scalable in this example. In theory, all
+// dimensions could be be scalable, but there is no target yet for which this
+// would make sense.
+func.func @masked_extract_contract4_scalable_J_dim(%arg0: vector<3x5xf32>,
+                                    %arg1: vector<5x[7]xf32>,
+                                    %arg2: vector<3x[7]xf32>,
+                                    %m : vector<3x[7]x5xi1>) -> vector<3x[7]xf32> {
+  %0 = vector.mask %m { vector.contract #matmat_trait %arg0, %arg1, %arg2
+  : vector<3x5xf32>, vector<5x[7]xf32> into vector<3x[7]xf32> } : vector<3x[7]x5xi1> -> vector<3x[7]xf32>
+  return %0 : vector<3x[7]xf32>
+}
+
 // CHECK-LABEL: func @matmul
 // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x4xf32>,
 // CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x3xf32>,
diff --git a/mlir/test/Dialect/Vector/vector-scalable-outerproduct.mlir b/mlir/test/Dialect/Vector/vector-scalable-outerproduct.mlir
index 7d9923e036660c9..3b4e24da92aaacc 100644
--- a/mlir/test/Dialect/Vector/vector-scalable-outerproduct.mlir
+++ b/mlir/test/Dialect/Vector/vector-scalable-outerproduct.mlir
@@ -21,9 +21,12 @@ func.func @invalid_outerproduct(%src : memref<?xf32>) {
   %0 = vector.load %src[%idx] : memref<?xf32>, vector<[4]xf32>
   %1 = vector.load %src[%idx] : memref<?xf32>, vector<4xf32>
 
-  // expected-error @+1 {{expected either all or none of vector operands #1 and #2 to be scalable}}
+  // expected-error @+1 {{expected either both or only #2 operand dim to be scalable}}
   %op = vector.outerproduct %0, %1 : vector<[4]xf32>, vector<4xf32>
+
+  return
 }
+
 // -----
 
 func.func @invalid_outerproduct1(%src : memref<?xf32>) {
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/test-contraction.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/test-contraction.mlir
new file mode 100644
index 000000000000000..12187dfd7787155
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/test-contraction.mlir
@@ -0,0 +1,137 @@
+// DEFINE: %{compile} = mlir-opt %s  -test-transform-dialect-interpreter -test-transform-dialect-erase-schedule\
+// DEFINE:    -cse -canonicalize -convert-vector-to-scf -arm-sve-legalize-vector-storage\
+// DEFINE:    -convert-vector-to-llvm="enable-arm-sve" -test-lower-to-llvm
+// DEFINE: %{entry} =
+// DEFINE: %{run} = %mcr_aarch64_cmd -e=%{entry} -entry-point-result=void --march=aarch64 --mattr="+sve" -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext
+
+// REDEFINE: %{entry} = entry_i32
+// RUN: %{compile} | %{run} | FileCheck %s --check-prefix=I32
+
+// REDEFINE: %{entry} = entry_f32
+// RUN: %{compile} | %{run} | FileCheck %s --check-prefix=F32
+
+#matmat_accesses = [
+  affine_map<(i, j, k) -> (i, k)>,
+  affine_map<(i, j, k) -> (k, j)>,
+  affine_map<(i, j, k) -> (i, j)>
+]
+#matmat_trait = {
+  indexing_maps = #matmat_accesses,
+  iterator_types = ["parallel", "parallel", "reduction"]
+}
+
+func.func @entry_i32() {
+  %vscale = vector.vscale
+
+  %c0 = arith.constant 0 : index
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+  %c5 = arith.constant 5 : index
+  %n_rows = arith.muli %vscale, %c2 : index
+
+  %cst = arith.constant 0: i32
+  %i32_123 = arith.constant 123 : i32
+  %i32_314 = arith.constant 314 : i32
+
+  // Allocate and initialize matrix A
+  %A_alloc = memref.alloca() : memref<3x5xi32>
+  linalg.fill ins(%i32_123 : i32) outs(%A_alloc :memref<3x5xi32>)
+  %mask_a = vector.create_mask %c3, %c5 : vector<3x5xi1>
+  %vector_a = vector.transfer_read %A_alloc[%c0, %c0], %cst, %mask_a {in_bounds = [true, true]} : memref<3x5xi32>, vector<3x5xi32>
+
+  // Allocate and initialize matrix B
+  %B_alloc = memref.alloca(%n_rows) : memref<5x?xi32>
+  linalg.fill ins(%i32_123 : i32) outs(%B_alloc :memref<5x?xi32>)
+  %mask_b = vector.create_mask %c5, %n_rows : vector<5x[2]xi1>
+  %vector_b = vector.transfer_read %B_alloc[%c0, %c0], %cst, %mask_b {in_bounds = [true, true]} : memref<5x?xi32>, vector<5x[2]xi32>
+
+  // Allocate and initialize matrix C
+  %C_alloc = memref.alloca(%n_rows) : memref<3x?xi32>
+  linalg.fill ins(%i32_314 : i32) outs(%C_alloc :memref<3x?xi32>)
+  %mask_c = vector.create_mask %c3, %n_rows : vector<3x[2]xi1>
+  %vector_c = vector.transfer_read %C_alloc[%c0, %c0], %cst, %mask_c {in_bounds = [true, true]} : memref<3x?xi32>, vector<3x[2]xi32>
+
+  // Matmul
+  %m = vector.create_mask %c3, %n_rows, %c5 : vector<3x[2]x5xi1>
+  %0 = vector.mask %m { vector.contract #matmat_trait %vector_a, %vector_b, %vector_c
+    : vector<3x5xi32>, vector<5x[2]xi32> into vector<3x[2]xi32> } : vector<3x[2]x5xi1> -> vector<3x[2]xi32>
+
+  // Print the output
+  %slice1 = vector.extract %0[0] : vector<[2]xi32> from vector<3x[2]xi32>
+  // I32: ( 75959, 75959, 75959, 75959
+  vector.print %slice1 : vector<[2]xi32>
+  %slice2 = vector.extract %0[1] : vector<[2]xi32> from vector<3x[2]xi32>
+  // I32-NEXT: ( 75959, 75959, 75959, 75959
+  vector.print %slice2 : vector<[2]xi32>
+  %slice3 = vector.extract %0[2] : vector<[2]xi32> from vector<3x[2]xi32>
+  // I32-NEXT: ( 75959, 75959, 75959, 75959
+  vector.print %slice3 : vector<[2]xi32>
+
+  // CHECK: SVE: END OF TEST OUTPUT
+  vector.print str "SVE: END OF TEST OUTPUT"
+
+  return
+}
+
+func.func @entry_f32() {
+  %vscale = vector.vscale
+
+  %c0 = arith.constant 0 : index
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+  %c5 = arith.constant 5 : index
+  %n_rows = arith.muli %vscale, %c2 : index
+
+  %cst = arith.constant 0.0: f32
+  %f32_123 = arith.constant 1.23 : f32
+  %f32_314 = arith.constant 3.14 : f32
+
+  // Allocate and initialize matrix A
+  %A_alloc = memref.alloca() : memref<3x5xf32>
+  linalg.fill ins(%f32_123 : f32) outs(%A_alloc :memref<3x5xf32>)
+  %mask_a = vector.create_mask %c3, %c5 : vector<3x5xi1>
+  %vector_a = vector.transfer_read %A_alloc[%c0, %c0], %cst, %mask_a {in_bounds = [true, true]} : memref<3x5xf32>, vector<3x5xf32>
+
+  // Allocate and initialize matrix B
+  %B_alloc = memref.alloca(%n_rows) : memref<5x?xf32>
+  linalg.fill ins(%f32_123 : f32) outs(%B_alloc :memref<5x?xf32>)
+  %mask_b = vector.create_mask %c5, %n_rows : vector<5x[2]xi1>
+  %vector_b = vector.transfer_read %B_alloc[%c0, %c0], %cst, %mask_b {in_bounds = [true, true]} : memref<5x?xf32>, vector<5x[2]xf32>
+
+  // Allocate and initialize matrix C
+  %C_alloc = memref.alloca(%n_rows) : memref<3x?xf32>
+  linalg.fill ins(%f32_314 : f32) outs(%C_alloc :memref<3x?xf32>)
+  %mask_c = vector.create_mask %c3, %n_rows : vector<3x[2]xi1>
+  %vector_c = vector.transfer_read %C_alloc[%c0, %c0], %cst, %mask_c {in_bounds = [true, true]} : memref<3x?xf32>, vector<3x[2]xf32>
+
+  // Matmul
+  %m = vector.create_mask %c3, %n_rows, %c5 : vector<3x[2]x5xi1>
+  %0 = vector.mask %m { vector.contract #matmat_trait %vector_a, %vector_b, %vector_c
+    : vector<3x5xf32>, vector<5x[2]xf32> into vector<3x[2]xf32> } : vector<3x[2]x5xi1> -> vector<3x[2]xf32>
+
+  // Print the output
+  %slice1 = vector.extract %0[0] : vector<[2]xf32> from vector<3x[2]xf32>
+  // F32: ( 10.7045, 10.7045, 10.7045, 10.7045
+  vector.print %slice1 : vector<[2]xf32>
+  %slice2 = vector.extract %0[1] : vector<[2]xf32> from vector<3x[2]xf32>
+  // F32-NEXT: ( 10.7045, 10.7045, 10.7045, 10.7045
+  vector.print %slice2 : vector<[2]xf32>
+  %slice3 = vector.extract %0[2] : vector<[2]xf32> from vector<3x[2]xf32>
+  // F32-NEXT: ( 10.7045, 10.7045, 10.7045, 10.7045
+  vector.print %slice3 : vector<[2]xf32>
+
+  // CHECK: SVE: END OF TEST OUTPUT
+  vector.print str "SVE: END OF TEST OUTPUT"
+
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%module_op: !transform.any_op):
+  %f = transform.structured.match ops{["func.func"]} in %module_op
+    : (!transform.any_op) -> !transform.any_op
+
+  transform.apply_patterns to %f {
+    transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
+  } : !transform.any_op
+}



More information about the Mlir-commits mailing list