[Mlir-commits] [mlir] [mlir][ArmSME] Support 2-way widening outer products (PR #78975)
Benjamin Maxwell
llvmlistbot at llvm.org
Tue Jan 30 07:56:41 PST 2024
================
@@ -0,0 +1,292 @@
+//===- OuterProductFusion.cpp - Fuse 'arm_sme.outerproduct' ops -----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements rewrites that fuse 'arm_sme.outerproduct' operations
+// into the 2-way or 4-way widening outerproduct operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
+#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
+#include "mlir/Dialect/ArmSME/Transforms/Transforms.h"
+#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+#define DEBUG_TYPE "arm-sme-outerproduct-fusion"
+
+namespace mlir::arm_sme {
+#define GEN_PASS_DEF_OUTERPRODUCTFUSION
+#include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc"
+} // namespace mlir::arm_sme
+
+using namespace mlir;
+using namespace mlir::arm_sme;
+
+namespace {
+// Fuse two 'arm_sme.outerproduct' operations that are chained via the
+// accumulator into 2-way outer product operation.
+//
+// For example:
+//
+// %a0_ext = arith.extf %a0 : vector<[4]xf16> to vector<[4]xf32>
+// %b0_ext = arith.extf %b0 : vector<[4]xf16> to vector<[4]xf32>
+// %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xf32>,
+// vector<[4]xf32>
+//
+// %a1_ext = arith.extf %a1 : vector<[4]xf16> to vector<[4]xf32>
+// %b1_ext = arith.extf %b1 : vector<[4]xf16> to vector<[4]xf32>
+// %1 = arm_sme.outerproduct %a1_ext, %b1_ext, %0 : vector<[4]xf32>,
+// vector<[4]xf32>
+//
+// Becomes:
+//
+// %a_packed = "llvm.intr.experimental.vector.interleave2"(%a0, %a1)
+// : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16>
+// %b_packed = "llvm.intr.experimental.vector.interleave2"(%b0, %b1)
+// : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16>
+// %0 = arm_sme.fmopa_2way %a_packed, %b_packed
+// : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
+class OuterProductFusion2Way
+ : public OpRewritePattern<arm_sme::OuterProductOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(arm_sme::OuterProductOp op,
+ PatternRewriter &rewriter) const override {
+ Value acc = op.getAcc();
+ if (!acc)
+ return rewriter.notifyMatchFailure(op, "no accumulator operand");
+
+ arm_sme::OuterProductOp op1 = acc.getDefiningOp<arm_sme::OuterProductOp>();
+ arm_sme::OuterProductOp op2 = op;
+ if (!op1)
+ return rewriter.notifyMatchFailure(op,
+ "defining op of accumulator operand "
+ "must be an 'arm_sme.outerproduct'");
+
+ if (op1.getKind() != op2.getKind())
+ return rewriter.notifyMatchFailure(
+ op, "combining kind (add or sub) of outer products must match");
+
+ if (!op1->hasOneUse()) {
+ // If the first outer product has uses other than as the input to another
+ // outer product, it can't be erased after fusion. This is a problem when
+ // it also has an accumulator as this will be used as the root for tile
+ // allocation and since the widening outer product uses the same
+ // accumulator it will get assigned the same tile ID, resulting in 3
+ // outer products accumulating to the same tile and incorrect results.
+ //
+ // Example:
+ //
+ // %acc = arith.constant dense<0.0> ; root for tile allocation
+ // %0 = arm_sme.outerproduct %a0, %b0 acc(%acc)
+ // vector.print %0 ; intermediary use, can't erase %0
+ // %1 = arm_sme.outerproduct %a1, %b1 acc(%0)
+ //
+ // After fusion and tile allocation
+ //
+ // %0 = arm_sme.zero {tile_id = 0 : i32}
+ // %1 = arm_sme.outerproduct %a0, %b0 acc(%0) {tile_id = 0 : i32}
+ // vector.print %1
+ // %2 = arm_sme.fmopa_2way %a, %b acc(%0) {tile_id = 0 : i32}
+ //
+ // No accumulator would be ok, but it's simpler to prevent this
+ // altogether, since it has no benefit.
+ return rewriter.notifyMatchFailure(
+ op, "first outer product is not single use and cannot be removed, "
+ "no benefit to fusing");
+ }
+
+ if (bool(op1.getLhsMask()) != bool(op2.getLhsMask()))
+ return rewriter.notifyMatchFailure(
+ op, "unsupported masking, either both outerproducts are masked "
+ "or neither");
+
+ if (failed(canFuseOuterProducts(rewriter, op1, op2)))
+ return failure();
+
+ auto loc = op.getLoc();
+
+ auto packInputs = [&](VectorType type, Value lhs, Value rhs) {
+ return rewriter.create<LLVM::experimental_vector_interleave2>(loc, type,
+ lhs, rhs);
+ };
+
+ auto extOp = op.getLhs().getDefiningOp();
+ VectorType extSourceVectorType =
+ cast<VectorType>(extOp->getOperand(0).getType());
+ VectorType widenedVectorType =
+ VectorType::Builder(extSourceVectorType)
+ .setDim(0, extSourceVectorType.getShape()[0] * 2);
----------------
MacDue wrote:
'widened' here threw me off, I always assumed widened referred to the element bit width
https://github.com/llvm/llvm-project/pull/78975
More information about the Mlir-commits
mailing list