[Mlir-commits] [mlir] [mlir][ArmSME] Support 4-way widening outer products (PR #79288)

Cullen Rhodes llvmlistbot at llvm.org
Fri Feb 2 08:50:25 PST 2024


================
@@ -225,37 +271,238 @@ class OuterProductFusion2Way
 
     return success();
   }
+};
+
+// Fuse four 'arm_sme.outerproduct' operations that are chained via the
+// accumulator into 4-way outer product operation.
+class OuterProductFusion4Way
+    : public OpRewritePattern<arm_sme::OuterProductOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
 
-  // An outer product is compatible if all of the following are true:
-  // - the result type matches `resultType`.
-  // - the defining operations of the inputs are identical and of the type
-  //   `ExtOp`.
-  // - the input types of the defining operations are identical and match
-  //   `inputType`.
-  template <typename ExtOp>
-  LogicalResult isCompatible(PatternRewriter &rewriter,
-                             arm_sme::OuterProductOp op, VectorType resultType,
-                             VectorType inputType) const {
-    if (op.getResultType() != resultType)
-      return rewriter.notifyMatchFailure(op.getLoc(), [&](Diagnostic &diag) {
-        diag << "unsupported result type, expected " << resultType;
-      });
-
-    auto lhsDefOp = op.getLhs().getDefiningOp<ExtOp>();
-    auto rhsDefOp = op.getRhs().getDefiningOp<ExtOp>();
-
-    if (!lhsDefOp || !rhsDefOp)
+  LogicalResult matchAndRewrite(arm_sme::OuterProductOp op,
+                                PatternRewriter &rewriter) const override {
+    Value acc = op.getAcc();
+    if (!acc)
+      return rewriter.notifyMatchFailure(op, MATCH_FAILURE_NO_ACCUMULATOR);
+
+    arm_sme::OuterProductOp op4 = op;
+    arm_sme::OuterProductOp op3 = acc.getDefiningOp<arm_sme::OuterProductOp>();
+    if (!op3)
       return rewriter.notifyMatchFailure(
-          op, "defining op of outerproduct operands must be one of: "
-              "'arith.extf' or 'arith.extsi' or 'arith.extui'");
+          op, MATCH_FAILURE_EXPECTED_OUTERPRODUCT_DEF_OP);
+
+    acc = op3.getAcc();
+    if (!acc)
+      return rewriter.notifyMatchFailure(op, MATCH_FAILURE_NO_ACCUMULATOR);
+
+    arm_sme::OuterProductOp op2 = acc.getDefiningOp<arm_sme::OuterProductOp>();
+    if (!op2)
+      return rewriter.notifyMatchFailure(
+          op, MATCH_FAILURE_EXPECTED_OUTERPRODUCT_DEF_OP);
+
+    acc = op2.getAcc();
+    if (!acc)
+      return rewriter.notifyMatchFailure(op, MATCH_FAILURE_NO_ACCUMULATOR);
+
+    arm_sme::OuterProductOp op1 = acc.getDefiningOp<arm_sme::OuterProductOp>();
+    if (!op1)
+      return rewriter.notifyMatchFailure(
+          op, MATCH_FAILURE_EXPECTED_OUTERPRODUCT_DEF_OP);
+
+    arm_sme::CombiningKind kind = op1.getKind();
+    if (op2.getKind() != kind || op3.getKind() != kind || op4.getKind() != kind)
+      return rewriter.notifyMatchFailure(
+          op, MATCH_FAILURE_INCONSISTENT_COMBINING_KIND);
+
+    if (!op1->hasOneUse() || !op2->hasOneUse() || !op3->hasOneUse())
+      return rewriter.notifyMatchFailure(
+          op, MATCH_FAILURE_OUTERPRODUCT_NOT_SINGLE_USE);
+
+    if (bool(op1.getLhsMask()) != bool(op2.getLhsMask()) !=
+        bool(op3.getLhsMask()) != bool(op4.getLhsMask()))
+      return rewriter.notifyMatchFailure(op,
+                                         MATCH_FAILURE_INCONSISTENT_MASKING);
----------------
c-rhodes wrote:

this is a lot cleaner, cheers!

https://github.com/llvm/llvm-project/pull/79288


More information about the Mlir-commits mailing list