[Mlir-commits] [mlir] [mlir][ArmSME] Support lowering masked vector.outerproduct ops to SME (PR #69604)
Benjamin Maxwell
llvmlistbot at llvm.org
Wed Oct 25 08:19:31 PDT 2023
================
@@ -427,12 +427,88 @@ struct TransposeOpToArmSMELowering
}
};
+/// Conversion pattern for vector.outerproduct.
+///
+/// If the vector.outerproduct is masked (and the mask from a
+/// vector.create_mask), then the mask is decomposed into two 1-D masks for the
+/// operands.
+///
+/// Example:
+///
+/// %mask = vector.create_mask %dimA, %dimB : vector<[4]x[4]xi1>
+/// %result = vector.mask %mask {
+/// vector.outerproduct %vecA, %vecB
+/// : vector<[4]xf32>, vector<[4]xf32>
+/// } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
+///
+/// is converted to:
+///
+/// %maskA = vector.create_mask %dimA : vector<[4]xi1>
+/// %maskB = vector.create_mask %dimB : vector<[4]xi1>
+/// %result = arm_sme.outerproduct %vecA, %vecB masks(%maskA, %maskB)
+/// : vector<[4]xf32>, vector<[4]xf32>, vector<[4]x[4]xf32>
+///
+struct VectorOuterProductToArmSMELowering
+ : public OpRewritePattern<vector::OuterProductOp> {
+
+ using OpRewritePattern<vector::OuterProductOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::OuterProductOp outerProductOp,
+ PatternRewriter &rewriter) const override {
+ // AXPY operation not suited for SME.
+ if (!isa<VectorType>(outerProductOp.getOperandTypeRHS()))
+ return outerProductOp.emitError("AXPY operations not supported");
----------------
MacDue wrote:
If does support if you want it... You'd just have to mask out all but the first element of the LHS. We're choosing not to support it though.
https://github.com/llvm/llvm-project/pull/69604
More information about the Mlir-commits
mailing list