[Mlir-commits] [mlir] [mlir][ArmSME] Support 4-way widening outer products (PR #79288)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Tue Feb 6 06:54:33 PST 2024
================
@@ -225,37 +271,173 @@ class OuterProductFusion2Way
return success();
}
+};
- // An outer product is compatible if all of the following are true:
- // - the result type matches `resultType`.
- // - the defining operations of the inputs are identical and of the type
- // `ExtOp`.
- // - the input types of the defining operations are identical and match
- // `inputType`.
- template <typename ExtOp>
- LogicalResult isCompatible(PatternRewriter &rewriter,
- arm_sme::OuterProductOp op, VectorType resultType,
- VectorType inputType) const {
- if (op.getResultType() != resultType)
- return rewriter.notifyMatchFailure(op.getLoc(), [&](Diagnostic &diag) {
- diag << "unsupported result type, expected " << resultType;
- });
-
- auto lhsDefOp = op.getLhs().getDefiningOp<ExtOp>();
- auto rhsDefOp = op.getRhs().getDefiningOp<ExtOp>();
-
- if (!lhsDefOp || !rhsDefOp)
- return rewriter.notifyMatchFailure(
- op, "defining op of outerproduct operands must be one of: "
- "'arith.extf' or 'arith.extsi' or 'arith.extui'");
+// Fuse four 'arm_sme.outerproduct' operations that are chained via the
+// accumulator into 4-way outer product operation.
+class OuterProductFusion4Way
+ : public OpRewritePattern<arm_sme::OuterProductOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
- auto lhsInType = cast<VectorType>(lhsDefOp.getIn().getType());
- auto rhsInType = cast<VectorType>(rhsDefOp.getIn().getType());
+ LogicalResult matchAndRewrite(arm_sme::OuterProductOp op,
+ PatternRewriter &rewriter) const override {
+ SmallVector<arm_sme::OuterProductOp, 4> outerProductChain;
+ outerProductChain.push_back(op);
+
+ for (int i = 0; i < 3; ++i) {
+ auto currentOp = outerProductChain.back();
+ auto acc = currentOp.getAcc();
+ if (!acc)
+ return rewriter.notifyMatchFailure(op, MATCH_FAILURE_NO_ACCUMULATOR);
+ auto previousOp = acc.getDefiningOp<arm_sme::OuterProductOp>();
+ if (!previousOp)
+ return rewriter.notifyMatchFailure(
+ op, MATCH_FAILURE_EXPECTED_OUTERPRODUCT_DEF_OP);
+ if (!previousOp->hasOneUse())
+ return rewriter.notifyMatchFailure(
+ op, MATCH_FAILURE_OUTERPRODUCT_NOT_SINGLE_USE);
+ if (previousOp.getKind() != currentOp.getKind())
+ return rewriter.notifyMatchFailure(
+ op, MATCH_FAILURE_INCONSISTENT_COMBINING_KIND);
+ if (bool(previousOp.getLhsMask()) != bool(currentOp.getLhsMask()))
+ return rewriter.notifyMatchFailure(
+ op, MATCH_FAILURE_INCONSISTENT_COMBINING_KIND);
+ outerProductChain.push_back(previousOp);
+ }
- if (lhsInType != inputType || rhsInType != inputType)
- return rewriter.notifyMatchFailure(op.getLoc(), [&](Diagnostic &diag) {
- diag << "unsupported input type, expected " << inputType;
- });
+ if (failed(canFuseOuterProducts(rewriter, outerProductChain)))
+ return failure();
+
+ arm_sme::OuterProductOp op1 = outerProductChain[3];
+ arm_sme::OuterProductOp op2 = outerProductChain[2];
+ arm_sme::OuterProductOp op3 = outerProductChain[1];
+ arm_sme::OuterProductOp op4 = outerProductChain[0];
+
+ auto loc = op.getLoc();
+
+ auto packInputs = [&](Value lhs, Value rhs) {
+ auto inputType = cast<VectorType>(lhs.getType());
+ VectorType inputTypeX2 =
+ VectorType::Builder(inputType).setDim(0, inputType.getShape()[0] * 2);
+ return rewriter.create<LLVM::experimental_vector_interleave2>(
+ loc, inputTypeX2, lhs, rhs);
+ };
+
+ auto lhs0 = packInputs(op1.getLhs().getDefiningOp()->getOperand(0),
+ op3.getLhs().getDefiningOp()->getOperand(0));
+ auto lhs1 = packInputs(op2.getLhs().getDefiningOp()->getOperand(0),
+ op4.getLhs().getDefiningOp()->getOperand(0));
+ auto lhs = packInputs(lhs0, lhs1);
+
+ auto rhs0 = packInputs(op1.getRhs().getDefiningOp()->getOperand(0),
+ op3.getRhs().getDefiningOp()->getOperand(0));
+ auto rhs1 = packInputs(op2.getRhs().getDefiningOp()->getOperand(0),
+ op4.getRhs().getDefiningOp()->getOperand(0));
+ auto rhs = packInputs(rhs0, rhs1);
+
+ Value lhsMask, rhsMask;
+ if (op1.getLhsMask() || op2.getLhsMask() || op3.getLhsMask() ||
+ op4.getLhsMask()) {
+ auto lhs0Mask = packInputs(op1.getLhsMask(), op3.getLhsMask());
+ auto lhs1Mask = packInputs(op2.getLhsMask(), op4.getLhsMask());
+ lhsMask = packInputs(lhs0Mask, lhs1Mask);
+
+ auto rhs0Mask = packInputs(op1.getRhsMask(), op3.getRhsMask());
+ auto rhs1Mask = packInputs(op2.getRhsMask(), op4.getRhsMask());
+ rhsMask = packInputs(rhs0Mask, rhs1Mask);
+ }
+
+ auto lhsExtOp = op.getLhs().getDefiningOp();
+ auto rhsExtOp = op.getRhs().getDefiningOp();
+
+ arm_sme::CombiningKind kind = op.getKind();
+ if (kind == arm_sme::CombiningKind::Add) {
+ if (isa<arith::ExtSIOp>(lhsExtOp) && isa<arith::ExtSIOp>(rhsExtOp))
+ rewriter.replaceOpWithNewOp<arm_sme::SMopa4WayOp>(
+ op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
+ else if (isa<arith::ExtUIOp>(lhsExtOp) && isa<arith::ExtUIOp>(rhsExtOp))
+ rewriter.replaceOpWithNewOp<arm_sme::UMopa4WayOp>(
+ op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
+ else if (isa<arith::ExtSIOp>(lhsExtOp) && isa<arith::ExtUIOp>(rhsExtOp))
+ rewriter.replaceOpWithNewOp<arm_sme::SuMopa4WayOp>(
+ op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
+ else if (isa<arith::ExtUIOp>(lhsExtOp) && isa<arith::ExtSIOp>(rhsExtOp))
+ rewriter.replaceOpWithNewOp<arm_sme::UsMopa4WayOp>(
+ op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
+ else
+ llvm_unreachable("unexpected extend op!");
+ } else if (kind == arm_sme::CombiningKind::Sub) {
+ if (isa<arith::ExtSIOp>(lhsExtOp) && isa<arith::ExtSIOp>(rhsExtOp))
+ rewriter.replaceOpWithNewOp<arm_sme::SMops4WayOp>(
+ op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
+ else if (isa<arith::ExtUIOp>(lhsExtOp) && isa<arith::ExtUIOp>(rhsExtOp))
+ rewriter.replaceOpWithNewOp<arm_sme::UMops4WayOp>(
+ op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
+ else if (isa<arith::ExtSIOp>(lhsExtOp) && isa<arith::ExtUIOp>(rhsExtOp))
+ rewriter.replaceOpWithNewOp<arm_sme::SuMops4WayOp>(
+ op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
+ else if (isa<arith::ExtUIOp>(lhsExtOp) && isa<arith::ExtSIOp>(rhsExtOp))
+ rewriter.replaceOpWithNewOp<arm_sme::UsMops4WayOp>(
+ op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
+ else
+ llvm_unreachable("unexpected extend op!");
+ } else {
+ llvm_unreachable("unexpected arm_sme::CombiningKind!");
+ }
----------------
banach-space wrote:
Missed that! Not sure how to improve this then :/
https://github.com/llvm/llvm-project/pull/79288
More information about the Mlir-commits
mailing list