[Mlir-commits] [mlir] e7432ba - [mlir][ArmSME] Fail instead of error in vector.outerproduct lowering (#75447)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Dec 14 23:30:37 PST 2023
Author: Cullen Rhodes
Date: 2023-12-15T07:30:32Z
New Revision: e7432babaf4fd9235be691591757321ce20e02da
URL: https://github.com/llvm/llvm-project/commit/e7432babaf4fd9235be691591757321ce20e02da
DIFF: https://github.com/llvm/llvm-project/commit/e7432babaf4fd9235be691591757321ce20e02da.diff
LOG: [mlir][ArmSME] Fail instead of error in vector.outerproduct lowering (#75447)
The 'vector.outerproduct' -> 'arm_sme.outerproduct' conversion currently
errors on unsupported cases when it should return failure.
Added:
Modified:
mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
mlir/test/Conversion/VectorToArmSME/unsupported.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 312e89c8f100dd..87d1bf9bed5a31 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -510,16 +510,18 @@ struct VectorOuterProductToArmSMELowering
// We don't yet support lowering AXPY operations to SME. These could be
// lowered by masking out all but the first element of the LHS.
if (!isa<VectorType>(outerProductOp.getOperandTypeRHS()))
- return outerProductOp.emitError("AXPY operations not supported");
+ return rewriter.notifyMatchFailure(outerProductOp,
+ "AXPY operations not supported");
if (!arm_sme::isValidSMETileVectorType(
outerProductOp.getResultVectorType()))
- return outerProductOp.emitError(
- "outer product does not fit into SME tile");
+ return rewriter.notifyMatchFailure(
+ outerProductOp, "outer product does not fit into SME tile");
auto kind = outerProductOp.getKind();
if (kind != vector::CombiningKind::ADD)
- return outerProductOp.emitError(
+ return rewriter.notifyMatchFailure(
+ outerProductOp,
"unsupported kind (lowering to SME only supports ADD at the moment)");
Value lhsMask = {};
diff --git a/mlir/test/Conversion/VectorToArmSME/unsupported.mlir b/mlir/test/Conversion/VectorToArmSME/unsupported.mlir
index 3ef283727edd49..35089ebebac7e1 100644
--- a/mlir/test/Conversion/VectorToArmSME/unsupported.mlir
+++ b/mlir/test/Conversion/VectorToArmSME/unsupported.mlir
@@ -151,25 +151,31 @@ func.func @transfer_write_2d__out_of_bounds(%vector : vector<[4]x[4]xf32>, %dest
// -----
+// CHECK-LABEL: @vector_outerproduct_unsupported_axpy
+// CHECK-NOT: arm_sme.outerproduct
+// CHECK: vector.outerproduct
func.func @vector_outerproduct_unsupported_axpy(%lhs : vector<[2]xf64>, %rhs : f64, %acc : vector<[2]xf64>) -> vector<[2]xf64> {
- // expected-error at +1 {{AXPY operations not supported}}
%0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<mul>} : vector<[2]xf64>, f64
return %0 : vector<[2]xf64>
}
// -----
+// CHECK-LABEL: @vector_outerproduct_unsupported_kind
+// CHECK-NOT: arm_sme.outerproduct
+// CHECK: vector.outerproduct
func.func @vector_outerproduct_unsupported_kind(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>) {
%acc = arm_sme.get_tile : vector<[2]x[2]xf64>
- // expected-error at +1 {{unsupported kind}}
%0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<mul>} : vector<[2]xf64>, vector<[2]xf64>
"prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
}
// -----
+// CHECK-LABEL: @vector_outerproduct_unknown_mask
+// CHECK-NOT: arm_sme.outerproduct
+// CHECK: vector.outerproduct
func.func @vector_outerproduct_unknown_mask(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %mask : vector<[4]x[4]xi1>) {
- // CHECK: vector.outerproduct
%acc = arm_sme.get_tile : vector<[4]x[4]xf32>
%0 = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
"prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> ()
More information about the Mlir-commits
mailing list