[Mlir-commits] [mlir] [mlir][x86vector] Shuffle FMAs (PR #172823)
Arun Thangamani
llvmlistbot at llvm.org
Sun Dec 21 23:57:31 PST 2025
================
@@ -0,0 +1,182 @@
+//===- ShuffleVectorFMAOps.cpp --------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/X86Vector/Transforms.h"
+#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
+
+#include "mlir/IR/PatternMatch.h"
+
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace mlir::vector;
+using namespace mlir::x86vector;
+
+namespace {
+
+// Validates whether the given operation is an x86vector operation and has only
+// one consumer.
+static bool validateX86OpsHasOneUser(Value op) {
+ if (auto cvt = op.getDefiningOp<x86vector::CvtPackedEvenIndexedToF32Op>())
+ return cvt.getResult().hasOneUse();
+
+ if (auto bcst = op.getDefiningOp<x86vector::BcstToPackedF32Op>())
+ return bcst.getResult().hasOneUse();
+
+ return false;
+}
+
+// Validates the vector.fma operation on the following conditions:
+// (i) one of the lhs or rhs defining operation should be
+// CvtPackedEvenIndexedToF32Op, (ii) the lhs or rhs defining operation should be
+// an x86vector operation and has only one consumer, (iii) all oerations in same
+// block, and (iv) ths FMA has only one user.
+static bool validateVectorFMAOp(vector::FMAOp fmaOp) {
+ Value lhs = fmaOp.getLhs();
+ Value rhs = fmaOp.getRhs();
+
+ if (!isa<x86vector::CvtPackedEvenIndexedToF32Op>(lhs.getDefiningOp()) &&
+ !isa<x86vector::CvtPackedEvenIndexedToF32Op>(rhs.getDefiningOp()))
+ return false;
+
+ if (!validateX86OpsHasOneUser(lhs) || !validateX86OpsHasOneUser(rhs))
+ return false;
+
+ if (lhs.getDefiningOp()->getBlock() != rhs.getDefiningOp()->getBlock())
+ return false;
+
+ if (lhs.getDefiningOp()->getBlock() != fmaOp->getBlock())
+ return false;
+
+ if (!fmaOp.getResult().hasOneUse())
+ return false;
+
+ Operation *consumer = *fmaOp.getResult().getUsers().begin();
+ if (consumer->getBlock() != fmaOp->getBlock())
+ return false;
+
+ return true;
+}
+
+// Moves vector.fma along with the lhs and rhs defining operation before it's
+// comsumer. If the consumer is vector.ShapeCastOp and has only one user then
+// move before the consumer of vector.ShapeCastOp.
+// TODO: Move before first consumer, if there are multiple.
+static void moveFMA(vector::FMAOp fmaOp) {
+ Operation *consumer = *fmaOp.getResult().getUsers().begin();
+
+ if (auto shapeCastOp = dyn_cast<vector::ShapeCastOp>(consumer)) {
+ if (shapeCastOp.getResult().hasOneUse()) {
+ Operation *nxtConsumer = *shapeCastOp.getResult().getUsers().begin();
+ if (nxtConsumer->getBlock() == fmaOp->getBlock()) {
+ consumer = *shapeCastOp.getResult().getUsers().begin();
+ fmaOp.getLhs().getDefiningOp()->moveBefore(consumer);
+ fmaOp.getRhs().getDefiningOp()->moveBefore(consumer);
+ fmaOp->moveBefore(consumer);
+ shapeCastOp->moveBefore(consumer);
+ return;
+ }
+ }
+ }
+
+ fmaOp.getLhs().getDefiningOp()->moveBefore(consumer);
+ fmaOp.getRhs().getDefiningOp()->moveBefore(consumer);
+ fmaOp->moveBefore(consumer);
+ return;
+}
+
+// Shuffle FMAs with x86vector operations as operands such that
+// FMAs are grouped with respect to odd/even packed index.
+//
+// For example:
+// ```
+// %1 = x86vector.avx.bcst_to_f32.packed
+// %2 = x86vector.avx.cvt.packed.odd.indexed_to_f32
+// %3 = vector.fma %1, %2, %arg1
+// %4 = x86vector.avx.bcst_to_f32.packed
+// %5 = x86vector.avx.cvt.packed.even.indexed_to_f32
+// %6 = vector.fma %4, %5, %3
+// %7 = x86vector.avx.bcst_to_f32.packed
+// %8 = x86vector.avx.cvt.packed.odd.indexed_to_f32
+// %9 = vector.fma %7, %8, %arg2
+// %10 = x86vector.avx.bcst_to_f32.packed
+// %11 = x86vector.avx.cvt.packed.even.indexed_to_f32
+// %12 = vector.fma %10, %11, %9
+// yield %6, %12
+// ```
+// to
+// ```
+// %1 = x86vector.avx.bcst_to_f32.packed
+// %2 = x86vector.avx.cvt.packed.odd.indexed_to_f32
+// %3 = vector.fma %1, %2, %arg1
+// %4 = x86vector.avx.bcst_to_f32.packed
+// %5 = x86vector.avx.cvt.packed.odd.indexed_to_f32
+// %6 = vector.fma %4, %5, %arg2
+// %7 = x86vector.avx.bcst_to_f32.packed
+// %8 = x86vector.avx.cvt.packed.even.indexed_to_f32
+// %9 = vector.fma %7, %8, %3
+// %10 = x86vector.avx.bcst_to_f32.packed
+// %11 = x86vector.avx.cvt.packed.even.indexed_to_f32
+// %12 = vector.fma %10, %11, %6
+// yield %9, %12
+// ```
+// TODO: Shuffling supported only if the FMA, lhs/rhs defining operations
+// have only one consumer. Have to extend this pass for multiple consumers.
+struct ShuffleVectorFMAOps : public OpRewritePattern<vector::FMAOp> {
+ using OpRewritePattern<vector::FMAOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::FMAOp fmaOp,
+ PatternRewriter &rewriter) const override {
+
+ if (!validateVectorFMAOp(fmaOp))
+ return failure();
+
+ llvm::SmallVector<vector::FMAOp> fmaOps;
----------------
arun-thmn wrote:
This we followed the logic of our earlier PR: https://github.com/llvm/llvm-project/pull/169333.
In the first pass, we capture all `fma` and order them. It should take O(n) for best and average case. The worst case is O(n2). From the next pass (reapply), it will break immediately on reaching the next FMA.
https://github.com/llvm/llvm-project/pull/172823
More information about the Mlir-commits
mailing list