[Mlir-commits] [mlir] [mlir][ArmSME] Support lowering masked vector.outerproduct ops to SME (PR #69604)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Thu Oct 26 01:55:56 PDT 2023
================
@@ -427,12 +427,88 @@ struct TransposeOpToArmSMELowering
}
};
+/// Conversion pattern for vector.outerproduct.
+///
+/// If the vector.outerproduct is masked (and the mask from a
+/// vector.create_mask), then the mask is decomposed into two 1-D masks for the
+/// operands.
+///
+/// Example:
+///
+/// %mask = vector.create_mask %dimA, %dimB : vector<[4]x[4]xi1>
+/// %result = vector.mask %mask {
+/// vector.outerproduct %vecA, %vecB
+/// : vector<[4]xf32>, vector<[4]xf32>
+/// } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
+///
+/// is converted to:
+///
+/// %maskA = vector.create_mask %dimA : vector<[4]xi1>
+/// %maskB = vector.create_mask %dimB : vector<[4]xi1>
+/// %result = arm_sme.outerproduct %vecA, %vecB masks(%maskA, %maskB)
+/// : vector<[4]xf32>, vector<[4]xf32>, vector<[4]x[4]xf32>
+///
+struct VectorOuterProductToArmSMELowering
+ : public OpRewritePattern<vector::OuterProductOp> {
+
+ using OpRewritePattern<vector::OuterProductOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::OuterProductOp outerProductOp,
+ PatternRewriter &rewriter) const override {
+ // AXPY operation not suited for SME.
+ if (!isa<VectorType>(outerProductOp.getOperandTypeRHS()))
+ return outerProductOp.emitError("AXPY operations not supported");
+
+ auto kind = outerProductOp.getKind();
+ if (kind != vector::CombiningKind::ADD)
+ return outerProductOp.emitError("unsupported kind");
+
+ Value lhsMask = {};
+ Value rhsMask = {};
+ Operation *rootOp = outerProductOp;
+ if (outerProductOp.isMasked()) {
+ auto maskingOp = outerProductOp.getMaskingOp();
+ rewriter.setInsertionPoint(maskingOp);
+ rootOp = maskingOp;
+
+ // Attempt to extract masks from vector.create_mask.
+ // TODO: Add support for other mask sources.
+ auto mask = maskingOp.getMask();
+ auto createMaskOp = mask.getDefiningOp<vector::CreateMaskOp>();
+ if (!createMaskOp)
+ return failure();
+
+ auto maskType = createMaskOp.getVectorType();
+ if (maskType.getRank() != 2)
----------------
banach-space wrote:
This should never happen, right? As in, that would be an invalid `vector.outerproduct`? Just checking my understanding.
https://github.com/llvm/llvm-project/pull/69604
More information about the Mlir-commits
mailing list