[Mlir-commits] [mlir] [mlir][x86vector] Shuffle FMAs (PR #172823)

Arun Thangamani llvmlistbot at llvm.org
Tue Jan 20 01:59:59 PST 2026


https://github.com/arun-thmn updated https://github.com/llvm/llvm-project/pull/172823

>From 027028008c5811f95d4389986e56cd27f70c12ab Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Thu, 18 Dec 2025 01:05:00 -0800
Subject: [PATCH 1/6] initial commit for shuffling fma

---
 .../TransformOps/X86VectorTransformOps.td     |  11 ++
 .../mlir/Dialect/X86Vector/Transforms.h       |   2 +
 .../TransformOps/X86VectorTransformOps.cpp    |   5 +
 .../X86Vector/Transforms/CMakeLists.txt       |   1 +
 .../Transforms/ShuffleVectorFMAOps.cpp        | 127 ++++++++++++++++++
 .../X86Vector/shuffle-vector-fmas.mlir        |  51 +++++++
 6 files changed, 197 insertions(+)
 create mode 100644 mlir/lib/Dialect/X86Vector/Transforms/ShuffleVectorFMAOps.cpp
 create mode 100644 mlir/test/Dialect/X86Vector/shuffle-vector-fmas.mlir

diff --git a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
index 12ba5e9f11141..a1c946dc3f91d 100644
--- a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
@@ -49,6 +49,17 @@ def ApplySinkVectorProducerOpsPatternsOp : Op<Transform_Dialect,
   let assemblyFormat = "attr-dict";
 }
 
+def ApplyShuffleVectorFMAOpsPatternsOp : Op<Transform_Dialect,
+    "apply_patterns.x86vector.shuffle_vector_fma_ops",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Collect patterns to sink vector producer operations forward in a block to
+         place them immediately before their first use.
+  }];
+
+  let assemblyFormat = "attr-dict";
+}
+
 
 #endif // X86VECTOR_TRANSFORM_OPS
 
diff --git a/mlir/include/mlir/Dialect/X86Vector/Transforms.h b/mlir/include/mlir/Dialect/X86Vector/Transforms.h
index b9c9054f57890..93c4544a61f62 100644
--- a/mlir/include/mlir/Dialect/X86Vector/Transforms.h
+++ b/mlir/include/mlir/Dialect/X86Vector/Transforms.h
@@ -95,6 +95,8 @@ void populateVectorContractToPackedTypeDotProductPatterns(
 // range by placing them at their earliest legal use site
 void populateSinkVectorProducerOpsPatterns(RewritePatternSet &patterns);
 
+void populateShuffleVectorFMAOpsPatterns(RewritePatternSet &patterns);
+
 //===----------------------------------------------------------------------===//
 /// Helpers extracted from:
 ///   - clang/lib/Headers/avxintrin.h
diff --git a/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
index 25772f2aa57f4..a23f9b701199e 100644
--- a/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
@@ -37,6 +37,11 @@ void mlir::transform::ApplySinkVectorProducerOpsPatternsOp::populatePatterns(
   x86vector::populateSinkVectorProducerOpsPatterns(patterns);
 }
 
+void mlir::transform::ApplyShuffleVectorFMAOpsPatternsOp::populatePatterns(
+    RewritePatternSet &patterns) {
+  x86vector::populateShuffleVectorFMAOpsPatterns(patterns);
+}
+
 //===----------------------------------------------------------------------===//
 // Transform op registration
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
index cc4d3cac0f7ea..95471bc72f65d 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
@@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRX86VectorTransforms
   VectorContractToFMA.cpp
   VectorContractToPackedTypeDotProduct.cpp
   SinkVectorProducerOps.cpp
+  ShuffleVectorFMAOps.cpp
 
   LINK_LIBS PUBLIC
   MLIRArithDialect
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/ShuffleVectorFMAOps.cpp b/mlir/lib/Dialect/X86Vector/Transforms/ShuffleVectorFMAOps.cpp
new file mode 100644
index 0000000000000..c3160529f1fd7
--- /dev/null
+++ b/mlir/lib/Dialect/X86Vector/Transforms/ShuffleVectorFMAOps.cpp
@@ -0,0 +1,127 @@
+//===- ShuffleVectorFMAOps.cpp --------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/Dialect/X86Vector/Transforms.h"
+#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
+
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/PatternMatch.h"
+
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace mlir::vector;
+using namespace mlir::x86vector;
+
+namespace {
+
+static bool validateX86OpsHasOneUser(Value op) {
+
+  if (auto x86Op = op.getDefiningOp<x86vector::CvtPackedEvenIndexedToF32Op>()) {
+    if (!x86Op.getResult().hasOneUse())
+      return false;
+  } else if (auto x86Op = op.getDefiningOp<x86vector::BcstToPackedF32Op>()) {
+    if (!x86Op.getResult().hasOneUse())
+      return false;
+  } else {
+    return false;
+  }
+  return true;
+}
+
+static bool validateVectorFMAOp(vector::FMAOp fmaOp) {
+
+  Value lhs = fmaOp.getLhs();
+  Value rhs = fmaOp.getRhs();
+
+  if (!isa<x86vector::CvtPackedEvenIndexedToF32Op>(lhs.getDefiningOp()) &&
+      !isa<x86vector::CvtPackedEvenIndexedToF32Op>(rhs.getDefiningOp()))
+    return false;
+
+  if (!validateX86OpsHasOneUser(fmaOp.getLhs()) ||
+      !validateX86OpsHasOneUser(fmaOp.getRhs()))
+    return false;
+
+  if (!fmaOp.getResult().hasOneUse())
+    return false;
+
+  return true;
+}
+
+static void moveFMA(vector::FMAOp fmaOp) {
+  Operation *onlyUser = *fmaOp.getResult().getUsers().begin();
+
+  if (auto shapeCastOp = dyn_cast<vector::ShapeCastOp>(onlyUser)) {
+    if (shapeCastOp.getResult().hasOneUse()) {
+      onlyUser = *shapeCastOp.getResult().getUsers().begin();
+      fmaOp.getLhs().getDefiningOp()->moveBefore(onlyUser);
+      fmaOp.getRhs().getDefiningOp()->moveBefore(onlyUser);
+      fmaOp->moveBefore(onlyUser);
+      shapeCastOp->moveBefore(onlyUser);
+      return;
+    }
+  }
+
+  fmaOp.getLhs().getDefiningOp()->moveBefore(onlyUser);
+  fmaOp.getRhs().getDefiningOp()->moveBefore(onlyUser);
+  fmaOp->moveBefore(onlyUser);
+  return;
+}
+
+struct ShuffleVectorFMAOps : public OpRewritePattern<vector::FMAOp> {
+  using OpRewritePattern<vector::FMAOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::FMAOp fmaOp,
+                                PatternRewriter &rewriter) const override {
+
+    if (!validateVectorFMAOp(fmaOp))
+      return failure();
+
+    llvm::SmallVector<vector::FMAOp> fmaOps;
+    Operation *nextOp = fmaOp;
+    bool loopBreak = true;
+
+    while ((nextOp = nextOp->getNextNode())) {
+      if (auto fma = dyn_cast<vector::FMAOp>(nextOp)) {
+        if (isa<x86vector::CvtPackedEvenIndexedToF32Op>(
+                fma.getLhs().getDefiningOp()) ||
+            isa<x86vector::CvtPackedEvenIndexedToF32Op>(
+                fma.getRhs().getDefiningOp())) {
+          if (loopBreak)
+            break;
+        }
+
+        if (validateVectorFMAOp(fma))
+          fmaOps.push_back(fma);
+
+        loopBreak = false;
+      }
+    }
+
+    if (fmaOps.empty())
+      return failure();
+
+    fmaOps.push_back(fmaOp);
+    for (size_t i = 0; i < fmaOps.size(); i++) {
+      moveFMA(fmaOps[i]);
+    }
+
+    return success();
+  }
+};
+
+} // namespace
+
+void x86vector::populateShuffleVectorFMAOpsPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<ShuffleVectorFMAOps>(patterns.getContext());
+}
diff --git a/mlir/test/Dialect/X86Vector/shuffle-vector-fmas.mlir b/mlir/test/Dialect/X86Vector/shuffle-vector-fmas.mlir
new file mode 100644
index 0000000000000..93abecdf07914
--- /dev/null
+++ b/mlir/test/Dialect/X86Vector/shuffle-vector-fmas.mlir
@@ -0,0 +1,51 @@
+// RUN: mlir-opt %s -transform-interpreter -cse -split-input-file | FileCheck %s
+
+!vec = vector<8xf32>
+!memrefA = memref<1x1x1xbf16>
+!memrefB = memref<1x8x2xbf16>
+
+func.func @shuffle_vector_fma(
+  %arg0: !memrefA, %arg1: !memrefA, %arg2: !memrefB, %arg3: !memrefA,
+  %arg4: !memrefA, %arg5: !memrefB, %arg6: !vec) -> !vec
+{
+  %0 = x86vector.avx.bcst_to_f32.packed %arg0 : !memrefA -> !vec
+  %1 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg2 : !memrefB -> !vec
+  %2 = vector.fma %0, %1, %arg6 : !vec
+  %3 = x86vector.avx.bcst_to_f32.packed %arg1 : !memrefA -> !vec
+  %4 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg2 : !memrefB -> !vec
+  %5 = vector.fma %3, %4, %2 : !vec
+  %55 = vector.shape_cast %5 : vector<8xf32> to vector<1x8xf32>
+  %6 = x86vector.avx.bcst_to_f32.packed %arg3 : !memrefA -> !vec
+  %7 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg5 : !memrefB -> !vec
+  %8 = vector.fma %6, %7, %arg6 : !vec
+  %9 = x86vector.avx.bcst_to_f32.packed %arg4 : !memrefA -> !vec
+  %10 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg5 : !memrefB -> !vec
+  %11 = vector.fma %9, %10, %8 : !vec
+  %13 = vector.shape_cast %55 : vector<1x8xf32> to vector<8xf32>
+  %12 = vector.fma %13, %11, %arg6 : !vec
+  return %12 : !vec
+}
+
+// CHECK-LABEL: @shuffle_vector_fma
+// CHECK: x86vector.avx.bcst_to_f32.packed
+// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32
+// CHECK: vector.fma
+// CHECK: x86vector.avx.bcst_to_f32.packed
+// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32
+// CHECK: vector.fma
+// CHECK: x86vector.avx.bcst_to_f32.packed
+// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32
+// CHECK: vector.fma
+// CHECK: x86vector.avx.bcst_to_f32.packed
+// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32
+// CHECK: vector.fma
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %0 {
+      transform.apply_patterns.x86vector.shuffle_vector_fma_ops
+    } : !transform.any_op
+    transform.yield
+  }
+}

>From 2061dd4fa4a68633f2e01328836479974c7e7262 Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Thu, 18 Dec 2025 08:04:48 -0800
Subject: [PATCH 2/6] added comments + new test-cases

---
 .../TransformOps/X86VectorTransformOps.td     |   4 +-
 .../mlir/Dialect/X86Vector/Transforms.h       |   2 +
 .../Transforms/ShuffleVectorFMAOps.cpp        | 145 +++++++---
 .../X86Vector/shuffle-vector-fmas.mlir        | 268 +++++++++++++++++-
 4 files changed, 367 insertions(+), 52 deletions(-)

diff --git a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
index a1c946dc3f91d..1183d1bb8b9f3 100644
--- a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
@@ -53,8 +53,8 @@ def ApplyShuffleVectorFMAOpsPatternsOp : Op<Transform_Dialect,
     "apply_patterns.x86vector.shuffle_vector_fma_ops",
     [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
   let description = [{
-    Collect patterns to sink vector producer operations forward in a block to
-         place them immediately before their first use.
+    Collect patterns to shuffle FMAs with x86vector operations as operands 
+    such that FMAs are grouped with respect to odd/even packed index.
   }];
 
   let assemblyFormat = "attr-dict";
diff --git a/mlir/include/mlir/Dialect/X86Vector/Transforms.h b/mlir/include/mlir/Dialect/X86Vector/Transforms.h
index 93c4544a61f62..72f3f685d7ec1 100644
--- a/mlir/include/mlir/Dialect/X86Vector/Transforms.h
+++ b/mlir/include/mlir/Dialect/X86Vector/Transforms.h
@@ -95,6 +95,8 @@ void populateVectorContractToPackedTypeDotProductPatterns(
 // range by placing them at their earliest legal use site
 void populateSinkVectorProducerOpsPatterns(RewritePatternSet &patterns);
 
+// Shuffles FMAs with x86vector operations as operands such that FMAs are
+// grouped with respect to odd/even packed index.
 void populateShuffleVectorFMAOpsPatterns(RewritePatternSet &patterns);
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/ShuffleVectorFMAOps.cpp b/mlir/lib/Dialect/X86Vector/Transforms/ShuffleVectorFMAOps.cpp
index c3160529f1fd7..945cf4e36b044 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/ShuffleVectorFMAOps.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/ShuffleVectorFMAOps.cpp
@@ -7,12 +7,9 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
-#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
 #include "mlir/Dialect/X86Vector/Transforms.h"
 #include "mlir/Dialect/X86Vector/X86VectorDialect.h"
 
-#include "mlir/IR/BuiltinAttributes.h"
-#include "mlir/IR/Dominance.h"
 #include "mlir/IR/PatternMatch.h"
 
 #include "mlir/Pass/Pass.h"
@@ -24,22 +21,24 @@ using namespace mlir::x86vector;
 
 namespace {
 
+// Validates whether the given operation is an x86vector operation and has only
+// one consumer.
 static bool validateX86OpsHasOneUser(Value op) {
+  if (auto cvt = op.getDefiningOp<x86vector::CvtPackedEvenIndexedToF32Op>())
+    return cvt.getResult().hasOneUse();
 
-  if (auto x86Op = op.getDefiningOp<x86vector::CvtPackedEvenIndexedToF32Op>()) {
-    if (!x86Op.getResult().hasOneUse())
-      return false;
-  } else if (auto x86Op = op.getDefiningOp<x86vector::BcstToPackedF32Op>()) {
-    if (!x86Op.getResult().hasOneUse())
-      return false;
-  } else {
-    return false;
-  }
-  return true;
+  if (auto bcst = op.getDefiningOp<x86vector::BcstToPackedF32Op>())
+    return bcst.getResult().hasOneUse();
+
+  return false;
 }
 
+// Validates the vector.fma operation on the following conditions:
+// (i) one of the lhs or rhs defining operation should be
+// CvtPackedEvenIndexedToF32Op, (ii) the lhs or rhs defining operation should be
+// an x86vector operation and has only one consumer, (iii) all oerations in same
+// block, and (iv) ths FMA has only one user.
 static bool validateVectorFMAOp(vector::FMAOp fmaOp) {
-
   Value lhs = fmaOp.getLhs();
   Value rhs = fmaOp.getRhs();
 
@@ -47,36 +46,89 @@ static bool validateVectorFMAOp(vector::FMAOp fmaOp) {
       !isa<x86vector::CvtPackedEvenIndexedToF32Op>(rhs.getDefiningOp()))
     return false;
 
-  if (!validateX86OpsHasOneUser(fmaOp.getLhs()) ||
-      !validateX86OpsHasOneUser(fmaOp.getRhs()))
+  if (!validateX86OpsHasOneUser(lhs) || !validateX86OpsHasOneUser(rhs))
+    return false;
+
+  if (lhs.getDefiningOp()->getBlock() != rhs.getDefiningOp()->getBlock())
+    return false;
+
+  if (lhs.getDefiningOp()->getBlock() != fmaOp->getBlock())
     return false;
 
   if (!fmaOp.getResult().hasOneUse())
     return false;
 
+  Operation *consumer = *fmaOp.getResult().getUsers().begin();
+  if (consumer->getBlock() != fmaOp->getBlock())
+    return false;
+
   return true;
 }
 
+// Moves vector.fma along with the lhs and rhs defining operation before it's
+// comsumer. If the consumer is vector.ShapeCastOp and has only one user then
+// move before the consumer of vector.ShapeCastOp.
+// TODO: Move before first consumer, if there are multiple.
 static void moveFMA(vector::FMAOp fmaOp) {
-  Operation *onlyUser = *fmaOp.getResult().getUsers().begin();
+  Operation *consumer = *fmaOp.getResult().getUsers().begin();
 
-  if (auto shapeCastOp = dyn_cast<vector::ShapeCastOp>(onlyUser)) {
+  if (auto shapeCastOp = dyn_cast<vector::ShapeCastOp>(consumer)) {
     if (shapeCastOp.getResult().hasOneUse()) {
-      onlyUser = *shapeCastOp.getResult().getUsers().begin();
-      fmaOp.getLhs().getDefiningOp()->moveBefore(onlyUser);
-      fmaOp.getRhs().getDefiningOp()->moveBefore(onlyUser);
-      fmaOp->moveBefore(onlyUser);
-      shapeCastOp->moveBefore(onlyUser);
-      return;
+      Operation *nxtConsumer = *shapeCastOp.getResult().getUsers().begin();
+      if (nxtConsumer->getBlock() == fmaOp->getBlock()) {
+        consumer = *shapeCastOp.getResult().getUsers().begin();
+        fmaOp.getLhs().getDefiningOp()->moveBefore(consumer);
+        fmaOp.getRhs().getDefiningOp()->moveBefore(consumer);
+        fmaOp->moveBefore(consumer);
+        shapeCastOp->moveBefore(consumer);
+        return;
+      }
     }
   }
 
-  fmaOp.getLhs().getDefiningOp()->moveBefore(onlyUser);
-  fmaOp.getRhs().getDefiningOp()->moveBefore(onlyUser);
-  fmaOp->moveBefore(onlyUser);
+  fmaOp.getLhs().getDefiningOp()->moveBefore(consumer);
+  fmaOp.getRhs().getDefiningOp()->moveBefore(consumer);
+  fmaOp->moveBefore(consumer);
   return;
 }
 
+// Shuffle FMAs with x86vector operations as operands such that
+// FMAs are grouped with respect to odd/even packed index.
+//
+// For example:
+// ```
+//   %1 = x86vector.avx.bcst_to_f32.packed
+//   %2 = x86vector.avx.cvt.packed.odd.indexed_to_f32
+//   %3 = vector.fma %1, %2, %arg1
+//   %4 = x86vector.avx.bcst_to_f32.packed
+//   %5 = x86vector.avx.cvt.packed.even.indexed_to_f32
+//   %6 = vector.fma %4, %5, %3
+//   %7 = x86vector.avx.bcst_to_f32.packed
+//   %8 = x86vector.avx.cvt.packed.odd.indexed_to_f32
+//   %9 = vector.fma %7, %8, %arg2
+//   %10 = x86vector.avx.bcst_to_f32.packed
+//   %11 = x86vector.avx.cvt.packed.even.indexed_to_f32
+//   %12 = vector.fma %10, %11, %9
+//   yield %6, %12
+// ```
+// to
+// ```
+//   %1 = x86vector.avx.bcst_to_f32.packed
+//   %2 = x86vector.avx.cvt.packed.odd.indexed_to_f32
+//   %3 = vector.fma %1, %2, %arg1
+//   %4 = x86vector.avx.bcst_to_f32.packed
+//   %5 = x86vector.avx.cvt.packed.odd.indexed_to_f32
+//   %6 = vector.fma %4, %5, %arg2
+//   %7 = x86vector.avx.bcst_to_f32.packed
+//   %8 = x86vector.avx.cvt.packed.even.indexed_to_f32
+//   %9 = vector.fma %7, %8, %3
+//   %10 = x86vector.avx.bcst_to_f32.packed
+//   %11 = x86vector.avx.cvt.packed.even.indexed_to_f32
+//   %12 = vector.fma %10, %11, %6
+//   yield %9, %12
+// ```
+// TODO: Shuffling supported only if the FMA, lhs/rhs defining operations
+// have only one consumer. Have to extend this pass for multiple consumers.
 struct ShuffleVectorFMAOps : public OpRewritePattern<vector::FMAOp> {
   using OpRewritePattern<vector::FMAOp>::OpRewritePattern;
 
@@ -88,32 +140,35 @@ struct ShuffleVectorFMAOps : public OpRewritePattern<vector::FMAOp> {
 
     llvm::SmallVector<vector::FMAOp> fmaOps;
     Operation *nextOp = fmaOp;
-    bool loopBreak = true;
+    bool stopAtNextDependentFMA = true;
 
+    // Break the loop and return failure if the immediate next FMA op
+    // have CvtPackedEvenIndexedToF32Op in it's lhs/rhs defining ops.
     while ((nextOp = nextOp->getNextNode())) {
-      if (auto fma = dyn_cast<vector::FMAOp>(nextOp)) {
-        if (isa<x86vector::CvtPackedEvenIndexedToF32Op>(
-                fma.getLhs().getDefiningOp()) ||
-            isa<x86vector::CvtPackedEvenIndexedToF32Op>(
-                fma.getRhs().getDefiningOp())) {
-          if (loopBreak)
-            break;
-        }
-
-        if (validateVectorFMAOp(fma))
-          fmaOps.push_back(fma);
-
-        loopBreak = false;
-      }
+      auto fma = dyn_cast<vector::FMAOp>(nextOp);
+      if (!fma)
+        continue;
+
+      bool hasX86CvtOperand = isa<x86vector::CvtPackedEvenIndexedToF32Op>(
+                                  fma.getLhs().getDefiningOp()) ||
+                              isa<x86vector::CvtPackedEvenIndexedToF32Op>(
+                                  fma.getRhs().getDefiningOp());
+
+      if (hasX86CvtOperand && stopAtNextDependentFMA)
+        break;
+
+      if (validateVectorFMAOp(fma))
+        fmaOps.push_back(fma);
+
+      stopAtNextDependentFMA = false;
     }
 
     if (fmaOps.empty())
       return failure();
 
     fmaOps.push_back(fmaOp);
-    for (size_t i = 0; i < fmaOps.size(); i++) {
-      moveFMA(fmaOps[i]);
-    }
+    for (auto fmaOp : fmaOps)
+      moveFMA(fmaOp);
 
     return success();
   }
diff --git a/mlir/test/Dialect/X86Vector/shuffle-vector-fmas.mlir b/mlir/test/Dialect/X86Vector/shuffle-vector-fmas.mlir
index 93abecdf07914..0e3c0b53b9cbd 100644
--- a/mlir/test/Dialect/X86Vector/shuffle-vector-fmas.mlir
+++ b/mlir/test/Dialect/X86Vector/shuffle-vector-fmas.mlir
@@ -4,7 +4,7 @@
 !memrefA = memref<1x1x1xbf16>
 !memrefB = memref<1x8x2xbf16>
 
-func.func @shuffle_vector_fma(
+func.func @shuffle_fma_lhs_even_index(
   %arg0: !memrefA, %arg1: !memrefA, %arg2: !memrefB, %arg3: !memrefA,
   %arg4: !memrefA, %arg5: !memrefB, %arg6: !vec) -> !vec
 {
@@ -14,19 +14,17 @@ func.func @shuffle_vector_fma(
   %3 = x86vector.avx.bcst_to_f32.packed %arg1 : !memrefA -> !vec
   %4 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg2 : !memrefB -> !vec
   %5 = vector.fma %3, %4, %2 : !vec
-  %55 = vector.shape_cast %5 : vector<8xf32> to vector<1x8xf32>
   %6 = x86vector.avx.bcst_to_f32.packed %arg3 : !memrefA -> !vec
   %7 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg5 : !memrefB -> !vec
   %8 = vector.fma %6, %7, %arg6 : !vec
   %9 = x86vector.avx.bcst_to_f32.packed %arg4 : !memrefA -> !vec
   %10 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg5 : !memrefB -> !vec
   %11 = vector.fma %9, %10, %8 : !vec
-  %13 = vector.shape_cast %55 : vector<1x8xf32> to vector<8xf32>
-  %12 = vector.fma %13, %11, %arg6 : !vec
+  %12 = vector.fma %5, %11, %arg6 : !vec
   return %12 : !vec
 }
 
-// CHECK-LABEL: @shuffle_vector_fma
+// CHECK-LABEL: @shuffle_fma_lhs_even_index
 // CHECK: x86vector.avx.bcst_to_f32.packed
 // CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32
 // CHECK: vector.fma
@@ -49,3 +47,263 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+!vec = vector<8xf32>
+!memrefA = memref<1x1x1xbf16>
+!memrefB = memref<1x8x2xbf16>
+
+func.func @shuffle_fma_rhs_even_index(
+  %arg0: !memrefA, %arg1: !memrefA, %arg2: !memrefB, %arg3: !memrefA,
+  %arg4: !memrefA, %arg5: !memrefB, %arg6: !vec) -> !vec
+{
+  %0 = x86vector.avx.bcst_to_f32.packed %arg0 : !memrefA -> !vec
+  %1 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg2 : !memrefB -> !vec
+  %2 = vector.fma %0, %1, %arg6 : !vec
+  %3 = x86vector.avx.bcst_to_f32.packed %arg1 : !memrefA -> !vec
+  %4 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg2 : !memrefB -> !vec
+  %5 = vector.fma %4, %3, %2 : !vec
+  %6 = x86vector.avx.bcst_to_f32.packed %arg3 : !memrefA -> !vec
+  %7 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg5 : !memrefB -> !vec
+  %8 = vector.fma %6, %7, %arg6 : !vec
+  %9 = x86vector.avx.bcst_to_f32.packed %arg4 : !memrefA -> !vec
+  %10 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg5 : !memrefB -> !vec
+  %11 = vector.fma %9, %10, %8 : !vec
+  %12 = vector.fma %5, %11, %arg6 : !vec
+  return %12 : !vec
+}
+
+// CHECK-LABEL: @shuffle_fma_rhs_even_index
+// CHECK: x86vector.avx.bcst_to_f32.packed
+// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32
+// CHECK: vector.fma
+// CHECK: x86vector.avx.bcst_to_f32.packed
+// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32
+// CHECK: vector.fma
+// CHECK: x86vector.avx.bcst_to_f32.packed
+// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32
+// CHECK: vector.fma
+// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32
+// CHECK: x86vector.avx.bcst_to_f32.packed
+// CHECK: vector.fma
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %0 {
+      transform.apply_patterns.x86vector.shuffle_vector_fma_ops
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!vec = vector<8xf32>
+!memrefA = memref<1x1x1xbf16>
+!memrefB = memref<1x8x2xbf16>
+
+func.func @negative_fma_lhs_multiple_consumer(
+  %arg0: !memrefA, %arg1: !memrefA, %arg2: !memrefB,
+  %arg3: !memrefA, %arg4: !memrefB, %arg5: !vec) -> !vec
+{
+  %0 = x86vector.avx.bcst_to_f32.packed %arg0 : !memrefA -> !vec
+  %1 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg2 : !memrefB -> !vec
+  %2 = vector.fma %0, %1, %arg5 : !vec
+  %3 = x86vector.avx.bcst_to_f32.packed %arg1 : !memrefA -> !vec
+  %4 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg2 : !memrefB -> !vec
+  %5 = vector.fma %3, %4, %2 : !vec
+  %7 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg4 : !memrefB -> !vec
+  %8 = vector.fma %3, %7, %arg5 : !vec
+  %9 = x86vector.avx.bcst_to_f32.packed %arg3 : !memrefA -> !vec
+  %10 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg4 : !memrefB -> !vec
+  %11 = vector.fma %9, %10, %8 : !vec
+  %12 = vector.fma %5, %11, %arg5 : !vec
+  return %12 : !vec
+}
+
+// CHECK-LABEL: @negative_fma_lhs_multiple_consumer
+// CHECK: x86vector.avx.bcst_to_f32.packed
+// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32
+// CHECK: vector.fma
+// CHECK: x86vector.avx.bcst_to_f32.packed
+// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32
+// CHECK: vector.fma
+// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32
+// CHECK: vector.fma
+// CHECK: x86vector.avx.bcst_to_f32.packed
+// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32
+// CHECK: vector.fma
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %0 {
+      transform.apply_patterns.x86vector.shuffle_vector_fma_ops
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!vec = vector<8xf32>
+!memrefA = memref<1x1x1xbf16>
+!memrefB = memref<1x8x2xbf16>
+
+func.func @negative_fma_rhs_multiple_consumer(
+  %arg0: !memrefA, %arg1: !memrefA, %arg2: !memrefB, %arg3: !memrefA,
+  %arg4: !memrefA, %arg5: !memrefB, %arg6: !vec) -> !vec
+{
+  %0 = x86vector.avx.bcst_to_f32.packed %arg0 : !memrefA -> !vec
+  %1 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg2 : !memrefB -> !vec
+  %2 = vector.fma %0, %1, %arg6 : !vec
+  %3 = x86vector.avx.bcst_to_f32.packed %arg1 : !memrefA -> !vec
+  %4 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg2 : !memrefB -> !vec
+  %5 = vector.fma %3, %4, %2 : !vec
+  %6 = x86vector.avx.bcst_to_f32.packed %arg3 : !memrefA -> !vec
+  %7 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg5 : !memrefB -> !vec
+  %8 = vector.fma %6, %7, %arg6 : !vec
+  %9 = x86vector.avx.bcst_to_f32.packed %arg4 : !memrefA -> !vec
+  %10 = vector.fma %9, %4, %8 : !vec
+  %11 = vector.fma %5, %10, %arg6 : !vec
+  return %11 : !vec
+}
+
+// CHECK-LABEL: @negative_fma_rhs_multiple_consumer
+// CHECK: x86vector.avx.bcst_to_f32.packed
+// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32
+// CHECK: vector.fma
+// CHECK: x86vector.avx.bcst_to_f32.packed
+// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32
+// CHECK: vector.fma
+// CHECK: x86vector.avx.bcst_to_f32.packed
+// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32
+// CHECK: vector.fma
+// CHECK: x86vector.avx.bcst_to_f32.packed
+// CHECK: vector.fma
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %0 {
+      transform.apply_patterns.x86vector.shuffle_vector_fma_ops
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!vec = vector<8xf32>
+!memrefA = memref<1x1x1xbf16>
+!memrefB = memref<1x8x2xbf16>
+
+func.func @negative_fma_multiple_consumer(
+  %arg0: !memrefA, %arg1: !memrefA, %arg2: !memrefB, %arg3: !memrefA,
+  %arg4: !memrefA, %arg5: !memrefB, %arg6: !vec) -> !vec
+{
+  %0 = x86vector.avx.bcst_to_f32.packed %arg0 : !memrefA -> !vec
+  %1 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg2 : !memrefB -> !vec
+  %2 = vector.fma %0, %1, %arg6 : !vec
+  %3 = x86vector.avx.bcst_to_f32.packed %arg1 : !memrefA -> !vec
+  %4 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg2 : !memrefB -> !vec
+  %5 = vector.fma %3, %4, %2 : !vec
+  %6 = x86vector.avx.bcst_to_f32.packed %arg3 : !memrefA -> !vec
+  %7 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg5 : !memrefB -> !vec
+  %8 = vector.fma %6, %7, %5 : !vec
+  %9 = x86vector.avx.bcst_to_f32.packed %arg4 : !memrefA -> !vec
+  %10 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg5 : !memrefB -> !vec
+  %11 = vector.fma %9, %10, %8 : !vec
+  %12 = vector.fma %5, %11, %arg6 : !vec
+  return %12 : !vec
+}
+
+// CHECK-LABEL: @negative_fma_multiple_consumer
+// CHECK: x86vector.avx.bcst_to_f32.packed
+// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32
+// CHECK: vector.fma
+// CHECK: x86vector.avx.bcst_to_f32.packed
+// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32
+// CHECK: vector.fma
+// CHECK: x86vector.avx.bcst_to_f32.packed
+// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32
+// CHECK: vector.fma
+// CHECK: x86vector.avx.bcst_to_f32.packed
+// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32
+// CHECK: vector.fma
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %0 {
+      transform.apply_patterns.x86vector.shuffle_vector_fma_ops
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+!vec = vector<8xf32>
+!memrefA = memref<1x1x1xbf16>
+!memrefB = memref<1x8x2xbf16>
+
+func.func @negative_no_shuffle_outside_block(
+  %arg0: !memrefA, %arg1: !memrefA, %arg2: !memrefB, %arg3: !memrefA,
+  %arg4: !memrefA, %arg5: !memrefB, %arg6: !vec, %arg7: i1) -> !vec
+{
+  %0 = x86vector.avx.bcst_to_f32.packed %arg0 : !memrefA -> !vec
+  %1 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg2 : !memrefB -> !vec
+  %2 = vector.fma %0, %1, %arg6 : !vec
+  %3 = x86vector.avx.bcst_to_f32.packed %arg1 : !memrefA -> !vec
+  %4 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg2 : !memrefB -> !vec
+  %5 = vector.fma %3, %4, %2 : !vec
+
+  %loop = scf.if %arg7 -> (vector<8xf32>) {
+    %6 = x86vector.avx.bcst_to_f32.packed %arg3 : !memrefA -> !vec
+    %7 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg5 : !memrefB -> !vec
+    %8 = vector.fma %6, %7, %arg6 : !vec
+    %9 = x86vector.avx.bcst_to_f32.packed %arg4 : !memrefA -> !vec
+    %10 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg5 : !memrefB -> !vec
+    %11 = vector.fma %9, %10, %8 : !vec
+    %12 = vector.fma %5, %11, %arg6 : !vec
+    scf.yield %12 : vector<8xf32>
+  } else {
+    %6 = x86vector.avx.bcst_to_f32.packed %arg3 : !memrefA -> !vec
+    %7 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg5 : !memrefB -> !vec
+    %8 = vector.fma %6, %7, %arg6 : !vec
+    %9 = x86vector.avx.bcst_to_f32.packed %arg4 : !memrefA -> !vec
+    %10 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg5 : !memrefB -> !vec
+    %11 = vector.fma %9, %10, %8 : !vec
+    %12 = vector.fma %5, %11, %arg6 : !vec
+    scf.yield %12 : vector<8xf32>
+  }
+
+  return %loop : !vec
+}
+
+// CHECK-LABEL: @negative_no_shuffle_outside_block
+// CHECK: x86vector.avx.bcst_to_f32.packed
+// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32
+// CHECK: vector.fma
+// CHECK: x86vector.avx.bcst_to_f32.packed
+// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32
+// CHECK: vector.fma
+// CHECK: scf.if
+// CHECK: x86vector.avx.bcst_to_f32.packed
+// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32
+// CHECK: vector.fma
+// CHECK: x86vector.avx.bcst_to_f32.packed
+// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32
+// CHECK: vector.fma
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %0 {
+      transform.apply_patterns.x86vector.shuffle_vector_fma_ops
+    } : !transform.any_op
+    transform.yield
+  }
+}

>From 9a35d3b12c4b65a40d4b93665f238e82635c37b9 Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Sun, 21 Dec 2025 23:42:36 -0800
Subject: [PATCH 3/6] simplified the test-cases

---
 .../Transforms/ShuffleVectorFMAOps.cpp        |   6 +-
 .../X86Vector/shuffle-vector-fmas.mlir        | 142 ++++++------------
 2 files changed, 47 insertions(+), 101 deletions(-)

diff --git a/mlir/lib/Dialect/X86Vector/Transforms/ShuffleVectorFMAOps.cpp b/mlir/lib/Dialect/X86Vector/Transforms/ShuffleVectorFMAOps.cpp
index 945cf4e36b044..90109d03d785e 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/ShuffleVectorFMAOps.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/ShuffleVectorFMAOps.cpp
@@ -36,8 +36,8 @@ static bool validateX86OpsHasOneUser(Value op) {
 // Validates the vector.fma operation on the following conditions:
 // (i) one of the lhs or rhs defining operation should be
 // CvtPackedEvenIndexedToF32Op, (ii) the lhs or rhs defining operation should be
-// an x86vector operation and has only one consumer, (iii) all oerations in same
-// block, and (iv) ths FMA has only one user.
+// an x86vector operation and has only one consumer, (iii) all operations
+// are in the same block, and (iv) ths FMA has only one user.
 static bool validateVectorFMAOp(vector::FMAOp fmaOp) {
   Value lhs = fmaOp.getLhs();
   Value rhs = fmaOp.getRhs();
@@ -65,7 +65,7 @@ static bool validateVectorFMAOp(vector::FMAOp fmaOp) {
   return true;
 }
 
-// Moves vector.fma along with the lhs and rhs defining operation before it's
+// Moves vector.fma along with the lhs and rhs defining operation before its
 // comsumer. If the consumer is vector.ShapeCastOp and has only one user then
 // move before the consumer of vector.ShapeCastOp.
 // TODO: Move before first consumer, if there are multiple.
diff --git a/mlir/test/Dialect/X86Vector/shuffle-vector-fmas.mlir b/mlir/test/Dialect/X86Vector/shuffle-vector-fmas.mlir
index 0e3c0b53b9cbd..52bd148353b4c 100644
--- a/mlir/test/Dialect/X86Vector/shuffle-vector-fmas.mlir
+++ b/mlir/test/Dialect/X86Vector/shuffle-vector-fmas.mlir
@@ -4,7 +4,7 @@
 !memrefA = memref<1x1x1xbf16>
 !memrefB = memref<1x8x2xbf16>
 
-func.func @shuffle_fma_lhs_even_index(
+func.func @shuffle_fma_with_rhs_as_even.index_to_f32(
   %arg0: !memrefA, %arg1: !memrefA, %arg2: !memrefB, %arg3: !memrefA,
   %arg4: !memrefA, %arg5: !memrefB, %arg6: !vec) -> !vec
 {
@@ -24,19 +24,23 @@ func.func @shuffle_fma_lhs_even_index(
   return %12 : !vec
 }
 
-// CHECK-LABEL: @shuffle_fma_lhs_even_index
-// CHECK: x86vector.avx.bcst_to_f32.packed
-// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32
-// CHECK: vector.fma
-// CHECK: x86vector.avx.bcst_to_f32.packed
-// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32
-// CHECK: vector.fma
-// CHECK: x86vector.avx.bcst_to_f32.packed
-// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32
-// CHECK: vector.fma
-// CHECK: x86vector.avx.bcst_to_f32.packed
-// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32
-// CHECK: vector.fma
+// Groups FMAs with respect to even/odd indexed input operands.
+// The vector.fma at %5 is moved along with its operands after %8.  
+// CHECK-LABEL: @shuffle_fma_with_rhs_as_even.index_to_f32
+// Odd-Indexed FMAs
+// CHECK: %[[BCST0:.*]] = x86vector.avx.bcst_to_f32.packed %arg0
+// CHECK: %[[ODD0:.*]]  = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg2
+// CHECK: %[[FMA_ODD0:.*]] = vector.fma %[[BCST0]], %[[ODD0]], %arg6
+// CHECK: %[[BCST1:.*]] = x86vector.avx.bcst_to_f32.packed %arg3
+// CHECK: %[[ODD1:.*]]  = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg5
+// CHECK: %[[FMA_ODD1:.*]] = vector.fma %[[BCST1]], %[[ODD1]], %arg6
+// Even-Indexed FMAs
+// CHECK: %[[BCST2:.*]] = x86vector.avx.bcst_to_f32.packed %arg4
+// CHECK: %[[EVEN0:.*]] = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg5
+// CHECK: %[[FMA_EVEN0:.*]] = vector.fma %[[BCST2]], %[[EVEN0]], %[[FMA_ODD1]]
+// CHECK: %[[BCST3:.*]] = x86vector.avx.bcst_to_f32.packed %arg1
+// CHECK: %[[EVEN1:.*]] = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg2
+// CHECK: %[[FMA_EVEN1:.*]] = vector.fma %[[BCST3]], %[[EVEN1]], %[[FMA_ODD0]]
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
@@ -54,7 +58,7 @@ module attributes {transform.with_named_sequence} {
 !memrefA = memref<1x1x1xbf16>
 !memrefB = memref<1x8x2xbf16>
 
-func.func @shuffle_fma_rhs_even_index(
+func.func @shuffle_fma_with_lhs_as_even.index_to_f32(
   %arg0: !memrefA, %arg1: !memrefA, %arg2: !memrefB, %arg3: !memrefA,
   %arg4: !memrefA, %arg5: !memrefB, %arg6: !vec) -> !vec
 {
@@ -74,19 +78,22 @@ func.func @shuffle_fma_rhs_even_index(
   return %12 : !vec
 }
 
-// CHECK-LABEL: @shuffle_fma_rhs_even_index
-// CHECK: x86vector.avx.bcst_to_f32.packed
-// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32
-// CHECK: vector.fma
-// CHECK: x86vector.avx.bcst_to_f32.packed
-// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32
-// CHECK: vector.fma
-// CHECK: x86vector.avx.bcst_to_f32.packed
-// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32
-// CHECK: vector.fma
-// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32
-// CHECK: x86vector.avx.bcst_to_f32.packed
-// CHECK: vector.fma
+// The vector.fma at %5 is moved along with its operands after %8.
+// CHECK-LABEL: @shuffle_fma_with_lhs_as_even.index_to_f32
+// Odd-Indexed FMAs
+// CHECK: %[[BCST0:.*]] = x86vector.avx.bcst_to_f32.packed %arg0
+// CHECK: %[[ODD0:.*]]  = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg2
+// CHECK: %[[FMA_ODD0:.*]] = vector.fma %[[BCST0]], %[[ODD0]], %arg6
+// CHECK: %[[BCST1:.*]] = x86vector.avx.bcst_to_f32.packed %arg3
+// CHECK: %[[ODD1:.*]]  = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg5
+// CHECK: %[[FMA_ODD1:.*]] = vector.fma %[[BCST1]], %[[ODD1]], %arg6
+// Even-Indexed FMAs
+// CHECK: %[[BCST2:.*]] = x86vector.avx.bcst_to_f32.packed %arg4
+// CHECK: %[[EVEN0:.*]] = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg5
+// CHECK: %[[FMA_EVEN0:.*]] = vector.fma %[[BCST2]], %[[EVEN0]], %[[FMA_ODD1]]
+// CHECK: %[[EVEN1:.*]] = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg2
+// CHECK: %[[BCST3:.*]] = x86vector.avx.bcst_to_f32.packed %arg1
+// CHECK: %[[FMA_EVEN1:.*]] = vector.fma %[[EVEN1]], %[[BCST3]], %[[FMA_ODD0]]
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
@@ -104,7 +111,7 @@ module attributes {transform.with_named_sequence} {
 !memrefA = memref<1x1x1xbf16>
 !memrefB = memref<1x8x2xbf16>
 
-func.func @negative_fma_lhs_multiple_consumer(
+func.func @negative_fma_operand_has_multiple_consumer(
   %arg0: !memrefA, %arg1: !memrefA, %arg2: !memrefB,
   %arg3: !memrefA, %arg4: !memrefB, %arg5: !vec) -> !vec
 {
@@ -123,66 +130,15 @@ func.func @negative_fma_lhs_multiple_consumer(
   return %12 : !vec
 }
 
-// CHECK-LABEL: @negative_fma_lhs_multiple_consumer
-// CHECK: x86vector.avx.bcst_to_f32.packed
-// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32
-// CHECK: vector.fma
-// CHECK: x86vector.avx.bcst_to_f32.packed
-// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32
-// CHECK: vector.fma
-// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32
-// CHECK: vector.fma
-// CHECK: x86vector.avx.bcst_to_f32.packed
-// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32
-// CHECK: vector.fma
-
-module attributes {transform.with_named_sequence} {
-  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
-    %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
-    transform.apply_patterns to %0 {
-      transform.apply_patterns.x86vector.shuffle_vector_fma_ops
-    } : !transform.any_op
-    transform.yield
-  }
-}
-
-// -----
-
-!vec = vector<8xf32>
-!memrefA = memref<1x1x1xbf16>
-!memrefB = memref<1x8x2xbf16>
-
-func.func @negative_fma_rhs_multiple_consumer(
-  %arg0: !memrefA, %arg1: !memrefA, %arg2: !memrefB, %arg3: !memrefA,
-  %arg4: !memrefA, %arg5: !memrefB, %arg6: !vec) -> !vec
-{
-  %0 = x86vector.avx.bcst_to_f32.packed %arg0 : !memrefA -> !vec
-  %1 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg2 : !memrefB -> !vec
-  %2 = vector.fma %0, %1, %arg6 : !vec
-  %3 = x86vector.avx.bcst_to_f32.packed %arg1 : !memrefA -> !vec
-  %4 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg2 : !memrefB -> !vec
-  %5 = vector.fma %3, %4, %2 : !vec
-  %6 = x86vector.avx.bcst_to_f32.packed %arg3 : !memrefA -> !vec
-  %7 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg5 : !memrefB -> !vec
-  %8 = vector.fma %6, %7, %arg6 : !vec
-  %9 = x86vector.avx.bcst_to_f32.packed %arg4 : !memrefA -> !vec
-  %10 = vector.fma %9, %4, %8 : !vec
-  %11 = vector.fma %5, %10, %arg6 : !vec
-  return %11 : !vec
-}
-
-// CHECK-LABEL: @negative_fma_rhs_multiple_consumer
-// CHECK: x86vector.avx.bcst_to_f32.packed
+// The vector.fma at %5 uses %3 as its LHS operand, which has two consumers; therefore, 
+// the rewrite is not applied.
+// CHECK-LABEL: @negative_fma_operand_has_multiple_consumer
 // CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32
 // CHECK: vector.fma
-// CHECK: x86vector.avx.bcst_to_f32.packed
 // CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32
 // CHECK: vector.fma
-// CHECK: x86vector.avx.bcst_to_f32.packed
 // CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32
 // CHECK: vector.fma
-// CHECK: x86vector.avx.bcst_to_f32.packed
-// CHECK: vector.fma
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
@@ -200,7 +156,7 @@ module attributes {transform.with_named_sequence} {
 !memrefA = memref<1x1x1xbf16>
 !memrefB = memref<1x8x2xbf16>
 
-func.func @negative_fma_multiple_consumer(
+func.func @negative_fma_has_multiple_consumer(
   %arg0: !memrefA, %arg1: !memrefA, %arg2: !memrefB, %arg3: !memrefA,
   %arg4: !memrefA, %arg5: !memrefB, %arg6: !vec) -> !vec
 {
@@ -220,19 +176,13 @@ func.func @negative_fma_multiple_consumer(
   return %12 : !vec
 }
 
-// CHECK-LABEL: @negative_fma_multiple_consumer
-// CHECK: x86vector.avx.bcst_to_f32.packed
+// vector.fma at %5 has two uses; therefore no re-write applied.
+// CHECK-LABEL: @negative_fma_has_multiple_consumer
 // CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32
 // CHECK: vector.fma
-// CHECK: x86vector.avx.bcst_to_f32.packed
 // CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32
 // CHECK: vector.fma
-// CHECK: x86vector.avx.bcst_to_f32.packed
 // CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32
-// CHECK: vector.fma
-// CHECK: x86vector.avx.bcst_to_f32.packed
-// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32
-// CHECK: vector.fma
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
@@ -283,20 +233,16 @@ func.func @negative_no_shuffle_outside_block(
   return %loop : !vec
 }
 
+// vector.fma at %5 has its consumer in an another block (%12); therefore rewrite is not
+// applied.
 // CHECK-LABEL: @negative_no_shuffle_outside_block
-// CHECK: x86vector.avx.bcst_to_f32.packed
 // CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32
 // CHECK: vector.fma
-// CHECK: x86vector.avx.bcst_to_f32.packed
 // CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32
 // CHECK: vector.fma
 // CHECK: scf.if
-// CHECK: x86vector.avx.bcst_to_f32.packed
 // CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32
 // CHECK: vector.fma
-// CHECK: x86vector.avx.bcst_to_f32.packed
-// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32
-// CHECK: vector.fma
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {

>From 03eb2259d21491684f926cbd2cd1da2e6f010d00 Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Thu, 15 Jan 2026 18:53:15 -0800
Subject: [PATCH 4/6] added test-cases for shape-cast

---
 .../Transforms/ShuffleVectorFMAOps.cpp         | 18 +++++++++---------
 1 file changed, 9 insertions(+), 9 deletions(-)

diff --git a/mlir/lib/Dialect/X86Vector/Transforms/ShuffleVectorFMAOps.cpp b/mlir/lib/Dialect/X86Vector/Transforms/ShuffleVectorFMAOps.cpp
index 90109d03d785e..ab192ec0f7bd0 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/ShuffleVectorFMAOps.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/ShuffleVectorFMAOps.cpp
@@ -23,7 +23,7 @@ namespace {
 
 // Validates whether the given operation is an x86vector operation and has only
 // one consumer.
-static bool validateX86OpsHasOneUser(Value op) {
+static bool validateFMAOperands(Value op) {
   if (auto cvt = op.getDefiningOp<x86vector::CvtPackedEvenIndexedToF32Op>())
     return cvt.getResult().hasOneUse();
 
@@ -46,7 +46,7 @@ static bool validateVectorFMAOp(vector::FMAOp fmaOp) {
       !isa<x86vector::CvtPackedEvenIndexedToF32Op>(rhs.getDefiningOp()))
     return false;
 
-  if (!validateX86OpsHasOneUser(lhs) || !validateX86OpsHasOneUser(rhs))
+  if (!validateFMAOperands(lhs) || !validateFMAOperands(rhs))
     return false;
 
   if (lhs.getDefiningOp()->getBlock() != rhs.getDefiningOp()->getBlock())
@@ -66,7 +66,7 @@ static bool validateVectorFMAOp(vector::FMAOp fmaOp) {
 }
 
 // Moves vector.fma along with the lhs and rhs defining operation before its
-// comsumer. If the consumer is vector.ShapeCastOp and has only one user then
+// consumer. If the consumer is vector.ShapeCastOp and has only one user then
 // move before the consumer of vector.ShapeCastOp.
 // TODO: Move before first consumer, if there are multiple.
 static void moveFMA(vector::FMAOp fmaOp) {
@@ -116,15 +116,15 @@ static void moveFMA(vector::FMAOp fmaOp) {
 //   %1 = x86vector.avx.bcst_to_f32.packed
 //   %2 = x86vector.avx.cvt.packed.odd.indexed_to_f32
 //   %3 = vector.fma %1, %2, %arg1
-//   %4 = x86vector.avx.bcst_to_f32.packed
-//   %5 = x86vector.avx.cvt.packed.odd.indexed_to_f32
-//   %6 = vector.fma %4, %5, %arg2
 //   %7 = x86vector.avx.bcst_to_f32.packed
-//   %8 = x86vector.avx.cvt.packed.even.indexed_to_f32
-//   %9 = vector.fma %7, %8, %3
+//   %8 = x86vector.avx.cvt.packed.odd.indexed_to_f32
+//   %9 = vector.fma %7, %8, %arg2
+//   %4 = x86vector.avx.bcst_to_f32.packed
+//   %5 = x86vector.avx.cvt.packed.even.indexed_to_f32
+//   %6 = vector.fma %4, %5, %3
 //   %10 = x86vector.avx.bcst_to_f32.packed
 //   %11 = x86vector.avx.cvt.packed.even.indexed_to_f32
-//   %12 = vector.fma %10, %11, %6
+//   %12 = vector.fma %10, %11, %9
 //   yield %9, %12
 // ```
 // TODO: Shuffling supported only if the FMA, lhs/rhs defining operations

>From 0d91162538198eb9dc537ed176f18b22658fbadd Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Thu, 15 Jan 2026 18:57:35 -0800
Subject: [PATCH 5/6] added test-cases for shape-cast

---
 .../X86Vector/shuffle-vector-fmas.mlir        | 57 +++++++++++++++++++
 1 file changed, 57 insertions(+)

diff --git a/mlir/test/Dialect/X86Vector/shuffle-vector-fmas.mlir b/mlir/test/Dialect/X86Vector/shuffle-vector-fmas.mlir
index 52bd148353b4c..4bf930b51c0c2 100644
--- a/mlir/test/Dialect/X86Vector/shuffle-vector-fmas.mlir
+++ b/mlir/test/Dialect/X86Vector/shuffle-vector-fmas.mlir
@@ -107,6 +107,63 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+!vec = vector<8xf32>
+!vecOut = vector<1x8xf32>
+!memrefA = memref<1x1x1xbf16>
+!memrefB = memref<1x8x2xbf16>
+
+func.func @shuffle_fma_with_shape_cast(
+  %arg0: !memrefA, %arg1: !memrefA, %arg2: !memrefB, %arg3: !memrefA,
+  %arg4: !memrefA, %arg5: !memrefB, %arg6: !vec) -> !vecOut
+{
+  %0 = x86vector.avx.bcst_to_f32.packed %arg0 : !memrefA -> !vec
+  %1 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg2 : !memrefB -> !vec
+  %2 = vector.fma %0, %1, %arg6 : !vec
+  %3 = x86vector.avx.bcst_to_f32.packed %arg1 : !memrefA -> !vec
+  %4 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg2 : !memrefB -> !vec
+  %5 = vector.fma %3, %4, %2 : !vec
+  %res1 = vector.shape_cast %5 : !vec to !vecOut
+  %6 = x86vector.avx.bcst_to_f32.packed %arg3 : !memrefA -> !vec
+  %7 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg5 : !memrefB -> !vec
+  %8 = vector.fma %6, %7, %arg6 : !vec
+  %9 = x86vector.avx.bcst_to_f32.packed %arg4 : !memrefA -> !vec
+  %10 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg5 : !memrefB -> !vec
+  %11 = vector.fma %9, %10, %8 : !vec
+  %res2 = vector.shape_cast %11 : !vec to !vecOut
+  %12 = arith.addf %res1, %res2 : !vecOut
+  return %12 : !vecOut
+}
+
+// CHECK-LABEL: @shuffle_fma_with_shape_cast
+// Odd-Indexed FMAs
+// CHECK: %[[BCST0:.*]] = x86vector.avx.bcst_to_f32.packed %arg0
+// CHECK: %[[ODD0:.*]]  = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg2
+// CHECK: %[[FMA_ODD0:.*]] = vector.fma %[[BCST0]], %[[ODD0]], %arg6
+// CHECK: %[[BCST1:.*]] = x86vector.avx.bcst_to_f32.packed %arg3
+// CHECK: %[[ODD1:.*]]  = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg5
+// CHECK: %[[FMA_ODD1:.*]] = vector.fma %[[BCST1]], %[[ODD1]], %arg6
+// Even-Indexed FMAs
+// CHECK: %[[BCST3:.*]] = x86vector.avx.bcst_to_f32.packed %arg4
+// CHECK: %[[EVEN1:.*]] = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg5
+// CHECK: %[[FMA_EVEN1:.*]] = vector.fma %[[BCST3]], %[[EVEN1]], %[[FMA_ODD1]]
+// CHECK: vector.shape_cast
+// CHECK: %[[BCST2:.*]] = x86vector.avx.bcst_to_f32.packed %arg1
+// CHECK: %[[EVEN0:.*]] = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg2
+// CHECK: %[[FMA_EVEN0:.*]] = vector.fma %[[BCST2]], %[[EVEN0]], %[[FMA_ODD0]]
+// CHECK: vector.shape_cast
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %0 {
+      transform.apply_patterns.x86vector.shuffle_vector_fma_ops
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 !vec = vector<8xf32>
 !memrefA = memref<1x1x1xbf16>
 !memrefB = memref<1x8x2xbf16>

>From c60d3d86a45af0889b6134b665ae8f9e716a88c3 Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Tue, 20 Jan 2026 01:59:40 -0800
Subject: [PATCH 6/6] move the FMA operations using rewriter

---
 .../Transforms/ShuffleVectorFMAOps.cpp        | 24 +++++++++++--------
 1 file changed, 14 insertions(+), 10 deletions(-)

diff --git a/mlir/lib/Dialect/X86Vector/Transforms/ShuffleVectorFMAOps.cpp b/mlir/lib/Dialect/X86Vector/Transforms/ShuffleVectorFMAOps.cpp
index ab192ec0f7bd0..a66546a5d1e45 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/ShuffleVectorFMAOps.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/ShuffleVectorFMAOps.cpp
@@ -69,7 +69,7 @@ static bool validateVectorFMAOp(vector::FMAOp fmaOp) {
 // consumer. If the consumer is vector.ShapeCastOp and has only one user then
 // move before the consumer of vector.ShapeCastOp.
 // TODO: Move before first consumer, if there are multiple.
-static void moveFMA(vector::FMAOp fmaOp) {
+static void moveFMA(PatternRewriter &rewriter, vector::FMAOp fmaOp) {
   Operation *consumer = *fmaOp.getResult().getUsers().begin();
 
   if (auto shapeCastOp = dyn_cast<vector::ShapeCastOp>(consumer)) {
@@ -77,18 +77,19 @@ static void moveFMA(vector::FMAOp fmaOp) {
       Operation *nxtConsumer = *shapeCastOp.getResult().getUsers().begin();
       if (nxtConsumer->getBlock() == fmaOp->getBlock()) {
         consumer = *shapeCastOp.getResult().getUsers().begin();
-        fmaOp.getLhs().getDefiningOp()->moveBefore(consumer);
-        fmaOp.getRhs().getDefiningOp()->moveBefore(consumer);
-        fmaOp->moveBefore(consumer);
-        shapeCastOp->moveBefore(consumer);
+        rewriter.moveOpBefore(fmaOp.getLhs().getDefiningOp(), consumer);
+        rewriter.moveOpBefore(fmaOp.getRhs().getDefiningOp(), consumer);
+        rewriter.moveOpBefore(fmaOp.getOperation(), consumer);
+        rewriter.moveOpBefore(shapeCastOp.getOperation(), consumer);
         return;
       }
     }
   }
 
-  fmaOp.getLhs().getDefiningOp()->moveBefore(consumer);
-  fmaOp.getRhs().getDefiningOp()->moveBefore(consumer);
-  fmaOp->moveBefore(consumer);
+  rewriter.moveOpBefore(fmaOp.getLhs().getDefiningOp(), consumer);
+  rewriter.moveOpBefore(fmaOp.getRhs().getDefiningOp(), consumer);
+  rewriter.moveOpBefore(fmaOp.getOperation(), consumer);
+
   return;
 }
 
@@ -164,11 +165,14 @@ struct ShuffleVectorFMAOps : public OpRewritePattern<vector::FMAOp> {
     }
 
     if (fmaOps.empty())
-      return failure();
+      return rewriter.notifyMatchFailure(
+          fmaOp, "No eligible FMA operations were found: the operation may "
+                 "already be shuffled, there may be no following FMAs, or the "
+                 "following FMAs do not satisfy the shuffle conditions.");
 
     fmaOps.push_back(fmaOp);
     for (auto fmaOp : fmaOps)
-      moveFMA(fmaOp);
+      moveFMA(rewriter, fmaOp);
 
     return success();
   }



More information about the Mlir-commits mailing list