[Mlir-commits] [mlir] [Draft][mlir][x86vector] Shuffle FMAs (PR #172823)

Arun Thangamani llvmlistbot at llvm.org
Thu Dec 18 01:06:15 PST 2025


https://github.com/arun-thmn created https://github.com/llvm/llvm-project/pull/172823

None

>From 027028008c5811f95d4389986e56cd27f70c12ab Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Thu, 18 Dec 2025 01:05:00 -0800
Subject: [PATCH] initial commit for shuffling fma

---
 .../TransformOps/X86VectorTransformOps.td     |  11 ++
 .../mlir/Dialect/X86Vector/Transforms.h       |   2 +
 .../TransformOps/X86VectorTransformOps.cpp    |   5 +
 .../X86Vector/Transforms/CMakeLists.txt       |   1 +
 .../Transforms/ShuffleVectorFMAOps.cpp        | 127 ++++++++++++++++++
 .../X86Vector/shuffle-vector-fmas.mlir        |  51 +++++++
 6 files changed, 197 insertions(+)
 create mode 100644 mlir/lib/Dialect/X86Vector/Transforms/ShuffleVectorFMAOps.cpp
 create mode 100644 mlir/test/Dialect/X86Vector/shuffle-vector-fmas.mlir

diff --git a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
index 12ba5e9f11141..a1c946dc3f91d 100644
--- a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
@@ -49,6 +49,17 @@ def ApplySinkVectorProducerOpsPatternsOp : Op<Transform_Dialect,
   let assemblyFormat = "attr-dict";
 }
 
+def ApplyShuffleVectorFMAOpsPatternsOp : Op<Transform_Dialect,
+    "apply_patterns.x86vector.shuffle_vector_fma_ops",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Collect patterns to sink vector producer operations forward in a block to
+         place them immediately before their first use.
+  }];
+
+  let assemblyFormat = "attr-dict";
+}
+
 
 #endif // X86VECTOR_TRANSFORM_OPS
 
diff --git a/mlir/include/mlir/Dialect/X86Vector/Transforms.h b/mlir/include/mlir/Dialect/X86Vector/Transforms.h
index b9c9054f57890..93c4544a61f62 100644
--- a/mlir/include/mlir/Dialect/X86Vector/Transforms.h
+++ b/mlir/include/mlir/Dialect/X86Vector/Transforms.h
@@ -95,6 +95,8 @@ void populateVectorContractToPackedTypeDotProductPatterns(
 // range by placing them at their earliest legal use site
 void populateSinkVectorProducerOpsPatterns(RewritePatternSet &patterns);
 
+void populateShuffleVectorFMAOpsPatterns(RewritePatternSet &patterns);
+
 //===----------------------------------------------------------------------===//
 /// Helpers extracted from:
 ///   - clang/lib/Headers/avxintrin.h
diff --git a/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
index 25772f2aa57f4..a23f9b701199e 100644
--- a/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
@@ -37,6 +37,11 @@ void mlir::transform::ApplySinkVectorProducerOpsPatternsOp::populatePatterns(
   x86vector::populateSinkVectorProducerOpsPatterns(patterns);
 }
 
+void mlir::transform::ApplyShuffleVectorFMAOpsPatternsOp::populatePatterns(
+    RewritePatternSet &patterns) {
+  x86vector::populateShuffleVectorFMAOpsPatterns(patterns);
+}
+
 //===----------------------------------------------------------------------===//
 // Transform op registration
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
index cc4d3cac0f7ea..95471bc72f65d 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
@@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRX86VectorTransforms
   VectorContractToFMA.cpp
   VectorContractToPackedTypeDotProduct.cpp
   SinkVectorProducerOps.cpp
+  ShuffleVectorFMAOps.cpp
 
   LINK_LIBS PUBLIC
   MLIRArithDialect
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/ShuffleVectorFMAOps.cpp b/mlir/lib/Dialect/X86Vector/Transforms/ShuffleVectorFMAOps.cpp
new file mode 100644
index 0000000000000..c3160529f1fd7
--- /dev/null
+++ b/mlir/lib/Dialect/X86Vector/Transforms/ShuffleVectorFMAOps.cpp
@@ -0,0 +1,127 @@
+//===- ShuffleVectorFMAOps.cpp --------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/Dialect/X86Vector/Transforms.h"
+#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
+
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/PatternMatch.h"
+
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace mlir::vector;
+using namespace mlir::x86vector;
+
+namespace {
+
+static bool validateX86OpsHasOneUser(Value op) {
+
+  if (auto x86Op = op.getDefiningOp<x86vector::CvtPackedEvenIndexedToF32Op>()) {
+    if (!x86Op.getResult().hasOneUse())
+      return false;
+  } else if (auto x86Op = op.getDefiningOp<x86vector::BcstToPackedF32Op>()) {
+    if (!x86Op.getResult().hasOneUse())
+      return false;
+  } else {
+    return false;
+  }
+  return true;
+}
+
+static bool validateVectorFMAOp(vector::FMAOp fmaOp) {
+
+  Value lhs = fmaOp.getLhs();
+  Value rhs = fmaOp.getRhs();
+
+  if (!isa<x86vector::CvtPackedEvenIndexedToF32Op>(lhs.getDefiningOp()) &&
+      !isa<x86vector::CvtPackedEvenIndexedToF32Op>(rhs.getDefiningOp()))
+    return false;
+
+  if (!validateX86OpsHasOneUser(fmaOp.getLhs()) ||
+      !validateX86OpsHasOneUser(fmaOp.getRhs()))
+    return false;
+
+  if (!fmaOp.getResult().hasOneUse())
+    return false;
+
+  return true;
+}
+
+static void moveFMA(vector::FMAOp fmaOp) {
+  Operation *onlyUser = *fmaOp.getResult().getUsers().begin();
+
+  if (auto shapeCastOp = dyn_cast<vector::ShapeCastOp>(onlyUser)) {
+    if (shapeCastOp.getResult().hasOneUse()) {
+      onlyUser = *shapeCastOp.getResult().getUsers().begin();
+      fmaOp.getLhs().getDefiningOp()->moveBefore(onlyUser);
+      fmaOp.getRhs().getDefiningOp()->moveBefore(onlyUser);
+      fmaOp->moveBefore(onlyUser);
+      shapeCastOp->moveBefore(onlyUser);
+      return;
+    }
+  }
+
+  fmaOp.getLhs().getDefiningOp()->moveBefore(onlyUser);
+  fmaOp.getRhs().getDefiningOp()->moveBefore(onlyUser);
+  fmaOp->moveBefore(onlyUser);
+  return;
+}
+
+struct ShuffleVectorFMAOps : public OpRewritePattern<vector::FMAOp> {
+  using OpRewritePattern<vector::FMAOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::FMAOp fmaOp,
+                                PatternRewriter &rewriter) const override {
+
+    if (!validateVectorFMAOp(fmaOp))
+      return failure();
+
+    llvm::SmallVector<vector::FMAOp> fmaOps;
+    Operation *nextOp = fmaOp;
+    bool loopBreak = true;
+
+    while ((nextOp = nextOp->getNextNode())) {
+      if (auto fma = dyn_cast<vector::FMAOp>(nextOp)) {
+        if (isa<x86vector::CvtPackedEvenIndexedToF32Op>(
+                fma.getLhs().getDefiningOp()) ||
+            isa<x86vector::CvtPackedEvenIndexedToF32Op>(
+                fma.getRhs().getDefiningOp())) {
+          if (loopBreak)
+            break;
+        }
+
+        if (validateVectorFMAOp(fma))
+          fmaOps.push_back(fma);
+
+        loopBreak = false;
+      }
+    }
+
+    if (fmaOps.empty())
+      return failure();
+
+    fmaOps.push_back(fmaOp);
+    for (size_t i = 0; i < fmaOps.size(); i++) {
+      moveFMA(fmaOps[i]);
+    }
+
+    return success();
+  }
+};
+
+} // namespace
+
+void x86vector::populateShuffleVectorFMAOpsPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<ShuffleVectorFMAOps>(patterns.getContext());
+}
diff --git a/mlir/test/Dialect/X86Vector/shuffle-vector-fmas.mlir b/mlir/test/Dialect/X86Vector/shuffle-vector-fmas.mlir
new file mode 100644
index 0000000000000..93abecdf07914
--- /dev/null
+++ b/mlir/test/Dialect/X86Vector/shuffle-vector-fmas.mlir
@@ -0,0 +1,51 @@
+// RUN: mlir-opt %s -transform-interpreter -cse -split-input-file | FileCheck %s
+
+!vec = vector<8xf32>
+!memrefA = memref<1x1x1xbf16>
+!memrefB = memref<1x8x2xbf16>
+
+func.func @shuffle_vector_fma(
+  %arg0: !memrefA, %arg1: !memrefA, %arg2: !memrefB, %arg3: !memrefA,
+  %arg4: !memrefA, %arg5: !memrefB, %arg6: !vec) -> !vec
+{
+  %0 = x86vector.avx.bcst_to_f32.packed %arg0 : !memrefA -> !vec
+  %1 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg2 : !memrefB -> !vec
+  %2 = vector.fma %0, %1, %arg6 : !vec
+  %3 = x86vector.avx.bcst_to_f32.packed %arg1 : !memrefA -> !vec
+  %4 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg2 : !memrefB -> !vec
+  %5 = vector.fma %3, %4, %2 : !vec
+  %55 = vector.shape_cast %5 : vector<8xf32> to vector<1x8xf32>
+  %6 = x86vector.avx.bcst_to_f32.packed %arg3 : !memrefA -> !vec
+  %7 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg5 : !memrefB -> !vec
+  %8 = vector.fma %6, %7, %arg6 : !vec
+  %9 = x86vector.avx.bcst_to_f32.packed %arg4 : !memrefA -> !vec
+  %10 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg5 : !memrefB -> !vec
+  %11 = vector.fma %9, %10, %8 : !vec
+  %13 = vector.shape_cast %55 : vector<1x8xf32> to vector<8xf32>
+  %12 = vector.fma %13, %11, %arg6 : !vec
+  return %12 : !vec
+}
+
+// CHECK-LABEL: @shuffle_vector_fma
+// CHECK: x86vector.avx.bcst_to_f32.packed
+// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32
+// CHECK: vector.fma
+// CHECK: x86vector.avx.bcst_to_f32.packed
+// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32
+// CHECK: vector.fma
+// CHECK: x86vector.avx.bcst_to_f32.packed
+// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32
+// CHECK: vector.fma
+// CHECK: x86vector.avx.bcst_to_f32.packed
+// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32
+// CHECK: vector.fma
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %0 {
+      transform.apply_patterns.x86vector.shuffle_vector_fma_ops
+    } : !transform.any_op
+    transform.yield
+  }
+}



More information about the Mlir-commits mailing list