[Mlir-commits] [mlir] 64ecd76 - [mlir][x86vector] Shuffle FMAs (#172823)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jan 20 03:51:30 PST 2026
Author: Arun Thangamani
Date: 2026-01-20T17:21:26+05:30
New Revision: 64ecd762e9518ca1bc9f0764411313578341d028
URL: https://github.com/llvm/llvm-project/commit/64ecd762e9518ca1bc9f0764411313578341d028
DIFF: https://github.com/llvm/llvm-project/commit/64ecd762e9518ca1bc9f0764411313578341d028.diff
LOG: [mlir][x86vector] Shuffle FMAs (#172823)
This patch Shuffles FMAs with x86vector operations as operands such that
FMAs are grouped with respect to odd/even packed index.
Continuation to PR: https://github.com/llvm/llvm-project/pull/170267 to
manage register allocation efficiently.
Added:
mlir/lib/Dialect/X86Vector/Transforms/ShuffleVectorFMAOps.cpp
mlir/test/Dialect/X86Vector/shuffle-vector-fmas.mlir
Modified:
mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
mlir/include/mlir/Dialect/X86Vector/Transforms.h
mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
index 3c73eadf82167..891829fca017f 100644
--- a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
@@ -60,6 +60,17 @@ def ApplySinkVectorProducerOpsPatternsOp : Op<Transform_Dialect,
let assemblyFormat = "attr-dict";
}
+def ApplyShuffleVectorFMAOpsPatternsOp : Op<Transform_Dialect,
+ "apply_patterns.x86vector.shuffle_vector_fma_ops",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ let description = [{
+ Collect patterns to shuffle FMAs with x86vector operations as operands
+ such that FMAs are grouped with respect to odd/even packed index.
+ }];
+
+ let assemblyFormat = "attr-dict";
+}
+
#endif // X86VECTOR_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/X86Vector/Transforms.h b/mlir/include/mlir/Dialect/X86Vector/Transforms.h
index c25cdaf2d9428..aadca92708908 100644
--- a/mlir/include/mlir/Dialect/X86Vector/Transforms.h
+++ b/mlir/include/mlir/Dialect/X86Vector/Transforms.h
@@ -100,6 +100,10 @@ void populateVectorContractBF16ToFMAPatterns(RewritePatternSet &patterns);
// range by placing them at their earliest legal use site.
void populateSinkVectorProducerOpsPatterns(RewritePatternSet &patterns);
+// Shuffles FMAs with x86vector operations as operands such that FMAs are
+// grouped with respect to odd/even packed index.
+void populateShuffleVectorFMAOpsPatterns(RewritePatternSet &patterns);
+
//===----------------------------------------------------------------------===//
/// Helpers extracted from:
/// - clang/lib/Headers/avxintrin.h
diff --git a/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
index e77d30c9c5ffb..c6be69305da50 100644
--- a/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
@@ -42,6 +42,11 @@ void mlir::transform::ApplySinkVectorProducerOpsPatternsOp::populatePatterns(
x86vector::populateSinkVectorProducerOpsPatterns(patterns);
}
+void mlir::transform::ApplyShuffleVectorFMAOpsPatternsOp::populatePatterns(
+ RewritePatternSet &patterns) {
+ x86vector::populateShuffleVectorFMAOpsPatterns(patterns);
+}
+
//===----------------------------------------------------------------------===//
// Transform op registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
index bbd9be880eb0a..01d2ec4810e29 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
@@ -5,6 +5,7 @@ add_mlir_dialect_library(MLIRX86VectorTransforms
VectorContractToPackedTypeDotProduct.cpp
VectorContractBF16ToFMA.cpp
SinkVectorProducerOps.cpp
+ ShuffleVectorFMAOps.cpp
LINK_LIBS PUBLIC
MLIRArithDialect
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/ShuffleVectorFMAOps.cpp b/mlir/lib/Dialect/X86Vector/Transforms/ShuffleVectorFMAOps.cpp
new file mode 100644
index 0000000000000..a66546a5d1e45
--- /dev/null
+++ b/mlir/lib/Dialect/X86Vector/Transforms/ShuffleVectorFMAOps.cpp
@@ -0,0 +1,186 @@
+//===- ShuffleVectorFMAOps.cpp --------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/X86Vector/Transforms.h"
+#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
+
+#include "mlir/IR/PatternMatch.h"
+
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace mlir::vector;
+using namespace mlir::x86vector;
+
+namespace {
+
+// Validates whether the given operation is an x86vector operation and has only
+// one consumer.
+static bool validateFMAOperands(Value op) {
+ if (auto cvt = op.getDefiningOp<x86vector::CvtPackedEvenIndexedToF32Op>())
+ return cvt.getResult().hasOneUse();
+
+ if (auto bcst = op.getDefiningOp<x86vector::BcstToPackedF32Op>())
+ return bcst.getResult().hasOneUse();
+
+ return false;
+}
+
+// Validates the vector.fma operation on the following conditions:
+// (i) one of the lhs or rhs defining operation should be
+// CvtPackedEvenIndexedToF32Op, (ii) the lhs or rhs defining operation should be
+// an x86vector operation and has only one consumer, (iii) all operations
+// are in the same block, and (iv) ths FMA has only one user.
+static bool validateVectorFMAOp(vector::FMAOp fmaOp) {
+ Value lhs = fmaOp.getLhs();
+ Value rhs = fmaOp.getRhs();
+
+ if (!isa<x86vector::CvtPackedEvenIndexedToF32Op>(lhs.getDefiningOp()) &&
+ !isa<x86vector::CvtPackedEvenIndexedToF32Op>(rhs.getDefiningOp()))
+ return false;
+
+ if (!validateFMAOperands(lhs) || !validateFMAOperands(rhs))
+ return false;
+
+ if (lhs.getDefiningOp()->getBlock() != rhs.getDefiningOp()->getBlock())
+ return false;
+
+ if (lhs.getDefiningOp()->getBlock() != fmaOp->getBlock())
+ return false;
+
+ if (!fmaOp.getResult().hasOneUse())
+ return false;
+
+ Operation *consumer = *fmaOp.getResult().getUsers().begin();
+ if (consumer->getBlock() != fmaOp->getBlock())
+ return false;
+
+ return true;
+}
+
+// Moves vector.fma along with the lhs and rhs defining operation before its
+// consumer. If the consumer is vector.ShapeCastOp and has only one user then
+// move before the consumer of vector.ShapeCastOp.
+// TODO: Move before first consumer, if there are multiple.
+static void moveFMA(PatternRewriter &rewriter, vector::FMAOp fmaOp) {
+ Operation *consumer = *fmaOp.getResult().getUsers().begin();
+
+ if (auto shapeCastOp = dyn_cast<vector::ShapeCastOp>(consumer)) {
+ if (shapeCastOp.getResult().hasOneUse()) {
+ Operation *nxtConsumer = *shapeCastOp.getResult().getUsers().begin();
+ if (nxtConsumer->getBlock() == fmaOp->getBlock()) {
+ consumer = *shapeCastOp.getResult().getUsers().begin();
+ rewriter.moveOpBefore(fmaOp.getLhs().getDefiningOp(), consumer);
+ rewriter.moveOpBefore(fmaOp.getRhs().getDefiningOp(), consumer);
+ rewriter.moveOpBefore(fmaOp.getOperation(), consumer);
+ rewriter.moveOpBefore(shapeCastOp.getOperation(), consumer);
+ return;
+ }
+ }
+ }
+
+ rewriter.moveOpBefore(fmaOp.getLhs().getDefiningOp(), consumer);
+ rewriter.moveOpBefore(fmaOp.getRhs().getDefiningOp(), consumer);
+ rewriter.moveOpBefore(fmaOp.getOperation(), consumer);
+
+ return;
+}
+
+// Shuffle FMAs with x86vector operations as operands such that
+// FMAs are grouped with respect to odd/even packed index.
+//
+// For example:
+// ```
+// %1 = x86vector.avx.bcst_to_f32.packed
+// %2 = x86vector.avx.cvt.packed.odd.indexed_to_f32
+// %3 = vector.fma %1, %2, %arg1
+// %4 = x86vector.avx.bcst_to_f32.packed
+// %5 = x86vector.avx.cvt.packed.even.indexed_to_f32
+// %6 = vector.fma %4, %5, %3
+// %7 = x86vector.avx.bcst_to_f32.packed
+// %8 = x86vector.avx.cvt.packed.odd.indexed_to_f32
+// %9 = vector.fma %7, %8, %arg2
+// %10 = x86vector.avx.bcst_to_f32.packed
+// %11 = x86vector.avx.cvt.packed.even.indexed_to_f32
+// %12 = vector.fma %10, %11, %9
+// yield %6, %12
+// ```
+// to
+// ```
+// %1 = x86vector.avx.bcst_to_f32.packed
+// %2 = x86vector.avx.cvt.packed.odd.indexed_to_f32
+// %3 = vector.fma %1, %2, %arg1
+// %7 = x86vector.avx.bcst_to_f32.packed
+// %8 = x86vector.avx.cvt.packed.odd.indexed_to_f32
+// %9 = vector.fma %7, %8, %arg2
+// %4 = x86vector.avx.bcst_to_f32.packed
+// %5 = x86vector.avx.cvt.packed.even.indexed_to_f32
+// %6 = vector.fma %4, %5, %3
+// %10 = x86vector.avx.bcst_to_f32.packed
+// %11 = x86vector.avx.cvt.packed.even.indexed_to_f32
+// %12 = vector.fma %10, %11, %9
+// yield %9, %12
+// ```
+// TODO: Shuffling supported only if the FMA, lhs/rhs defining operations
+// have only one consumer. Have to extend this pass for multiple consumers.
+struct ShuffleVectorFMAOps : public OpRewritePattern<vector::FMAOp> {
+ using OpRewritePattern<vector::FMAOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::FMAOp fmaOp,
+ PatternRewriter &rewriter) const override {
+
+ if (!validateVectorFMAOp(fmaOp))
+ return failure();
+
+ llvm::SmallVector<vector::FMAOp> fmaOps;
+ Operation *nextOp = fmaOp;
+ bool stopAtNextDependentFMA = true;
+
+ // Break the loop and return failure if the immediate next FMA op
+ // have CvtPackedEvenIndexedToF32Op in it's lhs/rhs defining ops.
+ while ((nextOp = nextOp->getNextNode())) {
+ auto fma = dyn_cast<vector::FMAOp>(nextOp);
+ if (!fma)
+ continue;
+
+ bool hasX86CvtOperand = isa<x86vector::CvtPackedEvenIndexedToF32Op>(
+ fma.getLhs().getDefiningOp()) ||
+ isa<x86vector::CvtPackedEvenIndexedToF32Op>(
+ fma.getRhs().getDefiningOp());
+
+ if (hasX86CvtOperand && stopAtNextDependentFMA)
+ break;
+
+ if (validateVectorFMAOp(fma))
+ fmaOps.push_back(fma);
+
+ stopAtNextDependentFMA = false;
+ }
+
+ if (fmaOps.empty())
+ return rewriter.notifyMatchFailure(
+ fmaOp, "No eligible FMA operations were found: the operation may "
+ "already be shuffled, there may be no following FMAs, or the "
+ "following FMAs do not satisfy the shuffle conditions.");
+
+ fmaOps.push_back(fmaOp);
+ for (auto fmaOp : fmaOps)
+ moveFMA(rewriter, fmaOp);
+
+ return success();
+ }
+};
+
+} // namespace
+
+void x86vector::populateShuffleVectorFMAOpsPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<ShuffleVectorFMAOps>(patterns.getContext());
+}
diff --git a/mlir/test/Dialect/X86Vector/shuffle-vector-fmas.mlir b/mlir/test/Dialect/X86Vector/shuffle-vector-fmas.mlir
new file mode 100644
index 0000000000000..4bf930b51c0c2
--- /dev/null
+++ b/mlir/test/Dialect/X86Vector/shuffle-vector-fmas.mlir
@@ -0,0 +1,312 @@
+// RUN: mlir-opt %s -transform-interpreter -cse -split-input-file | FileCheck %s
+
+!vec = vector<8xf32>
+!memrefA = memref<1x1x1xbf16>
+!memrefB = memref<1x8x2xbf16>
+
+func.func @shuffle_fma_with_rhs_as_even.index_to_f32(
+ %arg0: !memrefA, %arg1: !memrefA, %arg2: !memrefB, %arg3: !memrefA,
+ %arg4: !memrefA, %arg5: !memrefB, %arg6: !vec) -> !vec
+{
+ %0 = x86vector.avx.bcst_to_f32.packed %arg0 : !memrefA -> !vec
+ %1 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg2 : !memrefB -> !vec
+ %2 = vector.fma %0, %1, %arg6 : !vec
+ %3 = x86vector.avx.bcst_to_f32.packed %arg1 : !memrefA -> !vec
+ %4 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg2 : !memrefB -> !vec
+ %5 = vector.fma %3, %4, %2 : !vec
+ %6 = x86vector.avx.bcst_to_f32.packed %arg3 : !memrefA -> !vec
+ %7 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg5 : !memrefB -> !vec
+ %8 = vector.fma %6, %7, %arg6 : !vec
+ %9 = x86vector.avx.bcst_to_f32.packed %arg4 : !memrefA -> !vec
+ %10 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg5 : !memrefB -> !vec
+ %11 = vector.fma %9, %10, %8 : !vec
+ %12 = vector.fma %5, %11, %arg6 : !vec
+ return %12 : !vec
+}
+
+// Groups FMAs with respect to even/odd indexed input operands.
+// The vector.fma at %5 is moved along with its operands after %8.
+// CHECK-LABEL: @shuffle_fma_with_rhs_as_even.index_to_f32
+// Odd-Indexed FMAs
+// CHECK: %[[BCST0:.*]] = x86vector.avx.bcst_to_f32.packed %arg0
+// CHECK: %[[ODD0:.*]] = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg2
+// CHECK: %[[FMA_ODD0:.*]] = vector.fma %[[BCST0]], %[[ODD0]], %arg6
+// CHECK: %[[BCST1:.*]] = x86vector.avx.bcst_to_f32.packed %arg3
+// CHECK: %[[ODD1:.*]] = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg5
+// CHECK: %[[FMA_ODD1:.*]] = vector.fma %[[BCST1]], %[[ODD1]], %arg6
+// Even-Indexed FMAs
+// CHECK: %[[BCST2:.*]] = x86vector.avx.bcst_to_f32.packed %arg4
+// CHECK: %[[EVEN0:.*]] = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg5
+// CHECK: %[[FMA_EVEN0:.*]] = vector.fma %[[BCST2]], %[[EVEN0]], %[[FMA_ODD1]]
+// CHECK: %[[BCST3:.*]] = x86vector.avx.bcst_to_f32.packed %arg1
+// CHECK: %[[EVEN1:.*]] = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg2
+// CHECK: %[[FMA_EVEN1:.*]] = vector.fma %[[BCST3]], %[[EVEN1]], %[[FMA_ODD0]]
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %0 {
+ transform.apply_patterns.x86vector.shuffle_vector_fma_ops
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vec = vector<8xf32>
+!memrefA = memref<1x1x1xbf16>
+!memrefB = memref<1x8x2xbf16>
+
+func.func @shuffle_fma_with_lhs_as_even.index_to_f32(
+ %arg0: !memrefA, %arg1: !memrefA, %arg2: !memrefB, %arg3: !memrefA,
+ %arg4: !memrefA, %arg5: !memrefB, %arg6: !vec) -> !vec
+{
+ %0 = x86vector.avx.bcst_to_f32.packed %arg0 : !memrefA -> !vec
+ %1 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg2 : !memrefB -> !vec
+ %2 = vector.fma %0, %1, %arg6 : !vec
+ %3 = x86vector.avx.bcst_to_f32.packed %arg1 : !memrefA -> !vec
+ %4 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg2 : !memrefB -> !vec
+ %5 = vector.fma %4, %3, %2 : !vec
+ %6 = x86vector.avx.bcst_to_f32.packed %arg3 : !memrefA -> !vec
+ %7 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg5 : !memrefB -> !vec
+ %8 = vector.fma %6, %7, %arg6 : !vec
+ %9 = x86vector.avx.bcst_to_f32.packed %arg4 : !memrefA -> !vec
+ %10 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg5 : !memrefB -> !vec
+ %11 = vector.fma %9, %10, %8 : !vec
+ %12 = vector.fma %5, %11, %arg6 : !vec
+ return %12 : !vec
+}
+
+// The vector.fma at %5 is moved along with its operands after %8.
+// CHECK-LABEL: @shuffle_fma_with_lhs_as_even.index_to_f32
+// Odd-Indexed FMAs
+// CHECK: %[[BCST0:.*]] = x86vector.avx.bcst_to_f32.packed %arg0
+// CHECK: %[[ODD0:.*]] = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg2
+// CHECK: %[[FMA_ODD0:.*]] = vector.fma %[[BCST0]], %[[ODD0]], %arg6
+// CHECK: %[[BCST1:.*]] = x86vector.avx.bcst_to_f32.packed %arg3
+// CHECK: %[[ODD1:.*]] = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg5
+// CHECK: %[[FMA_ODD1:.*]] = vector.fma %[[BCST1]], %[[ODD1]], %arg6
+// Even-Indexed FMAs
+// CHECK: %[[BCST2:.*]] = x86vector.avx.bcst_to_f32.packed %arg4
+// CHECK: %[[EVEN0:.*]] = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg5
+// CHECK: %[[FMA_EVEN0:.*]] = vector.fma %[[BCST2]], %[[EVEN0]], %[[FMA_ODD1]]
+// CHECK: %[[EVEN1:.*]] = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg2
+// CHECK: %[[BCST3:.*]] = x86vector.avx.bcst_to_f32.packed %arg1
+// CHECK: %[[FMA_EVEN1:.*]] = vector.fma %[[EVEN1]], %[[BCST3]], %[[FMA_ODD0]]
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %0 {
+ transform.apply_patterns.x86vector.shuffle_vector_fma_ops
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vec = vector<8xf32>
+!vecOut = vector<1x8xf32>
+!memrefA = memref<1x1x1xbf16>
+!memrefB = memref<1x8x2xbf16>
+
+func.func @shuffle_fma_with_shape_cast(
+ %arg0: !memrefA, %arg1: !memrefA, %arg2: !memrefB, %arg3: !memrefA,
+ %arg4: !memrefA, %arg5: !memrefB, %arg6: !vec) -> !vecOut
+{
+ %0 = x86vector.avx.bcst_to_f32.packed %arg0 : !memrefA -> !vec
+ %1 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg2 : !memrefB -> !vec
+ %2 = vector.fma %0, %1, %arg6 : !vec
+ %3 = x86vector.avx.bcst_to_f32.packed %arg1 : !memrefA -> !vec
+ %4 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg2 : !memrefB -> !vec
+ %5 = vector.fma %3, %4, %2 : !vec
+ %res1 = vector.shape_cast %5 : !vec to !vecOut
+ %6 = x86vector.avx.bcst_to_f32.packed %arg3 : !memrefA -> !vec
+ %7 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg5 : !memrefB -> !vec
+ %8 = vector.fma %6, %7, %arg6 : !vec
+ %9 = x86vector.avx.bcst_to_f32.packed %arg4 : !memrefA -> !vec
+ %10 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg5 : !memrefB -> !vec
+ %11 = vector.fma %9, %10, %8 : !vec
+ %res2 = vector.shape_cast %11 : !vec to !vecOut
+ %12 = arith.addf %res1, %res2 : !vecOut
+ return %12 : !vecOut
+}
+
+// CHECK-LABEL: @shuffle_fma_with_shape_cast
+// Odd-Indexed FMAs
+// CHECK: %[[BCST0:.*]] = x86vector.avx.bcst_to_f32.packed %arg0
+// CHECK: %[[ODD0:.*]] = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg2
+// CHECK: %[[FMA_ODD0:.*]] = vector.fma %[[BCST0]], %[[ODD0]], %arg6
+// CHECK: %[[BCST1:.*]] = x86vector.avx.bcst_to_f32.packed %arg3
+// CHECK: %[[ODD1:.*]] = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg5
+// CHECK: %[[FMA_ODD1:.*]] = vector.fma %[[BCST1]], %[[ODD1]], %arg6
+// Even-Indexed FMAs
+// CHECK: %[[BCST3:.*]] = x86vector.avx.bcst_to_f32.packed %arg4
+// CHECK: %[[EVEN1:.*]] = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg5
+// CHECK: %[[FMA_EVEN1:.*]] = vector.fma %[[BCST3]], %[[EVEN1]], %[[FMA_ODD1]]
+// CHECK: vector.shape_cast
+// CHECK: %[[BCST2:.*]] = x86vector.avx.bcst_to_f32.packed %arg1
+// CHECK: %[[EVEN0:.*]] = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg2
+// CHECK: %[[FMA_EVEN0:.*]] = vector.fma %[[BCST2]], %[[EVEN0]], %[[FMA_ODD0]]
+// CHECK: vector.shape_cast
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %0 {
+ transform.apply_patterns.x86vector.shuffle_vector_fma_ops
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vec = vector<8xf32>
+!memrefA = memref<1x1x1xbf16>
+!memrefB = memref<1x8x2xbf16>
+
+func.func @negative_fma_operand_has_multiple_consumer(
+ %arg0: !memrefA, %arg1: !memrefA, %arg2: !memrefB,
+ %arg3: !memrefA, %arg4: !memrefB, %arg5: !vec) -> !vec
+{
+ %0 = x86vector.avx.bcst_to_f32.packed %arg0 : !memrefA -> !vec
+ %1 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg2 : !memrefB -> !vec
+ %2 = vector.fma %0, %1, %arg5 : !vec
+ %3 = x86vector.avx.bcst_to_f32.packed %arg1 : !memrefA -> !vec
+ %4 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg2 : !memrefB -> !vec
+ %5 = vector.fma %3, %4, %2 : !vec
+ %7 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg4 : !memrefB -> !vec
+ %8 = vector.fma %3, %7, %arg5 : !vec
+ %9 = x86vector.avx.bcst_to_f32.packed %arg3 : !memrefA -> !vec
+ %10 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg4 : !memrefB -> !vec
+ %11 = vector.fma %9, %10, %8 : !vec
+ %12 = vector.fma %5, %11, %arg5 : !vec
+ return %12 : !vec
+}
+
+// The vector.fma at %5 uses %3 as its LHS operand, which has two consumers; therefore,
+// the rewrite is not applied.
+// CHECK-LABEL: @negative_fma_operand_has_multiple_consumer
+// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32
+// CHECK: vector.fma
+// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32
+// CHECK: vector.fma
+// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32
+// CHECK: vector.fma
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %0 {
+ transform.apply_patterns.x86vector.shuffle_vector_fma_ops
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vec = vector<8xf32>
+!memrefA = memref<1x1x1xbf16>
+!memrefB = memref<1x8x2xbf16>
+
+func.func @negative_fma_has_multiple_consumer(
+ %arg0: !memrefA, %arg1: !memrefA, %arg2: !memrefB, %arg3: !memrefA,
+ %arg4: !memrefA, %arg5: !memrefB, %arg6: !vec) -> !vec
+{
+ %0 = x86vector.avx.bcst_to_f32.packed %arg0 : !memrefA -> !vec
+ %1 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg2 : !memrefB -> !vec
+ %2 = vector.fma %0, %1, %arg6 : !vec
+ %3 = x86vector.avx.bcst_to_f32.packed %arg1 : !memrefA -> !vec
+ %4 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg2 : !memrefB -> !vec
+ %5 = vector.fma %3, %4, %2 : !vec
+ %6 = x86vector.avx.bcst_to_f32.packed %arg3 : !memrefA -> !vec
+ %7 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg5 : !memrefB -> !vec
+ %8 = vector.fma %6, %7, %5 : !vec
+ %9 = x86vector.avx.bcst_to_f32.packed %arg4 : !memrefA -> !vec
+ %10 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg5 : !memrefB -> !vec
+ %11 = vector.fma %9, %10, %8 : !vec
+ %12 = vector.fma %5, %11, %arg6 : !vec
+ return %12 : !vec
+}
+
+// vector.fma at %5 has two uses; therefore no re-write applied.
+// CHECK-LABEL: @negative_fma_has_multiple_consumer
+// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32
+// CHECK: vector.fma
+// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32
+// CHECK: vector.fma
+// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %0 {
+ transform.apply_patterns.x86vector.shuffle_vector_fma_ops
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+!vec = vector<8xf32>
+!memrefA = memref<1x1x1xbf16>
+!memrefB = memref<1x8x2xbf16>
+
+func.func @negative_no_shuffle_outside_block(
+ %arg0: !memrefA, %arg1: !memrefA, %arg2: !memrefB, %arg3: !memrefA,
+ %arg4: !memrefA, %arg5: !memrefB, %arg6: !vec, %arg7: i1) -> !vec
+{
+ %0 = x86vector.avx.bcst_to_f32.packed %arg0 : !memrefA -> !vec
+ %1 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg2 : !memrefB -> !vec
+ %2 = vector.fma %0, %1, %arg6 : !vec
+ %3 = x86vector.avx.bcst_to_f32.packed %arg1 : !memrefA -> !vec
+ %4 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg2 : !memrefB -> !vec
+ %5 = vector.fma %3, %4, %2 : !vec
+
+ %loop = scf.if %arg7 -> (vector<8xf32>) {
+ %6 = x86vector.avx.bcst_to_f32.packed %arg3 : !memrefA -> !vec
+ %7 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg5 : !memrefB -> !vec
+ %8 = vector.fma %6, %7, %arg6 : !vec
+ %9 = x86vector.avx.bcst_to_f32.packed %arg4 : !memrefA -> !vec
+ %10 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg5 : !memrefB -> !vec
+ %11 = vector.fma %9, %10, %8 : !vec
+ %12 = vector.fma %5, %11, %arg6 : !vec
+ scf.yield %12 : vector<8xf32>
+ } else {
+ %6 = x86vector.avx.bcst_to_f32.packed %arg3 : !memrefA -> !vec
+ %7 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg5 : !memrefB -> !vec
+ %8 = vector.fma %6, %7, %arg6 : !vec
+ %9 = x86vector.avx.bcst_to_f32.packed %arg4 : !memrefA -> !vec
+ %10 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg5 : !memrefB -> !vec
+ %11 = vector.fma %9, %10, %8 : !vec
+ %12 = vector.fma %5, %11, %arg6 : !vec
+ scf.yield %12 : vector<8xf32>
+ }
+
+ return %loop : !vec
+}
+
+// vector.fma at %5 has its consumer in an another block (%12); therefore rewrite is not
+// applied.
+// CHECK-LABEL: @negative_no_shuffle_outside_block
+// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32
+// CHECK: vector.fma
+// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32
+// CHECK: vector.fma
+// CHECK: scf.if
+// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32
+// CHECK: vector.fma
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %0 {
+ transform.apply_patterns.x86vector.shuffle_vector_fma_ops
+ } : !transform.any_op
+ transform.yield
+ }
+}
More information about the Mlir-commits
mailing list