[Mlir-commits] [mlir] [mlir][ArmSME] Support lowering masked vector.outerproduct ops to SME (PR #69604)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Thu Oct 26 01:55:58 PDT 2023
================
@@ -427,12 +427,88 @@ struct TransposeOpToArmSMELowering
}
};
+/// Conversion pattern for vector.outerproduct.
+///
+/// If the vector.outerproduct is masked (and the mask from a
+/// vector.create_mask), then the mask is decomposed into two 1-D masks for the
+/// operands.
+///
+/// Example:
+///
+/// %mask = vector.create_mask %dimA, %dimB : vector<[4]x[4]xi1>
+/// %result = vector.mask %mask {
+/// vector.outerproduct %vecA, %vecB
+/// : vector<[4]xf32>, vector<[4]xf32>
+/// } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
+///
+/// is converted to:
+///
+/// %maskA = vector.create_mask %dimA : vector<[4]xi1>
+/// %maskB = vector.create_mask %dimB : vector<[4]xi1>
+/// %result = arm_sme.outerproduct %vecA, %vecB masks(%maskA, %maskB)
+/// : vector<[4]xf32>, vector<[4]xf32>, vector<[4]x[4]xf32>
+///
+struct VectorOuterProductToArmSMELowering
+ : public OpRewritePattern<vector::OuterProductOp> {
+
+ using OpRewritePattern<vector::OuterProductOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::OuterProductOp outerProductOp,
+ PatternRewriter &rewriter) const override {
+ // AXPY operation not suited for SME.
+ if (!isa<VectorType>(outerProductOp.getOperandTypeRHS()))
+ return outerProductOp.emitError("AXPY operations not supported");
+
+ auto kind = outerProductOp.getKind();
+ if (kind != vector::CombiningKind::ADD)
+ return outerProductOp.emitError("unsupported kind");
+
+ Value lhsMask = {};
+ Value rhsMask = {};
+ Operation *rootOp = outerProductOp;
+ if (outerProductOp.isMasked()) {
+ auto maskingOp = outerProductOp.getMaskingOp();
+ rewriter.setInsertionPoint(maskingOp);
+ rootOp = maskingOp;
+
+ // Attempt to extract masks from vector.create_mask.
+ // TODO: Add support for other mask sources.
+ auto mask = maskingOp.getMask();
+ auto createMaskOp = mask.getDefiningOp<vector::CreateMaskOp>();
+ if (!createMaskOp)
+ return failure();
+
+ auto maskType = createMaskOp.getVectorType();
+ if (maskType.getRank() != 2)
+ return failure();
+
+ auto loc = outerProductOp.getLoc();
+
+ Value lhsMaskDim = createMaskOp.getOperand(0);
+ Value rhsMaskDim = createMaskOp.getOperand(1);
+
+ VectorType operandMaskType = VectorType::Builder(maskType).dropDim(0);
+ lhsMask = rewriter.create<vector::CreateMaskOp>(loc, operandMaskType,
+ lhsMaskDim);
+ rhsMask = rewriter.create<vector::CreateMaskOp>(loc, operandMaskType,
+ rhsMaskDim);
+ }
+
+ rewriter.replaceOpWithNewOp<arm_sme::OuterProductOp>(
+ rootOp, outerProductOp.getResultVectorType(), outerProductOp.getLhs(),
+ outerProductOp.getRhs(), lhsMask, rhsMask, outerProductOp.getAcc());
+
+ return success();
----------------
banach-space wrote:
Try to avoid big blocks of indentation: https://llvm.org/docs/CodingStandards.html#use-early-exits-and-continue-to-simplify-code
```suggestion
if (!outerProductOp.isMasked()) {
rewriter.replaceOpWithNewOp<arm_sme::OuterProductOp>(
rootOp, outerProductOp.getResultVectorType(), outerProductOp.getLhs(),
outerProductOp.getRhs(), lhsMask, rhsMask, outerProductOp.getAcc());
return success();
}
auto maskingOp = outerProductOp.getMaskingOp();
rewriter.setInsertionPoint(maskingOp);
rootOp = maskingOp;
// Attempt to extract masks from vector.create_mask.
// TODO: Add support for other mask sources.
auto mask = maskingOp.getMask();
auto createMaskOp = mask.getDefiningOp<vector::CreateMaskOp>();
if (!createMaskOp)
return failure();
auto maskType = createMaskOp.getVectorType();
if (maskType.getRank() != 2
return failure();
auto loc = outerProductOp.getLoc();
Value lhsMaskDim = createMaskOp.getOperand(0);
Value rhsMaskDim = createMaskOp.getOperand(1);
VectorType operandMaskType = VectorType::Builder(maskType).dropDim(0);
lhsMask = rewriter.create<vector::CreateMaskOp>(loc, operandMaskType,
lhsMaskDim);
rhsMask = rewriter.create<vector::CreateMaskOp>(loc, operandMaskType,
rhsMaskDim);
}
```
https://github.com/llvm/llvm-project/pull/69604
More information about the Mlir-commits
mailing list