[Mlir-commits] [mlir] [wip][mlir][x86vector] Shuffle BF16 vector.contract output for Flat layout. (PR #174590)

Arun Thangamani llvmlistbot at llvm.org
Tue Jan 6 06:37:43 PST 2026


https://github.com/arun-thmn updated https://github.com/llvm/llvm-project/pull/174590

>From 5ddf7e96ddeb975534668b0bd0c991a4c483a171 Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Tue, 6 Jan 2026 05:44:34 -0800
Subject: [PATCH 1/2] initial commit for shuffle VC output

---
 .../TransformOps/X86VectorTransformOps.td     |  11 +
 .../mlir/Dialect/X86Vector/Transforms.h       |   3 +
 .../Dialect/X86Vector/Utils/X86VectorUtils.h  |  26 ++
 .../TransformOps/X86VectorTransformOps.cpp    |   5 +
 .../X86Vector/Transforms/CMakeLists.txt       |   1 +
 .../ShuffleBF16VectorContractResult.cpp       | 150 ++++++++
 .../X86Vector/Utils/X86VectorUtils.cpp        | 325 ++++++++++++++++++
 .../shuffle-bf16-vector-contract-result.mlir  |  66 ++++
 8 files changed, 587 insertions(+)
 create mode 100644 mlir/lib/Dialect/X86Vector/Transforms/ShuffleBF16VectorContractResult.cpp
 create mode 100644 mlir/test/Dialect/shuffle-bf16-vector-contract-result.mlir

diff --git a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
index 3c73eadf82167..00c611a9f3a7a 100644
--- a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
@@ -60,6 +60,17 @@ def ApplySinkVectorProducerOpsPatternsOp : Op<Transform_Dialect,
   let assemblyFormat = "attr-dict";
 }
 
+def ApplyShuffleBF16VectorContractResultPatternsOp : Op<Transform_Dialect,
+    "apply_patterns.x86vector.shuffle_bf16_vector_contract_result",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Collect patterns to shuffle results of flat layout BF16 type 
+       vector.contract operations.
+  }];
+
+  let assemblyFormat = "attr-dict";
+}
+
 
 #endif // X86VECTOR_TRANSFORM_OPS
 
diff --git a/mlir/include/mlir/Dialect/X86Vector/Transforms.h b/mlir/include/mlir/Dialect/X86Vector/Transforms.h
index c25cdaf2d9428..538ffaac79998 100644
--- a/mlir/include/mlir/Dialect/X86Vector/Transforms.h
+++ b/mlir/include/mlir/Dialect/X86Vector/Transforms.h
@@ -100,6 +100,9 @@ void populateVectorContractBF16ToFMAPatterns(RewritePatternSet &patterns);
 // range by placing them at their earliest legal use site.
 void populateSinkVectorProducerOpsPatterns(RewritePatternSet &patterns);
 
+void populateShuffleBF16VectorContractResultPatterns(
+    RewritePatternSet &patterns);
+
 //===----------------------------------------------------------------------===//
 /// Helpers extracted from:
 ///   - clang/lib/Headers/avxintrin.h
diff --git a/mlir/include/mlir/Dialect/X86Vector/Utils/X86VectorUtils.h b/mlir/include/mlir/Dialect/X86Vector/Utils/X86VectorUtils.h
index 2de9a3122cbd9..7267c5c2032d9 100644
--- a/mlir/include/mlir/Dialect/X86Vector/Utils/X86VectorUtils.h
+++ b/mlir/include/mlir/Dialect/X86Vector/Utils/X86VectorUtils.h
@@ -9,6 +9,9 @@
 #ifndef MLIR_DIALECT_X86VECTOR_UTILS_X86VECTORUTILS_H_
 #define MLIR_DIALECT_X86VECTOR_UTILS_X86VECTORUTILS_H_
 
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Value.h"
 #include "mlir/Support/LogicalResult.h"
 #include "llvm/ADT/ArrayRef.h"
 #include <cstdint>
@@ -26,6 +29,29 @@ namespace x86vector {
 bool isInVnniLayout(Operation *op, llvm::ArrayRef<AffineMap> indexingMaps,
                     std::optional<unsigned> blockingFactor = std::nullopt);
 
+bool validatePairVectorContract(vector::ContractionOp contractOp,
+                                vector::ContractionOp pairContOp,
+                                bool rhsHasMultipleNonUnitDims,
+                                int64_t nonUnitDimValue);
+
+Operation *traceToVectorReadLikeParentOperation(mlir::Value v);
+
+Operation *traceToVectorWriteLikeUserOperation(mlir::Value v);
+
+void shuffleBeforeWriteLikeOp(PatternRewriter &rewriter, Operation *op,
+                              Operation *op1, int64_t nonUnitDimAcc,
+                              VectorType accTy);
+
+void shuffleAfterReadLikeOp(PatternRewriter &rewriter, Operation *op,
+                            Operation *op1, vector::ContractionOp contractOp,
+                            vector::ContractionOp pairContractOp,
+                            int64_t nonUnitDimAcc, VectorType accTy);
+
+void shuffleNonUnitDimOperand(PatternRewriter &rewriter, Operation *op,
+                              Operation *op1, vector::ContractionOp contractOp,
+                              vector::ContractionOp pairContractOp,
+                              int64_t nonUnitDimAcc, VectorType Ty);
+
 } // namespace x86vector
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
index e77d30c9c5ffb..e40ddd3a4b1c0 100644
--- a/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
@@ -42,6 +42,11 @@ void mlir::transform::ApplySinkVectorProducerOpsPatternsOp::populatePatterns(
   x86vector::populateSinkVectorProducerOpsPatterns(patterns);
 }
 
+void mlir::transform::ApplyShuffleBF16VectorContractResultPatternsOp::
+    populatePatterns(RewritePatternSet &patterns) {
+  x86vector::populateShuffleBF16VectorContractResultPatterns(patterns);
+}
+
 //===----------------------------------------------------------------------===//
 // Transform op registration
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
index bbd9be880eb0a..acbc7fcfb635e 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
@@ -5,6 +5,7 @@ add_mlir_dialect_library(MLIRX86VectorTransforms
   VectorContractToPackedTypeDotProduct.cpp
   VectorContractBF16ToFMA.cpp
   SinkVectorProducerOps.cpp
+  ShuffleBF16VectorContractResult.cpp
 
   LINK_LIBS PUBLIC
   MLIRArithDialect
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/ShuffleBF16VectorContractResult.cpp b/mlir/lib/Dialect/X86Vector/Transforms/ShuffleBF16VectorContractResult.cpp
new file mode 100644
index 0000000000000..2081de9f2be91
--- /dev/null
+++ b/mlir/lib/Dialect/X86Vector/Transforms/ShuffleBF16VectorContractResult.cpp
@@ -0,0 +1,150 @@
+//===-
+//ShuffleBF16VectorContractResult.cpp-----------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/Dialect/X86Vector/Transforms.h"
+#include "mlir/Dialect/X86Vector/Utils/X86VectorUtils.h"
+#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
+
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/PatternMatch.h"
+
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/Casting.h"
+
+using namespace mlir;
+using namespace mlir::vector;
+using namespace mlir::x86vector;
+
+struct ShuffleBF16VectorContractResult
+    : public OpRewritePattern<vector::ContractionOp> {
+  using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
+                                PatternRewriter &rewriter) const override {
+
+    if (contractOp.getKind() != vector::CombiningKind::ADD)
+      return rewriter.notifyMatchFailure(contractOp,
+                                         "Expects add combining kind.");
+
+    // TODO: Move this validation to a common utility folder. Planned to
+    // do once (code refactoring), all architecture specific nanokernel
+    // passes are merged into the repo.
+    VectorType lhsTy = contractOp.getLhsType();
+    if (!lhsTy.getElementType().isBF16())
+      return rewriter.notifyMatchFailure(contractOp,
+                                         "Only BF16 lowering is supported.");
+
+    if (isInVnniLayout(contractOp.getOperation(),
+                       contractOp.getIndexingMapsArray(),
+                       /*blockingFactor=*/2))
+      return rewriter.notifyMatchFailure(contractOp,
+                                         "Input matrices in VNNI format.");
+
+    VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
+    if (!accTy)
+      return rewriter.notifyMatchFailure(contractOp, "Wrong accmulator type.");
+
+    if (!accTy.getElementType().isF32())
+      return rewriter.notifyMatchFailure(
+          contractOp, "Only F32 acumulation supported for BF16 type.");
+
+    ArrayRef<int64_t> accShape = accTy.getShape();
+    llvm::SmallVector<int64_t> nonUnitDimAcc;
+    llvm::copy_if(accShape, std::back_inserter(nonUnitDimAcc),
+                  [](int64_t dim) { return dim != 1; });
+
+    if (nonUnitDimAcc.size() != 1)
+      return rewriter.notifyMatchFailure(
+          contractOp, "A or B should be a non-unit dim in acc.");
+
+    int64_t nonUnitDimValue = nonUnitDimAcc.front();
+
+    if (nonUnitDimValue != 8 && nonUnitDimValue != 16)
+      return rewriter.notifyMatchFailure(
+          contractOp, "The accumulator dimension should be 8 or 16");
+
+    ArrayRef<int64_t> lhsShape = lhsTy.getShape();
+    llvm::SmallVector<int64_t> nonUnitDimLhs;
+    llvm::copy_if(lhsShape, std::back_inserter(nonUnitDimLhs),
+                  [](int64_t dim) { return dim != 1; });
+
+    VectorType rhsTy = contractOp.getRhsType();
+    ArrayRef<int64_t> rhsShape = rhsTy.getShape();
+    llvm::SmallVector<int64_t> nonUnitDimRhs;
+    llvm::copy_if(rhsShape, std::back_inserter(nonUnitDimRhs),
+                  [](int64_t dim) { return dim != 1; });
+
+    if ((nonUnitDimValue == 16) && (nonUnitDimLhs.size() - 1) > 0 &&
+        (nonUnitDimRhs.size() - 1) > 0)
+      return rewriter.notifyMatchFailure(contractOp,
+                                         "Excepts unit dimensions for either "
+                                         "LHS or RHS shape.");
+    if (nonUnitDimValue == 8 && nonUnitDimLhs.size() > 0 &&
+        nonUnitDimRhs.size() > 0)
+      return rewriter.notifyMatchFailure(contractOp,
+                                         "Excepts unit dimensions for either "
+                                         "LHS or RHS shape.");
+
+    vector::ContractionOp pairContractOp;
+    bool rhsHasMultipleNonUnitDims = nonUnitDimValue == 16
+                                         ? (nonUnitDimRhs.size() - 1) > 0
+                                         : nonUnitDimRhs.size() > 0;
+
+    Operation *nextOp = contractOp;
+    while ((nextOp = nextOp->getNextNode())) {
+      auto contOp = dyn_cast<vector::ContractionOp>(nextOp);
+
+      if (!contOp)
+        continue;
+
+      if (validatePairVectorContract(
+              contractOp, contOp, rhsHasMultipleNonUnitDims, nonUnitDimValue)) {
+        pairContractOp = contOp;
+        break;
+      }
+    }
+
+    if (!pairContractOp)
+      return failure();
+
+    Operation *accReadOp0 =
+        traceToVectorReadLikeParentOperation(contractOp.getAcc());
+    Operation *accReadOp1 =
+        traceToVectorReadLikeParentOperation(pairContractOp.getAcc());
+
+    Operation *resultWriteOp0 =
+        traceToVectorWriteLikeUserOperation(contractOp.getResult());
+    Operation *resultWriteOp1 =
+        traceToVectorWriteLikeUserOperation(pairContractOp.getResult());
+
+    if (!accReadOp0 || !accReadOp1)
+      return failure();
+
+    if (!resultWriteOp0 || !resultWriteOp1)
+      return failure();
+
+    shuffleAfterReadLikeOp(rewriter, accReadOp0, accReadOp1, contractOp,
+                           pairContractOp, nonUnitDimValue, accTy);
+    shuffleBeforeWriteLikeOp(rewriter, resultWriteOp0, resultWriteOp1,
+                             nonUnitDimValue, accTy);
+
+    return success();
+  }
+};
+
+void x86vector::populateShuffleBF16VectorContractResultPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<ShuffleBF16VectorContractResult>(patterns.getContext());
+}
diff --git a/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp b/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp
index ccb2e92fdd9e2..65cd0773160ef 100644
--- a/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp
+++ b/mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp
@@ -10,11 +10,18 @@
 
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/IR/Types.h"
 
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/Casting.h"
+
+#include "llvm/ADT/ArrayRef.h"
+#include <cassert>
+
 namespace mlir {
 namespace x86vector {
 
@@ -104,5 +111,323 @@ bool isInVnniLayout(Operation *op, ArrayRef<AffineMap> indexingMaps,
   return true;
 }
 
+struct ShuffleMasks {
+  llvm::ArrayRef<int64_t> maskLo;
+  llvm::ArrayRef<int64_t> maskHi;
+};
+
+inline ShuffleMasks getShuffleMasks(int64_t nonUnitDimAcc) {
+  // We only support these two layouts for now.
+  assert((nonUnitDimAcc == 8 || nonUnitDimAcc == 16) &&
+         "Unsupported nonUnitDimAcc value");
+
+  static constexpr int64_t maskLo8[] = {0, 8, 1, 9, 2, 10, 3, 11};
+  static constexpr int64_t maskHi8[] = {4, 12, 5, 13, 6, 14, 7, 15};
+
+  static constexpr int64_t maskLo16[] = {0, 1, 2, 3, 16, 17, 18, 19,
+                                         4, 5, 6, 7, 20, 21, 22, 23};
+  static constexpr int64_t maskHi16[] = {8,  9,  10, 11, 24, 25, 26, 27,
+                                         12, 13, 14, 15, 28, 29, 30, 31};
+
+  if (nonUnitDimAcc == 16)
+    return {maskLo16, maskHi16};
+
+  // nonUnitDimAcc == 8
+  return {maskLo8, maskHi8};
+}
+
+Operation *traceToVectorReadLikeParentOperation(Value v) {
+  while (true) {
+    // Case 1: Value defined by an operation
+    if (Operation *defOp = v.getDefiningOp()) {
+      if (isa<vector::TransferReadOp, vector::LoadOp>(defOp)) {
+        return defOp;
+      }
+
+      if (isa<vector::ShapeCastOp, vector::ShuffleOp>(defOp)) {
+        return nullptr;
+      }
+
+      // Continue tracing (accumulators usually forward the value)
+      if (defOp->getNumOperands() == 1) {
+        v = defOp->getOperand(0);
+        continue;
+      }
+
+      return nullptr;
+    }
+
+    // Case 2: BlockArgument (scf.for iter_arg)
+    if (auto barg = dyn_cast<BlockArgument>(v)) {
+      auto *parentOp = barg.getOwner()->getParentOp();
+
+      if (auto forOp = dyn_cast<scf::ForOp>(parentOp)) {
+        unsigned argNum = barg.getArgNumber();
+
+        // arg0 = induction variable (not an iter_arg)
+        if (argNum == 0)
+          return nullptr;
+
+        unsigned iterIdx = argNum - 1;
+        v = forOp.getInitArgs()[iterIdx];
+        continue;
+      }
+
+      return nullptr;
+    }
+
+    return nullptr;
+  }
+}
+
+Operation *traceToVectorWriteLikeUserOperation(Value v) {
+  for (OpOperand &use : v.getUses()) {
+    Operation *user = use.getOwner();
+
+    // --- TERMINAL OPS ---
+    if (isa<vector::TransferWriteOp>(user) || isa<vector::StoreOp>(user)) {
+      return user;
+    }
+
+    if (isa<vector::ShapeCastOp, vector::ShuffleOp>(user)) {
+      return nullptr;
+    }
+
+    // --- SCF YIELD ---
+    if (auto yield = dyn_cast<scf::YieldOp>(user)) {
+      Operation *parent = yield->getParentOp();
+      unsigned idx = use.getOperandNumber();
+      if (auto *res =
+              traceToVectorWriteLikeUserOperation(parent->getResult(idx)))
+        return res;
+      continue;
+    }
+
+    // --- SCF FOR ---
+    if (auto forOp = dyn_cast<scf::ForOp>(user)) {
+      unsigned idx = use.getOperandNumber();
+      if (auto *res = traceToVectorWriteLikeUserOperation(forOp.getResult(idx)))
+        return res;
+      continue;
+    }
+
+    // --- GENERIC CASE ---
+    for (Value res : user->getResults()) {
+      if (auto *found = traceToVectorWriteLikeUserOperation(res))
+        return found;
+    }
+  }
+
+  return nullptr;
+}
+
+static void rewriteUses(mlir::Value oldVal, mlir::Value newVal,
+                        mlir::Operation *targetContract,
+                        mlir::PatternRewriter &rewriter) {
+  for (mlir::OpOperand &use : llvm::make_early_inc_range(oldVal.getUses())) {
+
+    mlir::Operation *user = use.getOwner();
+
+    // if (user == targetContract ||
+    if (mlir::isa<mlir::vector::ContractionOp>(user) ||
+        mlir::isa<mlir::scf::ForOp>(user)) {
+      use.set(newVal);
+    }
+  }
+}
+
+void shuffleAfterReadLikeOp(mlir::PatternRewriter &rewriter,
+                            mlir::Operation *opA, mlir::Operation *opB,
+                            mlir::vector::ContractionOp contractA,
+                            mlir::vector::ContractionOp contractB,
+                            int64_t nonUnitDimAcc, mlir::VectorType accTy) {
+  mlir::Operation *insertAfter = opA->isBeforeInBlock(opB) ? opB : opA;
+
+  rewriter.setInsertionPointAfter(insertAfter);
+  mlir::Location loc = insertAfter->getLoc();
+
+  auto elemTy = accTy.getElementType();
+  auto flatTy = mlir::VectorType::get(nonUnitDimAcc, elemTy);
+
+  auto castA = mlir::vector::ShapeCastOp::create(rewriter, loc, flatTy,
+                                                 opA->getResult(0));
+
+  auto castB = mlir::vector::ShapeCastOp::create(rewriter, loc, flatTy,
+                                                 opB->getResult(0));
+
+  auto masks = getShuffleMasks(nonUnitDimAcc);
+
+  auto shuffleLo = mlir::vector::ShuffleOp::create(rewriter, loc, flatTy, castA,
+                                                   castB, masks.maskLo);
+
+  auto shuffleHi = mlir::vector::ShuffleOp::create(rewriter, loc, flatTy, castA,
+                                                   castB, masks.maskHi);
+
+  auto newAccA =
+      mlir::vector::ShapeCastOp::create(rewriter, loc, accTy, shuffleLo);
+
+  auto newAccB =
+      mlir::vector::ShapeCastOp::create(rewriter, loc, accTy, shuffleHi);
+
+  rewriteUses(opA->getResult(0), newAccA.getResult(), contractA, rewriter);
+
+  rewriteUses(opB->getResult(0), newAccB.getResult(), contractB, rewriter);
+}
+
+void shuffleBeforeWriteLikeOp(mlir::PatternRewriter &rewriter,
+                              mlir::Operation *opA, mlir::Operation *opB,
+                              int64_t nonUnitDimAcc, mlir::VectorType accTy) {
+  // Helper to extract vector operand from write-like ops
+  auto getWrittenVector = [](mlir::Operation *op) -> mlir::Value {
+    if (auto write = mlir::dyn_cast<mlir::vector::TransferWriteOp>(op))
+      return write.getVector();
+    if (auto store = mlir::dyn_cast<mlir::vector::StoreOp>(op))
+      return store.getValueToStore();
+    return nullptr;
+  };
+
+  mlir::Value vecA = getWrittenVector(opA);
+  mlir::Value vecB = getWrittenVector(opB);
+
+  assert(vecA && vecB && "expected vector write-like ops");
+
+  // Decide insertion point and location
+  mlir::Operation *insertBefore = opA->isBeforeInBlock(opB) ? opA : opB;
+
+  rewriter.setInsertionPoint(insertBefore);
+  mlir::Location loc = insertBefore->getLoc();
+
+  auto elemTy = accTy.getElementType();
+  auto flatTy = mlir::VectorType::get(nonUnitDimAcc, elemTy);
+
+  // Flatten vectors
+  auto castA = mlir::vector::ShapeCastOp::create(rewriter, loc, flatTy, vecA);
+
+  auto castB = mlir::vector::ShapeCastOp::create(rewriter, loc, flatTy, vecB);
+
+  // TODO: derive shuffle masks instead of hard-coding
+  auto masks = getShuffleMasks(nonUnitDimAcc);
+
+  auto shuffledLo = mlir::vector::ShuffleOp::create(rewriter, loc, flatTy,
+                                                    castA, castB, masks.maskLo);
+
+  auto shuffledHi = mlir::vector::ShuffleOp::create(rewriter, loc, flatTy,
+                                                    castA, castB, masks.maskHi);
+
+  // Cast back to accumulator type
+  auto newVecA =
+      mlir::vector::ShapeCastOp::create(rewriter, loc, accTy, shuffledLo);
+
+  auto newVecB =
+      mlir::vector::ShapeCastOp::create(rewriter, loc, accTy, shuffledHi);
+
+  // Update write operands in place
+  opA->setOperand(0, newVecA.getResult());
+  opB->setOperand(0, newVecB.getResult());
+}
+
+void shuffleNonUnitDimOperand(mlir::PatternRewriter &rewriter,
+                              mlir::Operation *opA, mlir::Operation *opB,
+                              mlir::vector::ContractionOp contractA,
+                              mlir::vector::ContractionOp contractB,
+                              int64_t nonUnitDimAcc, mlir::VectorType Ty) {
+  mlir::Operation *insertAfter = opA->isBeforeInBlock(opB) ? opB : opA;
+
+  rewriter.setInsertionPointAfter(insertAfter);
+  mlir::Location loc = insertAfter->getLoc();
+
+  auto elemTy = Ty.getElementType();
+  auto flatTy = mlir::VectorType::get(nonUnitDimAcc, elemTy);
+
+  auto castA = mlir::vector::ShapeCastOp::create(rewriter, loc, flatTy,
+                                                 opA->getResult(0));
+
+  auto castB = mlir::vector::ShapeCastOp::create(rewriter, loc, flatTy,
+                                                 opB->getResult(0));
+
+  static constexpr int64_t maskLo[] = {
+      0,  32, 1,  33, 2,  34, 3,  35, 8,  40, 9,  41, 10, 42, 11, 43,
+      16, 48, 17, 49, 18, 50, 19, 51, 24, 56, 25, 57, 26, 58, 27, 59};
+  static constexpr int64_t maskHi[] = {
+      4,  36, 5,  37, 6,  38, 7,  39, 12, 44, 13, 45, 14, 46, 15, 47,
+      20, 52, 21, 53, 22, 54, 23, 55, 28, 60, 29, 61, 30, 62, 31, 63};
+
+  auto shuffleLo = mlir::vector::ShuffleOp::create(rewriter, loc, flatTy, castA,
+                                                   castB, maskLo);
+
+  auto shuffleHi = mlir::vector::ShuffleOp::create(rewriter, loc, flatTy, castA,
+                                                   castB, maskHi);
+
+  auto newAccA =
+      mlir::vector::ShapeCastOp::create(rewriter, loc, Ty, shuffleLo);
+
+  auto newAccB =
+      mlir::vector::ShapeCastOp::create(rewriter, loc, Ty, shuffleHi);
+
+  rewriteUses(opA->getResult(0), newAccA.getResult(), contractA, rewriter);
+
+  rewriteUses(opB->getResult(0), newAccB.getResult(), contractB, rewriter);
+}
+
+bool validatePairVectorContract(vector::ContractionOp contractOp,
+                                vector::ContractionOp pairContOp,
+                                bool rhsHasMultipleNonUnitDims,
+                                int64_t nonUnitDimValue) {
+
+  if (!(contractOp.getLhs() == pairContOp.getLhs()) &&
+      !(contractOp.getRhs() == pairContOp.getRhs()))
+    return false;
+
+  if (rhsHasMultipleNonUnitDims &&
+      !(contractOp.getLhs() == pairContOp.getLhs()))
+    return false;
+
+  if (!rhsHasMultipleNonUnitDims &&
+      !(contractOp.getRhs() == pairContOp.getRhs()))
+    return false;
+
+  auto op =
+      rhsHasMultipleNonUnitDims ? contractOp.getRhs() : contractOp.getLhs();
+  auto op1 =
+      rhsHasMultipleNonUnitDims ? pairContOp.getRhs() : pairContOp.getLhs();
+
+  Value srcBuff;
+  SmallVector<OpFoldResult> indexVals;
+  llvm::TypeSwitch<mlir::Operation *>(op.getDefiningOp())
+      .Case<vector::TransferReadOp, vector::LoadOp>([&](auto readOp) {
+        srcBuff = readOp.getOperand(0);
+        indexVals = SmallVector<OpFoldResult>(readOp.getIndices().begin(),
+                                              readOp.getIndices().end());
+      });
+
+  Value srcBuff1;
+  SmallVector<OpFoldResult> indexVals1;
+  llvm::TypeSwitch<mlir::Operation *>(op1.getDefiningOp())
+      .Case<vector::TransferReadOp, vector::LoadOp>([&](auto readOp) {
+        srcBuff1 = readOp.getOperand(0);
+        indexVals1 = SmallVector<OpFoldResult>(readOp.getIndices().begin(),
+                                               readOp.getIndices().end());
+      });
+
+  if (!srcBuff || !srcBuff1)
+    return false;
+
+  if (!(srcBuff == srcBuff1))
+    return false;
+
+  for (size_t i = 0; i < indexVals.size(); i++) {
+    if (getConstantIntValue(indexVals[i]) == getConstantIntValue(indexVals1[i]))
+      continue;
+
+    auto value1 = *getConstantIntValue(indexVals1[i]);
+    auto value2 = *getConstantIntValue(indexVals[i]);
+
+    if ((value1 - value2) != nonUnitDimValue)
+      return false;
+  }
+
+  return true;
+}
+
 } // namespace x86vector
 } // namespace mlir
diff --git a/mlir/test/Dialect/shuffle-bf16-vector-contract-result.mlir b/mlir/test/Dialect/shuffle-bf16-vector-contract-result.mlir
new file mode 100644
index 0000000000000..7b1a6eaa1258a
--- /dev/null
+++ b/mlir/test/Dialect/shuffle-bf16-vector-contract-result.mlir
@@ -0,0 +1,66 @@
+// RUN: mlir-opt %s -transform-interpreter -cse -split-input-file | FileCheck %s
+
+
+!vecA = vector<1x1x1xbf16>
+!vecB = vector<1x1x8xbf16>
+!vecC = vector<1x8xf32>
+!memrefA = memref<1x4x1xbf16>
+!memrefB = memref<1x1x32xbf16>
+!memrefC = memref<2x32xf32>
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map1 = affine_map<(d0,  d1, d2, d3) -> (d0, d3, d2)>
+#map2 = affine_map<(d0,  d1, d2, d3) -> (d1, d2)>
+func.func @shuffle_VC_output_flat_layout(
+  %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
+{
+  %c0 = arith.constant 0 : index
+  %c8 = arith.constant 8 : index
+  %0 = ub.poison : bf16
+  %1 = vector.load %arg0[%c0, %c0, %c0] :
+        !memrefA, !vecA
+  %2 = vector.load %arg1[%c0, %c0, %c0] :
+        !memrefB, !vecB
+  %3 = vector.load %arg1[%c0, %c0, %c8] :
+        !memrefB, !vecB
+  %4 = vector.load %arg2[%c0, %c0] :
+        !memrefC, !vecC
+  %5 = vector.load %arg2[%c0, %c8] :
+        !memrefC, !vecC
+
+  %6 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["reduction", "parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %1, %2, %4
+    : !vecA, !vecB into !vecC
+
+  %7 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["reduction", "parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %1, %3, %5
+    : !vecA, !vecB into !vecC
+
+  vector.store %6, %arg2[%c0, %c0] : !memrefC, !vecC
+
+  vector.store %7, %arg2[%c0, %c8] : !memrefC, !vecC
+
+  return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @shuffle_VC_output_flat_layout
+// CHECK: vector.shuffle
+// CHECK-NEXT: vector.shuffle
+// CHECK: vector.contract
+// CHECK: vector.shuffle
+// CHECK-NEXT: vector.shuffle
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86vector.shuffle_bf16_vector_contract_result
+    } : !transform.any_op
+    transform.yield
+  }
+}

>From 7fcfa7772f5e15ebb12653a10ba0c9d061d4bd84 Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Tue, 6 Jan 2026 06:37:29 -0800
Subject: [PATCH 2/2] fix header format issue

---
 .../X86Vector/Transforms/ShuffleBF16VectorContractResult.cpp   | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/X86Vector/Transforms/ShuffleBF16VectorContractResult.cpp b/mlir/lib/Dialect/X86Vector/Transforms/ShuffleBF16VectorContractResult.cpp
index 2081de9f2be91..faa1b0cd48131 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/ShuffleBF16VectorContractResult.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/ShuffleBF16VectorContractResult.cpp
@@ -1,5 +1,4 @@
-//===-
-//ShuffleBF16VectorContractResult.cpp-----------------------------------------===//
+//===- ShuffleBF16VectorContractResult.cpp --------------------------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.



More information about the Mlir-commits mailing list