[Mlir-commits] [mlir] [mlir][ArmSME] Lower vector.outerproduct to FMOPA/BFMOPA (PR #65621)
Diego Caballero
llvmlistbot at llvm.org
Sun Sep 10 23:14:56 PDT 2023
================
@@ -361,6 +361,111 @@ struct MoveVectorToTileSliceToArmSMELowering
}
};
+/// Lower `vector.outerproduct` to SME MOPA intrinsics.
+///
+/// Example:
+///
+/// %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>}
+/// : vector<[4]xf32>, vector<[4]xf32>
+///
+/// is converted to:
+///
+/// "arm_sme.intr.mopa"(%tile_id, %ptrue_s, %ptrue_s, %lhs, %rhs)
+/// : (i32, vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>,
+/// vector<[4]xf32>) -> ()
+///
+/// Currently only supports FMOPA and BFMOPA (non-widening).
+struct VectorOuterProductToArmSMELowering
+ : public ConvertOpToLLVMPattern<vector::OuterProductOp> {
+ using ConvertOpToLLVMPattern<vector::OuterProductOp>::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::OuterProductOp outerProductOp,
+ vector::OuterProductOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto isSupportedType = [](VectorType vectorType) {
+ // TODO: the FP outer product instruction variants are predicated on
+ // different features:
+ //
+ // * FMOPA (non-widening)
+ // * half-precision - +sme2p1,+sme-f16f16
+ // * single-precision - +sme
+ // * double-precision - +sme-f64f64
+ // * BFMOPA
+ // * half-precision - +sme2p1,+b16b16
+ //
+ // It should be possible to control lowering based on target features.
+ if ((vectorType.getRank() != 2) || !vectorType.allDimsScalable())
+ return false;
+
+ auto elementType = vectorType.getElementType();
+
+ if (!elementType.isF16() && !elementType.isBF16() &&
+ !elementType.isF32() && !elementType.isF64())
+ return false;
+
+ unsigned minNumElts = arm_sme::MinStreamingVectorLengthInBits /
+ vectorType.getElementTypeBitWidth();
+ if (vectorType.getShape() != ArrayRef<int64_t>({minNumElts, minNumElts}))
+ return false;
+
+ return true;
+ };
+
+ auto resultVectorType = outerProductOp.getResultVectorType();
+ if (!isSupportedType(resultVectorType))
+ return outerProductOp.emitError("unsupported type");
+
+ vector::CombiningKind kind = outerProductOp.getKind();
+ if (kind != vector::CombiningKind::ADD)
+ // TODO: support subtract.
+ return outerProductOp.emitError("unsupported kind");
+
+ auto maskableOp =
+ cast<vector::MaskableOpInterface>(outerProductOp.getOperation());
+ if (maskableOp.isMasked())
+ // TODO: support masking.
+ return outerProductOp.emitError("masking is currently unsupported");
+
+ if (!isa<VectorType>(outerProductOp.getOperandTypeRHS()))
+ // AXPY operation not suited for SME.
+ return failure();
+
+ auto loc = outerProductOp.getLoc();
+
+ Value acc = outerProductOp.getAcc();
+ if (!acc)
+ // Initalize accumulator with zero.
+ acc = rewriter.create<arm_sme::ZeroOp>(loc, resultVectorType);
+
+ unsigned elementWidth = resultVectorType.getElementTypeBitWidth();
+ auto tileId = rewriter.create<arm_sme::CastVectorToTile>(
+ loc, rewriter.getIntegerType(elementWidth), acc);
+
+ // Create all active predicate mask.
+ auto one = rewriter.create<arith::ConstantOp>(
----------------
dcaballe wrote:
We should use a ConstantMaskOp here
https://github.com/llvm/llvm-project/pull/65621
More information about the Mlir-commits
mailing list