[Mlir-commits] [mlir] e7432ba - [mlir][ArmSME] Fail instead of error in vector.outerproduct lowering (#75447)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Dec 14 23:30:37 PST 2023


Author: Cullen Rhodes
Date: 2023-12-15T07:30:32Z
New Revision: e7432babaf4fd9235be691591757321ce20e02da

URL: https://github.com/llvm/llvm-project/commit/e7432babaf4fd9235be691591757321ce20e02da
DIFF: https://github.com/llvm/llvm-project/commit/e7432babaf4fd9235be691591757321ce20e02da.diff

LOG: [mlir][ArmSME] Fail instead of error in vector.outerproduct lowering (#75447)

The 'vector.outerproduct' -> 'arm_sme.outerproduct' conversion currently
errors on unsupported cases when it should return failure.

Added: 
    

Modified: 
    mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
    mlir/test/Conversion/VectorToArmSME/unsupported.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 312e89c8f100dd..87d1bf9bed5a31 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -510,16 +510,18 @@ struct VectorOuterProductToArmSMELowering
     // We don't yet support lowering AXPY operations to SME. These could be
     // lowered by masking out all but the first element of the LHS.
     if (!isa<VectorType>(outerProductOp.getOperandTypeRHS()))
-      return outerProductOp.emitError("AXPY operations not supported");
+      return rewriter.notifyMatchFailure(outerProductOp,
+                                         "AXPY operations not supported");
 
     if (!arm_sme::isValidSMETileVectorType(
             outerProductOp.getResultVectorType()))
-      return outerProductOp.emitError(
-          "outer product does not fit into SME tile");
+      return rewriter.notifyMatchFailure(
+          outerProductOp, "outer product does not fit into SME tile");
 
     auto kind = outerProductOp.getKind();
     if (kind != vector::CombiningKind::ADD)
-      return outerProductOp.emitError(
+      return rewriter.notifyMatchFailure(
+          outerProductOp,
           "unsupported kind (lowering to SME only supports ADD at the moment)");
 
     Value lhsMask = {};

diff  --git a/mlir/test/Conversion/VectorToArmSME/unsupported.mlir b/mlir/test/Conversion/VectorToArmSME/unsupported.mlir
index 3ef283727edd49..35089ebebac7e1 100644
--- a/mlir/test/Conversion/VectorToArmSME/unsupported.mlir
+++ b/mlir/test/Conversion/VectorToArmSME/unsupported.mlir
@@ -151,25 +151,31 @@ func.func @transfer_write_2d__out_of_bounds(%vector : vector<[4]x[4]xf32>, %dest
 
 // -----
 
+// CHECK-LABEL: @vector_outerproduct_unsupported_axpy
+// CHECK-NOT: arm_sme.outerproduct
+// CHECK:     vector.outerproduct
 func.func @vector_outerproduct_unsupported_axpy(%lhs : vector<[2]xf64>, %rhs : f64, %acc : vector<[2]xf64>) -> vector<[2]xf64> {
-  // expected-error at +1 {{AXPY operations not supported}}
   %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<mul>} : vector<[2]xf64>, f64
   return %0 : vector<[2]xf64>
 }
 
 // -----
 
+// CHECK-LABEL: @vector_outerproduct_unsupported_kind
+// CHECK-NOT: arm_sme.outerproduct
+// CHECK:     vector.outerproduct
 func.func @vector_outerproduct_unsupported_kind(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>) {
   %acc = arm_sme.get_tile : vector<[2]x[2]xf64>
-  // expected-error at +1 {{unsupported kind}}
   %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<mul>} : vector<[2]xf64>, vector<[2]xf64>
   "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
 }
 
 // -----
 
+// CHECK-LABEL: @vector_outerproduct_unknown_mask
+// CHECK-NOT: arm_sme.outerproduct
+// CHECK:     vector.outerproduct
 func.func @vector_outerproduct_unknown_mask(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %mask : vector<[4]x[4]xi1>) {
-  // CHECK: vector.outerproduct
   %acc = arm_sme.get_tile : vector<[4]x[4]xf32>
   %0 = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
   "prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> ()


        


More information about the Mlir-commits mailing list