[Mlir-commits] [mlir] [mlir][ArmSME] Fail instead of error in vector.outerproduct lowering (PR #75447)

Cullen Rhodes llvmlistbot at llvm.org
Thu Dec 14 01:21:25 PST 2023


https://github.com/c-rhodes created https://github.com/llvm/llvm-project/pull/75447

The 'vector.outerproduct' -> 'arm_sme.outerproduct' conversion currently
errors on unsupported cases when it should return failure.

>From ad858654b0d79569757679ed536f60154ceaa4d5 Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Thu, 14 Dec 2023 08:54:41 +0000
Subject: [PATCH 1/2] [mlir][ArmSME][NFC] Move conversion tests

* Move -vector-to-arm-sme tests to mlir/test/Conversion/VectorToArmSME
* Move -arm-sme-to-llvm tests to mlir/test/Conversion/ArmSMEToLLVM
* Separate unsupported tests.
---
 .../ArmSMEToLLVM}/arm-sme-to-llvm.mlir        |   0
 .../Conversion/ArmSMEToLLVM/unsupported.mlir  |  14 ++
 .../VectorToArmSME/unsupported.mlir           | 176 ++++++++++++++++++
 .../VectorToArmSME/vector-to-arm-sme.mlir}    | 139 --------------
 .../Dialect/ArmSME/vector-ops-to-llvm.mlir    |  36 ----
 5 files changed, 190 insertions(+), 175 deletions(-)
 rename mlir/test/{Dialect/ArmSME => Conversion/ArmSMEToLLVM}/arm-sme-to-llvm.mlir (100%)
 create mode 100644 mlir/test/Conversion/ArmSMEToLLVM/unsupported.mlir
 create mode 100644 mlir/test/Conversion/VectorToArmSME/unsupported.mlir
 rename mlir/test/{Dialect/ArmSME/vector-ops-to-sme.mlir => Conversion/VectorToArmSME/vector-to-arm-sme.mlir} (85%)

diff --git a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir b/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir
similarity index 100%
rename from mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
rename to mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir
diff --git a/mlir/test/Conversion/ArmSMEToLLVM/unsupported.mlir b/mlir/test/Conversion/ArmSMEToLLVM/unsupported.mlir
new file mode 100644
index 00000000000000..59665c471921d5
--- /dev/null
+++ b/mlir/test/Conversion/ArmSMEToLLVM/unsupported.mlir
@@ -0,0 +1,14 @@
+// RUN: mlir-opt %s -allocate-arm-sme-tiles -convert-arm-sme-to-llvm -split-input-file -allow-unregistered-dialect -verify-diagnostics
+
+//===----------------------------------------------------------------------===//
+// arm_sme.outerproduct
+//===----------------------------------------------------------------------===//
+
+func.func @arm_sme_outerproduct_unsupported_type(%lhs : vector<[16]xi8>, %rhs : vector<[16]xi8>) {
+  %acc = arm_sme.get_tile : vector<[16]x[16]xi8>
+  // expected-error at +2 {{failed to legalize operation 'arm_sme.outerproduct'}}
+  // expected-error at +1 {{unsupported type}}
+  %0 = arm_sme.outerproduct %lhs, %rhs  acc(%acc) : vector<[16]xi8>, vector<[16]xi8>
+  "prevent.dce"(%0) : (vector<[16]x[16]xi8>) -> ()
+}
+
diff --git a/mlir/test/Conversion/VectorToArmSME/unsupported.mlir b/mlir/test/Conversion/VectorToArmSME/unsupported.mlir
new file mode 100644
index 00000000000000..3ef283727edd49
--- /dev/null
+++ b/mlir/test/Conversion/VectorToArmSME/unsupported.mlir
@@ -0,0 +1,176 @@
+// RUN: mlir-opt %s -convert-vector-to-arm-sme -split-input-file -allow-unregistered-dialect -verify-diagnostics | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// vector.transfer_read
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @transfer_read_2d__bad_type
+// CHECK-NOT: arm_sme.tile_load
+// CHECK: vector.transfer_read
+func.func @transfer_read_2d__bad_type(%src : memref<?x?xf64>) {
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0.0 : f64
+  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [false, false]} : memref<?x?xf64>, vector<[4]x[4]xf64>
+  "prevent.dce"(%0) : (vector<[4]x[4]xf64>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_read_2d__non_memref_type
+// CHECK-NOT: arm_sme.tile_load
+// CHECK: vector.transfer_read
+func.func @transfer_read_2d__non_memref_type(%src : tensor<?x?xf64>) {
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0.0 : f64
+  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : tensor<?x?xf64>, vector<[2]x[2]xf64>
+  "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_read_2d__bad_transfer_rank
+// CHECK-NOT: arm_sme.tile_load
+// CHECK: vector.transfer_read
+func.func @transfer_read_2d__bad_transfer_rank(%src : memref<?x?xf64>) {
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0.0 : f64
+  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true]} : memref<?x?xf64>, vector<[2]xf64>
+  "prevent.dce"(%0) : (vector<[2]xf64>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_read_2d__non_transpose
+// CHECK-NOT: arm_sme.tile_load
+// CHECK: vector.transfer_read
+func.func @transfer_read_2d__non_transpose(%src : memref<?x?xf64>) {
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0.0 : f64
+  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d0, 0)>, in_bounds = [true, true]} : memref<?x?xf64>, vector<[2]x[2]xf64>
+  "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_read_2d__out_of_bounds
+// CHECK-NOT: arm_sme.tile_load
+// CHECK: vector.transfer_read
+func.func @transfer_read_2d__out_of_bounds(%src : memref<?x?xf64>) {
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0.0 : f64
+  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [false, false]} : memref<?x?xf64>, vector<[2]x[2]xf64>
+  "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
+  return
+}
+
+//===----------------------------------------------------------------------===//
+// vector.transfer_write
+//===----------------------------------------------------------------------===//
+
+// -----
+
+// The following tests check the 'vector.transfer_write' -> 'arm_sme.intr.zero'
+// lowering only occurs for vector types of correct rank, shape, element size
+// and number of scalable dims.
+
+// CHECK-LABEL: @transfer_write_2d_zero__bad_type
+// CHECK: vector.transfer_write
+// CHECK-NOT: arm_sme.intr.zero
+func.func @transfer_write_2d_zero__bad_type(%arg0 : memref<?x?xi4>) {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant dense<0> : vector<[16]x[16]xi4>
+  vector.transfer_write %cst, %arg0[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[16]xi4>, memref<?x?xi4>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_write_2d_zero__bad_shape
+// CHECK: vector.transfer_write
+// CHECK-NOT: arm_sme.tile_store
+func.func @transfer_write_2d_zero__bad_shape(%arg0 : memref<?x?xi8>) {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant dense<0> : vector<[8]x[8]xi8>
+  vector.transfer_write %cst, %arg0[%c0, %c0] {in_bounds = [true, true]} : vector<[8]x[8]xi8>, memref<?x?xi8>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_write_2d_zero__bad_rank
+// CHECK: vector.transfer_write
+// CHECK-NOT: arm_sme.tile_store
+func.func @transfer_write_2d_zero__bad_rank(%arg0 : memref<?x?x?xi8>) {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant dense<0> : vector<[16]x[16]x[16]xi8>
+  vector.transfer_write %cst, %arg0[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<[16]x[16]x[16]xi8>, memref<?x?x?xi8>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_write_2d_zero__non_memref_type
+// CHECK: vector.transfer_write
+// CHECK-NOT: arm_sme.tile_store
+func.func @transfer_write_2d_zero__non_memref_type(%arg0 : tensor<?x?xi8>) -> tensor<?x?xi8> {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant dense<0> : vector<[16]x[16]xi8>
+  %0 = vector.transfer_write %cst, %arg0[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[16]xi8>, tensor<?x?xi8>
+  return %0 : tensor<?x?xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_write_2d__fixed
+// CHECK: vector.transfer_write
+// CHECK-NOT: arm_sme.tile_store
+func.func @transfer_write_2d__fixed(%vector : vector<16x16xi8>, %dest : memref<?x?xi8>) {
+  %c0 = arith.constant 0 : index
+  vector.transfer_write %vector, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xi8>, memref<?x?xi8>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_write_2d__out_of_bounds
+// CHECK: vector.transfer_write
+// CHECK-NOT: arm_sme.tile_store
+func.func @transfer_write_2d__out_of_bounds(%vector : vector<[4]x[4]xf32>, %dest : memref<?x?xf32>) {
+  %c0 = arith.constant 0 : index
+  vector.transfer_write %vector, %dest[%c0, %c0] {in_bounds = [false, false]} : vector<[4]x[4]xf32>, memref<?x?xf32>
+  return
+}
+
+//===----------------------------------------------------------------------===//
+// vector.outerproduct
+//===----------------------------------------------------------------------===//
+
+// -----
+
+func.func @vector_outerproduct_unsupported_axpy(%lhs : vector<[2]xf64>, %rhs : f64, %acc : vector<[2]xf64>) -> vector<[2]xf64> {
+  // expected-error at +1 {{AXPY operations not supported}}
+  %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<mul>} : vector<[2]xf64>, f64
+  return %0 : vector<[2]xf64>
+}
+
+// -----
+
+func.func @vector_outerproduct_unsupported_kind(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>) {
+  %acc = arm_sme.get_tile : vector<[2]x[2]xf64>
+  // expected-error at +1 {{unsupported kind}}
+  %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<mul>} : vector<[2]xf64>, vector<[2]xf64>
+  "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
+}
+
+// -----
+
+func.func @vector_outerproduct_unknown_mask(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %mask : vector<[4]x[4]xi1>) {
+  // CHECK: vector.outerproduct
+  %acc = arm_sme.get_tile : vector<[4]x[4]xf32>
+  %0 = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
+  "prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> ()
+}
diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir b/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir
similarity index 85%
rename from mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
rename to mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir
index 6ea949d9c16509..6783263c184961 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
+++ b/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir
@@ -150,71 +150,6 @@ func.func @transfer_read_2d_transpose_with_mask_f32(%src : memref<?x?xf32>, %mas
 
 // -----
 
-// CHECK-LABEL: @transfer_read_2d__bad_type
-// CHECK-NOT: arm_sme.tile_load
-// CHECK: vector.transfer_read
-func.func @transfer_read_2d__bad_type(%src : memref<?x?xf64>) {
-  %c0 = arith.constant 0 : index
-  %pad = arith.constant 0.0 : f64
-  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [false, false]} : memref<?x?xf64>, vector<[4]x[4]xf64>
-  "prevent.dce"(%0) : (vector<[4]x[4]xf64>) -> ()
-  return
-}
-
-// -----
-
-// CHECK-LABEL: @transfer_read_2d__non_memref_type
-// CHECK-NOT: arm_sme.tile_load
-// CHECK: vector.transfer_read
-func.func @transfer_read_2d__non_memref_type(%src : tensor<?x?xf64>) {
-  %c0 = arith.constant 0 : index
-  %pad = arith.constant 0.0 : f64
-  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : tensor<?x?xf64>, vector<[2]x[2]xf64>
-  "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
-  return
-}
-
-// -----
-
-// CHECK-LABEL: @transfer_read_2d__bad_transfer_rank
-// CHECK-NOT: arm_sme.tile_load
-// CHECK: vector.transfer_read
-func.func @transfer_read_2d__bad_transfer_rank(%src : memref<?x?xf64>) {
-  %c0 = arith.constant 0 : index
-  %pad = arith.constant 0.0 : f64
-  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true]} : memref<?x?xf64>, vector<[2]xf64>
-  "prevent.dce"(%0) : (vector<[2]xf64>) -> ()
-  return
-}
-
-// -----
-
-// CHECK-LABEL: @transfer_read_2d__non_transpose
-// CHECK-NOT: arm_sme.tile_load
-// CHECK: vector.transfer_read
-func.func @transfer_read_2d__non_transpose(%src : memref<?x?xf64>) {
-  %c0 = arith.constant 0 : index
-  %pad = arith.constant 0.0 : f64
-  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d0, 0)>, in_bounds = [true, true]} : memref<?x?xf64>, vector<[2]x[2]xf64>
-  "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
-  return
-}
-
-// -----
-
-// CHECK-LABEL: @transfer_read_2d__out_of_bounds
-// CHECK-NOT: arm_sme.tile_load
-// CHECK: vector.transfer_read
-func.func @transfer_read_2d__out_of_bounds(%src : memref<?x?xf64>) {
-  %c0 = arith.constant 0 : index
-  %pad = arith.constant 0.0 : f64
-  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [false, false]} : memref<?x?xf64>, vector<[2]x[2]xf64>
-  "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
-  return
-}
-
-// -----
-
 //===----------------------------------------------------------------------===//
 // vector.transfer_write
 //===----------------------------------------------------------------------===//
@@ -366,80 +301,6 @@ func.func @transfer_write_2d_transpose_with_mask_bf16(%vector : vector<[8]x[8]xb
   return
 }
 
-// -----
-
-// The following tests check the 'vector.transfer_write' -> 'arm_sme.intr.zero'
-// lowering only occurs for vector types of correct rank, shape, element size
-// and number of scalable dims.
-
-// CHECK-LABEL: @transfer_write_2d_zero__bad_type
-// CHECK: vector.transfer_write
-// CHECK-NOT: arm_sme.intr.zero
-func.func @transfer_write_2d_zero__bad_type(%arg0 : memref<?x?xi4>) {
-  %c0 = arith.constant 0 : index
-  %cst = arith.constant dense<0> : vector<[16]x[16]xi4>
-  vector.transfer_write %cst, %arg0[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[16]xi4>, memref<?x?xi4>
-  return
-}
-
-// -----
-
-// CHECK-LABEL: @transfer_write_2d_zero__bad_shape
-// CHECK: vector.transfer_write
-// CHECK-NOT: arm_sme.tile_store
-func.func @transfer_write_2d_zero__bad_shape(%arg0 : memref<?x?xi8>) {
-  %c0 = arith.constant 0 : index
-  %cst = arith.constant dense<0> : vector<[8]x[8]xi8>
-  vector.transfer_write %cst, %arg0[%c0, %c0] {in_bounds = [true, true]} : vector<[8]x[8]xi8>, memref<?x?xi8>
-  return
-}
-
-// -----
-
-// CHECK-LABEL: @transfer_write_2d_zero__bad_rank
-// CHECK: vector.transfer_write
-// CHECK-NOT: arm_sme.tile_store
-func.func @transfer_write_2d_zero__bad_rank(%arg0 : memref<?x?x?xi8>) {
-  %c0 = arith.constant 0 : index
-  %cst = arith.constant dense<0> : vector<[16]x[16]x[16]xi8>
-  vector.transfer_write %cst, %arg0[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<[16]x[16]x[16]xi8>, memref<?x?x?xi8>
-  return
-}
-
-// -----
-
-// CHECK-LABEL: @transfer_write_2d_zero__non_memref_type
-// CHECK: vector.transfer_write
-// CHECK-NOT: arm_sme.tile_store
-func.func @transfer_write_2d_zero__non_memref_type(%arg0 : tensor<?x?xi8>) -> tensor<?x?xi8> {
-  %c0 = arith.constant 0 : index
-  %cst = arith.constant dense<0> : vector<[16]x[16]xi8>
-  %0 = vector.transfer_write %cst, %arg0[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[16]xi8>, tensor<?x?xi8>
-  return %0 : tensor<?x?xi8>
-}
-
-// -----
-
-// CHECK-LABEL: @transfer_write_2d__fixed
-// CHECK: vector.transfer_write
-// CHECK-NOT: arm_sme.tile_store
-func.func @transfer_write_2d__fixed(%vector : vector<16x16xi8>, %dest : memref<?x?xi8>) {
-  %c0 = arith.constant 0 : index
-  vector.transfer_write %vector, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xi8>, memref<?x?xi8>
-  return
-}
-
-// -----
-
-// CHECK-LABEL: @transfer_write_2d__out_of_bounds
-// CHECK: vector.transfer_write
-// CHECK-NOT: arm_sme.tile_store
-func.func @transfer_write_2d__out_of_bounds(%vector : vector<[4]x[4]xf32>, %dest : memref<?x?xf32>) {
-  %c0 = arith.constant 0 : index
-  vector.transfer_write %vector, %dest[%c0, %c0] {in_bounds = [false, false]} : vector<[4]x[4]xf32>, memref<?x?xf32>
-  return
-}
-
 //===----------------------------------------------------------------------===//
 // vector.broadcast
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
index d16d250b70eb3f..ce5bfd25cbdbcc 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
@@ -469,42 +469,6 @@ func.func @vector_outerproduct_masked_f64(%lhs : vector<[2]xf64>, %rhs : vector<
   "prevent.dce"(%result) : (vector<[2]x[2]xf64>) -> ()
 }
 
-// -----
-
-func.func @vector_outerproduct_unsupported_axpy(%lhs : vector<[2]xf64>, %rhs : f64, %acc : vector<[2]xf64>) -> vector<[2]xf64> {
-  // expected-error at +1 {{AXPY operations not supported}}
-  %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<mul>} : vector<[2]xf64>, f64
-  return %0 : vector<[2]xf64>
-}
-
-// -----
-
-func.func @vector_outerproduct_unsupported_type(%lhs : vector<[16]xi8>, %rhs : vector<[16]xi8>) {
-  %acc = arm_sme.get_tile : vector<[16]x[16]xi8>
-  // expected-error at +2 {{failed to legalize operation 'arm_sme.outerproduct'}}
-  // expected-error at +1 {{unsupported type}}
-  %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[16]xi8>, vector<[16]xi8>
-  "prevent.dce"(%0) : (vector<[16]x[16]xi8>) -> ()
-}
-
-// -----
-
-func.func @vector_outerproduct_unsupported_kind(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>) {
-  %acc = arm_sme.get_tile : vector<[2]x[2]xf64>
-  // expected-error at +1 {{unsupported kind}}
-  %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<mul>} : vector<[2]xf64>, vector<[2]xf64>
-  "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
-}
-
-// -----
-
-func.func @vector_outerproduct_unknown_mask(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %mask : vector<[4]x[4]xi1>) {
-  // CHECK: vector.outerproduct
-  %acc = arm_sme.get_tile : vector<[4]x[4]xf32>
-  %0 = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
-  "prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> ()
-}
-
 //===----------------------------------------------------------------------===//
 // vector.insert
 //===----------------------------------------------------------------------===//

>From 91483f32babe811fd475ecd69efa4109e248cc1d Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Thu, 14 Dec 2023 08:15:18 +0000
Subject: [PATCH 2/2] [mlir][ArmSME] Fail instead of error in
 vector.outerproduct lowering

The 'vector.outerproduct' -> 'arm_sme.outerproduct' conversion currently
errors on unsupported cases when it should return failure.
---
 .../lib/Conversion/VectorToArmSME/VectorToArmSME.cpp | 10 ++++++----
 mlir/test/Conversion/VectorToArmSME/unsupported.mlir | 12 +++++++++---
 2 files changed, 15 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 312e89c8f100dd..87d1bf9bed5a31 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -510,16 +510,18 @@ struct VectorOuterProductToArmSMELowering
     // We don't yet support lowering AXPY operations to SME. These could be
     // lowered by masking out all but the first element of the LHS.
     if (!isa<VectorType>(outerProductOp.getOperandTypeRHS()))
-      return outerProductOp.emitError("AXPY operations not supported");
+      return rewriter.notifyMatchFailure(outerProductOp,
+                                         "AXPY operations not supported");
 
     if (!arm_sme::isValidSMETileVectorType(
             outerProductOp.getResultVectorType()))
-      return outerProductOp.emitError(
-          "outer product does not fit into SME tile");
+      return rewriter.notifyMatchFailure(
+          outerProductOp, "outer product does not fit into SME tile");
 
     auto kind = outerProductOp.getKind();
     if (kind != vector::CombiningKind::ADD)
-      return outerProductOp.emitError(
+      return rewriter.notifyMatchFailure(
+          outerProductOp,
           "unsupported kind (lowering to SME only supports ADD at the moment)");
 
     Value lhsMask = {};
diff --git a/mlir/test/Conversion/VectorToArmSME/unsupported.mlir b/mlir/test/Conversion/VectorToArmSME/unsupported.mlir
index 3ef283727edd49..35089ebebac7e1 100644
--- a/mlir/test/Conversion/VectorToArmSME/unsupported.mlir
+++ b/mlir/test/Conversion/VectorToArmSME/unsupported.mlir
@@ -151,25 +151,31 @@ func.func @transfer_write_2d__out_of_bounds(%vector : vector<[4]x[4]xf32>, %dest
 
 // -----
 
+// CHECK-LABEL: @vector_outerproduct_unsupported_axpy
+// CHECK-NOT: arm_sme.outerproduct
+// CHECK:     vector.outerproduct
 func.func @vector_outerproduct_unsupported_axpy(%lhs : vector<[2]xf64>, %rhs : f64, %acc : vector<[2]xf64>) -> vector<[2]xf64> {
-  // expected-error at +1 {{AXPY operations not supported}}
   %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<mul>} : vector<[2]xf64>, f64
   return %0 : vector<[2]xf64>
 }
 
 // -----
 
+// CHECK-LABEL: @vector_outerproduct_unsupported_kind
+// CHECK-NOT: arm_sme.outerproduct
+// CHECK:     vector.outerproduct
 func.func @vector_outerproduct_unsupported_kind(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>) {
   %acc = arm_sme.get_tile : vector<[2]x[2]xf64>
-  // expected-error at +1 {{unsupported kind}}
   %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<mul>} : vector<[2]xf64>, vector<[2]xf64>
   "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
 }
 
 // -----
 
+// CHECK-LABEL: @vector_outerproduct_unknown_mask
+// CHECK-NOT: arm_sme.outerproduct
+// CHECK:     vector.outerproduct
 func.func @vector_outerproduct_unknown_mask(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %mask : vector<[4]x[4]xi1>) {
-  // CHECK: vector.outerproduct
   %acc = arm_sme.get_tile : vector<[4]x[4]xf32>
   %0 = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
   "prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> ()



More information about the Mlir-commits mailing list