[Mlir-commits] [mlir] [mlir][ArmSME] Fail instead of error in vector.outerproduct lowering (PR #75447)
Cullen Rhodes
llvmlistbot at llvm.org
Thu Dec 14 03:01:20 PST 2023
https://github.com/c-rhodes updated https://github.com/llvm/llvm-project/pull/75447
>From 17d111ffb264d4ab8b584261e6fef84c8fc3eb6e Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Thu, 14 Dec 2023 08:15:18 +0000
Subject: [PATCH] [mlir][ArmSME] Fail instead of error in vector.outerproduct
lowering
The 'vector.outerproduct' -> 'arm_sme.outerproduct' conversion currently
errors on unsupported cases when it should return failure.
---
.../lib/Conversion/VectorToArmSME/VectorToArmSME.cpp | 10 ++++++----
mlir/test/Conversion/VectorToArmSME/unsupported.mlir | 12 +++++++++---
2 files changed, 15 insertions(+), 7 deletions(-)
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 312e89c8f100dd..87d1bf9bed5a31 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -510,16 +510,18 @@ struct VectorOuterProductToArmSMELowering
// We don't yet support lowering AXPY operations to SME. These could be
// lowered by masking out all but the first element of the LHS.
if (!isa<VectorType>(outerProductOp.getOperandTypeRHS()))
- return outerProductOp.emitError("AXPY operations not supported");
+ return rewriter.notifyMatchFailure(outerProductOp,
+ "AXPY operations not supported");
if (!arm_sme::isValidSMETileVectorType(
outerProductOp.getResultVectorType()))
- return outerProductOp.emitError(
- "outer product does not fit into SME tile");
+ return rewriter.notifyMatchFailure(
+ outerProductOp, "outer product does not fit into SME tile");
auto kind = outerProductOp.getKind();
if (kind != vector::CombiningKind::ADD)
- return outerProductOp.emitError(
+ return rewriter.notifyMatchFailure(
+ outerProductOp,
"unsupported kind (lowering to SME only supports ADD at the moment)");
Value lhsMask = {};
diff --git a/mlir/test/Conversion/VectorToArmSME/unsupported.mlir b/mlir/test/Conversion/VectorToArmSME/unsupported.mlir
index 3ef283727edd49..35089ebebac7e1 100644
--- a/mlir/test/Conversion/VectorToArmSME/unsupported.mlir
+++ b/mlir/test/Conversion/VectorToArmSME/unsupported.mlir
@@ -151,25 +151,31 @@ func.func @transfer_write_2d__out_of_bounds(%vector : vector<[4]x[4]xf32>, %dest
// -----
+// CHECK-LABEL: @vector_outerproduct_unsupported_axpy
+// CHECK-NOT: arm_sme.outerproduct
+// CHECK: vector.outerproduct
func.func @vector_outerproduct_unsupported_axpy(%lhs : vector<[2]xf64>, %rhs : f64, %acc : vector<[2]xf64>) -> vector<[2]xf64> {
- // expected-error at +1 {{AXPY operations not supported}}
%0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<mul>} : vector<[2]xf64>, f64
return %0 : vector<[2]xf64>
}
// -----
+// CHECK-LABEL: @vector_outerproduct_unsupported_kind
+// CHECK-NOT: arm_sme.outerproduct
+// CHECK: vector.outerproduct
func.func @vector_outerproduct_unsupported_kind(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>) {
%acc = arm_sme.get_tile : vector<[2]x[2]xf64>
- // expected-error at +1 {{unsupported kind}}
%0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<mul>} : vector<[2]xf64>, vector<[2]xf64>
"prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
}
// -----
+// CHECK-LABEL: @vector_outerproduct_unknown_mask
+// CHECK-NOT: arm_sme.outerproduct
+// CHECK: vector.outerproduct
func.func @vector_outerproduct_unknown_mask(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %mask : vector<[4]x[4]xi1>) {
- // CHECK: vector.outerproduct
%acc = arm_sme.get_tile : vector<[4]x[4]xf32>
%0 = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
"prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> ()
More information about the Mlir-commits
mailing list