[Mlir-commits] [mlir] [mlir][ArmSME] Support 2-way widening outer products (PR #78975)

Andrzej WarzyƄski llvmlistbot at llvm.org
Wed Jan 24 09:50:55 PST 2024


================
@@ -0,0 +1,238 @@
+//===- OuterProductWidening.cpp - Widen 'arm_sme.outerproduct' ops --------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements rewrites that fold 'arm_sme.outerproduct' operations
+// into the 2-way or 4-way widening outerproduct operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
+#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
+#include "mlir/Dialect/ArmSME/Transforms/Transforms.h"
+#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#define DEBUG_TYPE "arm-sme-outerproduct-widening"
+
+namespace mlir::arm_sme {
+#define GEN_PASS_DEF_OUTERPRODUCTWIDENING
+#include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc"
+} // namespace mlir::arm_sme
+
+using namespace mlir;
+using namespace mlir::arm_sme;
+
+namespace {
+// Fold two 'arm_sme.outerproduct' operations that are chained via the
+// accumulator into 2-way outer product operation.
+//
+// For example:
+//
+//  %a0_ext = arith.extf %a0 : vector<[4]xf16> to vector<[4]xf32>
+//  %b0_ext = arith.extf %b0 : vector<[4]xf16> to vector<[4]xf32>
+//  %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xf32>,
+//                                               vector<[4]xf32>
+//
+//  %a1_ext = arith.extf %a1 : vector<[4]xf16> to vector<[4]xf32>
+//  %b1_ext = arith.extf %b1 : vector<[4]xf16> to vector<[4]xf32>
+//  %1 = arm_sme.outerproduct %a1_ext, %b1_ext, %0 : vector<[4]xf32>,
+//                                                   vector<[4]xf32>
+//
+// Becomes:
+//
+//  %a_packed = arm_sve.zip %a0, %a1 : vector<[8]xf16> to vector<[8]xf16>
+//  %b_packed = arm_sve.zip %b0, %b1 : vector<[8]xf16> to vector<[8]xf16>
+//  %0 = arm_sme.fmopa_wide_2way %a_packed, %b_packed : vector<[8]xf16>,
+//                                                      vector<[4]xf32>
+class OuterProduct2WayWidening
+    : public OpRewritePattern<arm_sme::OuterProductOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(arm_sme::OuterProductOp op,
+                                PatternRewriter &rewriter) const override {
+    Value acc = op.getAcc();
+    if (!acc)
+      return rewriter.notifyMatchFailure(op, "no accumulator operand");
+
+    arm_sme::OuterProductOp op1 = acc.getDefiningOp<arm_sme::OuterProductOp>();
+    arm_sme::OuterProductOp op2 = op;
+    if (!op1)
+      return rewriter.notifyMatchFailure(op,
+                                         "defining op of accumulator operand "
+                                         "must be an 'arm_sme.outerproduct'");
+
+    if (op1.getKind() != op2.getKind())
+      return rewriter.notifyMatchFailure(
+          op, "combining kind (add or sub) of outer products must match");
+
+    if (!llvm::hasSingleElement(op1->getUses())) {
+      // We could still widen, but if the first outer product has an
+      // accumulator it will be used as the root for tile allocation and since
+      // the widening outer product uses the same accumulator it will get
+      // assigned the same tile ID, resulting in 3 outer products and incorrect
+      // results. No accumulator would be ok, but it's simpler to prevent this
+      // altogether, since it has no benefit.
----------------
banach-space wrote:

I think I know what you mean, but an example might simplify this. Something like (the types don't really matter for this discussion):
```mlir
// BEFORE
%a0_ext = arith.extf %a0
%b0_ext = arith.extf %b0
%0 = arm_sme.outerproduct %a0_ext, %b0_ext

%a1_ext = arith.extf %a1
%b1_ext = arith.extf %b1
%1 = arm_sme.outerproduct %a1_ext, %b1_ext, %0

%a2_ext = arith.extf %a2
%b2_ext = arith.extf %b2
%2 = arm_sme.outerproduct %a2_ext, %b2_ext, %0

// AFTER
%a_zip = arith.extf %a0, %a1
%b_zip = arith.extf %b0, %b1
%1 = arm_sme.outerproduct_2way %a_zip, %b_zip

%a2_ext = arith.extf %a2
%b2_ext = arith.extf %b2
%2 = arm_sme.outerproduct %a2_ext, %b2_ext, %1
```
Now, what exactly would go wrong in this case? :) 

Also: 
```suggestion
      // We could still widen, but since the first outer product has an
      // accumulator it will be used as the root for tile allocation and since
      // the widening outer product uses the same accumulator it will get
      // assigned the same tile ID, resulting in 3 outer products and incorrect
      // results. No accumulator would be ok, but it's simpler to prevent this
      // altogether, since it has no benefit.
```

The current wording reads to me as "if it happens that the first outer product has an acc". Where's it's more like "the first OP has an accumulator, hence". #english-not-my-first-lang :) 

https://github.com/llvm/llvm-project/pull/78975


More information about the Mlir-commits mailing list