[Mlir-commits] [mlir] [mlir][ArmSME] Support 4-way widening outer products (PR #79288)

Benjamin Maxwell llvmlistbot at llvm.org
Fri Feb 2 01:54:22 PST 2024


================
@@ -225,37 +271,238 @@ class OuterProductFusion2Way
 
     return success();
   }
+};
+
+// Fuse four 'arm_sme.outerproduct' operations that are chained via the
+// accumulator into 4-way outer product operation.
+class OuterProductFusion4Way
+    : public OpRewritePattern<arm_sme::OuterProductOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
 
-  // An outer product is compatible if all of the following are true:
-  // - the result type matches `resultType`.
-  // - the defining operations of the inputs are identical and of the type
-  //   `ExtOp`.
-  // - the input types of the defining operations are identical and match
-  //   `inputType`.
-  template <typename ExtOp>
-  LogicalResult isCompatible(PatternRewriter &rewriter,
-                             arm_sme::OuterProductOp op, VectorType resultType,
-                             VectorType inputType) const {
-    if (op.getResultType() != resultType)
-      return rewriter.notifyMatchFailure(op.getLoc(), [&](Diagnostic &diag) {
-        diag << "unsupported result type, expected " << resultType;
-      });
-
-    auto lhsDefOp = op.getLhs().getDefiningOp<ExtOp>();
-    auto rhsDefOp = op.getRhs().getDefiningOp<ExtOp>();
-
-    if (!lhsDefOp || !rhsDefOp)
+  LogicalResult matchAndRewrite(arm_sme::OuterProductOp op,
+                                PatternRewriter &rewriter) const override {
+    Value acc = op.getAcc();
+    if (!acc)
+      return rewriter.notifyMatchFailure(op, MATCH_FAILURE_NO_ACCUMULATOR);
+
+    arm_sme::OuterProductOp op4 = op;
+    arm_sme::OuterProductOp op3 = acc.getDefiningOp<arm_sme::OuterProductOp>();
+    if (!op3)
       return rewriter.notifyMatchFailure(
-          op, "defining op of outerproduct operands must be one of: "
-              "'arith.extf' or 'arith.extsi' or 'arith.extui'");
+          op, MATCH_FAILURE_EXPECTED_OUTERPRODUCT_DEF_OP);
+
+    acc = op3.getAcc();
+    if (!acc)
+      return rewriter.notifyMatchFailure(op, MATCH_FAILURE_NO_ACCUMULATOR);
+
+    arm_sme::OuterProductOp op2 = acc.getDefiningOp<arm_sme::OuterProductOp>();
+    if (!op2)
+      return rewriter.notifyMatchFailure(
+          op, MATCH_FAILURE_EXPECTED_OUTERPRODUCT_DEF_OP);
+
+    acc = op2.getAcc();
+    if (!acc)
+      return rewriter.notifyMatchFailure(op, MATCH_FAILURE_NO_ACCUMULATOR);
+
+    arm_sme::OuterProductOp op1 = acc.getDefiningOp<arm_sme::OuterProductOp>();
+    if (!op1)
+      return rewriter.notifyMatchFailure(
+          op, MATCH_FAILURE_EXPECTED_OUTERPRODUCT_DEF_OP);
+
+    arm_sme::CombiningKind kind = op1.getKind();
+    if (op2.getKind() != kind || op3.getKind() != kind || op4.getKind() != kind)
+      return rewriter.notifyMatchFailure(
+          op, MATCH_FAILURE_INCONSISTENT_COMBINING_KIND);
+
+    if (!op1->hasOneUse() || !op2->hasOneUse() || !op3->hasOneUse())
+      return rewriter.notifyMatchFailure(
+          op, MATCH_FAILURE_OUTERPRODUCT_NOT_SINGLE_USE);
+
+    if (bool(op1.getLhsMask()) != bool(op2.getLhsMask()) !=
+        bool(op3.getLhsMask()) != bool(op4.getLhsMask()))
+      return rewriter.notifyMatchFailure(op,
+                                         MATCH_FAILURE_INCONSISTENT_MASKING);
+
+    if (failed(canFuseOuterProducts(rewriter, op1, op2, op3, op4)))
+      return failure();
+
+    auto loc = op.getLoc();
+
+    auto packInputs = [&](Value lhs, Value rhs) {
+      auto inputType = cast<VectorType>(lhs.getType());
+      VectorType inputTypeX2 =
+          VectorType::Builder(inputType).setDim(0, inputType.getShape()[0] * 2);
+      return rewriter.create<LLVM::experimental_vector_interleave2>(
+          loc, inputTypeX2, lhs, rhs);
+    };
 
-    auto lhsInType = cast<VectorType>(lhsDefOp.getIn().getType());
-    auto rhsInType = cast<VectorType>(rhsDefOp.getIn().getType());
+    auto lhs0 = packInputs(op1.getLhs().getDefiningOp()->getOperand(0),
+                           op3.getLhs().getDefiningOp()->getOperand(0));
+    auto lhs1 = packInputs(op2.getLhs().getDefiningOp()->getOperand(0),
+                           op4.getLhs().getDefiningOp()->getOperand(0));
+    auto lhs = packInputs(lhs0, lhs1);
 
-    if (lhsInType != inputType || rhsInType != inputType)
-      return rewriter.notifyMatchFailure(op.getLoc(), [&](Diagnostic &diag) {
-        diag << "unsupported input type, expected " << inputType;
-      });
+    auto rhs0 = packInputs(op1.getRhs().getDefiningOp()->getOperand(0),
+                           op3.getRhs().getDefiningOp()->getOperand(0));
+    auto rhs1 = packInputs(op2.getRhs().getDefiningOp()->getOperand(0),
+                           op4.getRhs().getDefiningOp()->getOperand(0));
+    auto rhs = packInputs(rhs0, rhs1);
+
+    Value lhsMask, rhsMask;
+    if (op1.getLhsMask() || op2.getLhsMask() || op3.getLhsMask() ||
+        op4.getLhsMask()) {
+      auto lhs0Mask = packInputs(op1.getLhsMask(), op3.getLhsMask());
+      auto lhs1Mask = packInputs(op2.getLhsMask(), op4.getLhsMask());
+      lhsMask = packInputs(lhs0Mask, lhs1Mask);
+
+      auto rhs0Mask = packInputs(op1.getRhsMask(), op3.getRhsMask());
+      auto rhs1Mask = packInputs(op2.getRhsMask(), op4.getRhsMask());
+      rhsMask = packInputs(rhs0Mask, rhs1Mask);
+    }
+
+    auto lhsExtOp = op.getLhs().getDefiningOp();
+    auto rhsExtOp = op.getRhs().getDefiningOp();
+
+    if (kind == arm_sme::CombiningKind::Add) {
+      if (isa<arith::ExtSIOp>(lhsExtOp) && isa<arith::ExtSIOp>(rhsExtOp))
+        rewriter.replaceOpWithNewOp<arm_sme::SMopa4WayOp>(
+            op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
+      else if (isa<arith::ExtUIOp>(lhsExtOp) && isa<arith::ExtUIOp>(rhsExtOp))
+        rewriter.replaceOpWithNewOp<arm_sme::UMopa4WayOp>(
+            op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
+      else if (isa<arith::ExtSIOp>(lhsExtOp) && isa<arith::ExtUIOp>(rhsExtOp))
+        rewriter.replaceOpWithNewOp<arm_sme::SuMopa4WayOp>(
+            op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
+      else if (isa<arith::ExtUIOp>(lhsExtOp) && isa<arith::ExtSIOp>(rhsExtOp))
+        rewriter.replaceOpWithNewOp<arm_sme::UsMopa4WayOp>(
+            op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
+      else
+        llvm_unreachable("unexpected extend op!");
+    } else if (kind == arm_sme::CombiningKind::Sub) {
+      if (isa<arith::ExtSIOp>(lhsExtOp) && isa<arith::ExtSIOp>(rhsExtOp))
+        rewriter.replaceOpWithNewOp<arm_sme::SMops4WayOp>(
+            op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
+      else if (isa<arith::ExtUIOp>(lhsExtOp) && isa<arith::ExtUIOp>(rhsExtOp))
+        rewriter.replaceOpWithNewOp<arm_sme::UMops4WayOp>(
+            op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
+      else if (isa<arith::ExtSIOp>(lhsExtOp) && isa<arith::ExtUIOp>(rhsExtOp))
+        rewriter.replaceOpWithNewOp<arm_sme::SuMops4WayOp>(
+            op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
+      else if (isa<arith::ExtUIOp>(lhsExtOp) && isa<arith::ExtSIOp>(rhsExtOp))
+        rewriter.replaceOpWithNewOp<arm_sme::UsMops4WayOp>(
+            op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
+      else
+        llvm_unreachable("unexpected extend op!");
+    } else {
+      llvm_unreachable("unexpected arm_sme::CombiningKind!");
+    }
+
+    rewriter.eraseOp(op3);
+    rewriter.eraseOp(op2);
+    rewriter.eraseOp(op1);
+
+    return success();
+  }
+
+private:
+  // Four outer products can be fused if all of the following are true:
+  // - input and result types match.
+  // - the defining operations of the inputs are identical extensions,
+  //   specifically either:
+  //     - a signed or unsigned extension for integer types.
+  //     - a floating-point extension for floating-point types.
+  // - the types and extension are supported, i.e. there's a 4-way operation
+  //   they can be fused into.
+  LogicalResult canFuseOuterProducts(PatternRewriter &rewriter,
+                                     arm_sme::OuterProductOp op1,
+                                     arm_sme::OuterProductOp op2,
+                                     arm_sme::OuterProductOp op3,
+                                     arm_sme::OuterProductOp op4) const {
+    // Supported result types.
+    auto nxnxv4i32 =
+        VectorType::get({4, 4}, rewriter.getI32Type(), {true, true});
+    auto nxnxv2i64 =
+        VectorType::get({2, 2}, rewriter.getI64Type(), {true, true});
+    // Supported input types.
+    // Note: this is before packing so these have 1/4 the number of elements
+    // of the input vector types of the 4-way operations.
+    auto nxv4i8 = VectorType::get({4}, rewriter.getI8Type(), true);
+    auto nxv2i16 = VectorType::get({2}, rewriter.getI16Type(), true);
+    if (
+        // signed, i8i8i32
+        (failed(
+             isCompatible<arith::ExtSIOp>(rewriter, op1, nxnxv4i32, nxv4i8)) ||
+         failed(
+             isCompatible<arith::ExtSIOp>(rewriter, op2, nxnxv4i32, nxv4i8)) ||
+         failed(
+             isCompatible<arith::ExtSIOp>(rewriter, op3, nxnxv4i32, nxv4i8)) ||
+         failed(
+             isCompatible<arith::ExtSIOp>(rewriter, op4, nxnxv4i32, nxv4i8))) &&
----------------
MacDue wrote:

A lambda would avoid repeating the arguments again:
```c++
auto failedToMatch = [&](auto extendOp, VectorType resultType, VectorType inputType) {
        using ExtendOpTy = decltype(extendOp);
        return failed(
             isCompatible<ExtendOpTy>(rewriter, op1, resultType, inputType)) ||
         failed(
             isCompatible<ExtendOpTy>(rewriter, op2, resultType, inputType)) ||
         failed(
             isCompatible<ExtendOpTy>(rewriter, op3, resultType, inputType)) ||
         failed(
             isCompatible<ExtendOpTy>(rewriter, op4, resultType, inputType)));
};
```

// using:
```c++
failedToMatch(arith::ExtUIOp{}, nxnxv2i64, nxv2i16)
```

(passing the null `arith::ExtUIOp{}` is a trick to use it as an implicit template parameter, since template proper template lambdas need C++20) 

https://github.com/llvm/llvm-project/pull/79288


More information about the Mlir-commits mailing list