[Mlir-commits] [mlir] [mlir][ArmSME] Fail instead of error in vector.outerproduct lowering (PR #75447)

Cullen Rhodes llvmlistbot at llvm.org
Thu Dec 14 03:01:20 PST 2023


https://github.com/c-rhodes updated https://github.com/llvm/llvm-project/pull/75447

>From 17d111ffb264d4ab8b584261e6fef84c8fc3eb6e Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Thu, 14 Dec 2023 08:15:18 +0000
Subject: [PATCH] [mlir][ArmSME] Fail instead of error in vector.outerproduct
 lowering

The 'vector.outerproduct' -> 'arm_sme.outerproduct' conversion currently
errors on unsupported cases when it should return failure.
---
 .../lib/Conversion/VectorToArmSME/VectorToArmSME.cpp | 10 ++++++----
 mlir/test/Conversion/VectorToArmSME/unsupported.mlir | 12 +++++++++---
 2 files changed, 15 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 312e89c8f100dd..87d1bf9bed5a31 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -510,16 +510,18 @@ struct VectorOuterProductToArmSMELowering
     // We don't yet support lowering AXPY operations to SME. These could be
     // lowered by masking out all but the first element of the LHS.
     if (!isa<VectorType>(outerProductOp.getOperandTypeRHS()))
-      return outerProductOp.emitError("AXPY operations not supported");
+      return rewriter.notifyMatchFailure(outerProductOp,
+                                         "AXPY operations not supported");
 
     if (!arm_sme::isValidSMETileVectorType(
             outerProductOp.getResultVectorType()))
-      return outerProductOp.emitError(
-          "outer product does not fit into SME tile");
+      return rewriter.notifyMatchFailure(
+          outerProductOp, "outer product does not fit into SME tile");
 
     auto kind = outerProductOp.getKind();
     if (kind != vector::CombiningKind::ADD)
-      return outerProductOp.emitError(
+      return rewriter.notifyMatchFailure(
+          outerProductOp,
           "unsupported kind (lowering to SME only supports ADD at the moment)");
 
     Value lhsMask = {};
diff --git a/mlir/test/Conversion/VectorToArmSME/unsupported.mlir b/mlir/test/Conversion/VectorToArmSME/unsupported.mlir
index 3ef283727edd49..35089ebebac7e1 100644
--- a/mlir/test/Conversion/VectorToArmSME/unsupported.mlir
+++ b/mlir/test/Conversion/VectorToArmSME/unsupported.mlir
@@ -151,25 +151,31 @@ func.func @transfer_write_2d__out_of_bounds(%vector : vector<[4]x[4]xf32>, %dest
 
 // -----
 
+// CHECK-LABEL: @vector_outerproduct_unsupported_axpy
+// CHECK-NOT: arm_sme.outerproduct
+// CHECK:     vector.outerproduct
 func.func @vector_outerproduct_unsupported_axpy(%lhs : vector<[2]xf64>, %rhs : f64, %acc : vector<[2]xf64>) -> vector<[2]xf64> {
-  // expected-error at +1 {{AXPY operations not supported}}
   %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<mul>} : vector<[2]xf64>, f64
   return %0 : vector<[2]xf64>
 }
 
 // -----
 
+// CHECK-LABEL: @vector_outerproduct_unsupported_kind
+// CHECK-NOT: arm_sme.outerproduct
+// CHECK:     vector.outerproduct
 func.func @vector_outerproduct_unsupported_kind(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>) {
   %acc = arm_sme.get_tile : vector<[2]x[2]xf64>
-  // expected-error at +1 {{unsupported kind}}
   %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<mul>} : vector<[2]xf64>, vector<[2]xf64>
   "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
 }
 
 // -----
 
+// CHECK-LABEL: @vector_outerproduct_unknown_mask
+// CHECK-NOT: arm_sme.outerproduct
+// CHECK:     vector.outerproduct
 func.func @vector_outerproduct_unknown_mask(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %mask : vector<[4]x[4]xi1>) {
-  // CHECK: vector.outerproduct
   %acc = arm_sme.get_tile : vector<[4]x[4]xf32>
   %0 = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
   "prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> ()



More information about the Mlir-commits mailing list