[Mlir-commits] [mlir] [mlir][x86vector] Shuffle BF16 vector.contract output for Flat layout. (PR #174590)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jan 12 07:59:04 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-vector

Author: Arun Thangamani (arun-thmn)

<details>
<summary>Changes</summary>

This patch shuffles the output of a `bf16` type `non-vnni` packed `vector.contract` operation (`flat` layout). The output of the contraction operation is shuffle to match the `flat` layout, before get stored in the `acc` matrix.

Following this transform schedule, the `vector.contract` will be lowered to one of the following operations:
  - x86vector::DotBF16Op with `B` matrix shuffled to compensate the `flat` layout (supported as part of this PR), or
  - vector.fma with loads + broadcast using `bf16` packed operations (supported as part of this PR).

---

Patch is 108.91 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/174590.diff


12 Files Affected:

- (modified) mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td (+11) 
- (modified) mlir/include/mlir/Dialect/X86Vector/Transforms.h (+4) 
- (modified) mlir/include/mlir/Dialect/X86Vector/Utils/X86VectorUtils.h (+35) 
- (modified) mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp (+5) 
- (modified) mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt (+1) 
- (added) mlir/lib/Dialect/X86Vector/Transforms/ShuffleBF16VectorContractResult.cpp (+192) 
- (modified) mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp (+142-31) 
- (modified) mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp (+228-98) 
- (modified) mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp (+295) 
- (modified) mlir/test/Dialect/X86Vector/vector-contract-bf16-to-fma.mlir (+495) 
- (modified) mlir/test/Dialect/X86Vector/vector-contract-to-packed-type-dotproduct.mlir (+392-20) 
- (added) mlir/test/Dialect/shuffle-bf16-vector-contract-result.mlir (+467) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
index 3c73eadf82167..00c611a9f3a7a 100644
--- a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
@@ -60,6 +60,17 @@ def ApplySinkVectorProducerOpsPatternsOp : Op<Transform_Dialect,
   let assemblyFormat = "attr-dict";
 }
 
+def ApplyShuffleBF16VectorContractResultPatternsOp : Op<Transform_Dialect,
+    "apply_patterns.x86vector.shuffle_bf16_vector_contract_result",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Collect patterns to shuffle results of flat layout BF16 type 
+       vector.contract operations.
+  }];
+
+  let assemblyFormat = "attr-dict";
+}
+
 
 #endif // X86VECTOR_TRANSFORM_OPS
 
diff --git a/mlir/include/mlir/Dialect/X86Vector/Transforms.h b/mlir/include/mlir/Dialect/X86Vector/Transforms.h
index c25cdaf2d9428..e07fb4aedf539 100644
--- a/mlir/include/mlir/Dialect/X86Vector/Transforms.h
+++ b/mlir/include/mlir/Dialect/X86Vector/Transforms.h
@@ -100,6 +100,10 @@ void populateVectorContractBF16ToFMAPatterns(RewritePatternSet &patterns);
 // range by placing them at their earliest legal use site.
 void populateSinkVectorProducerOpsPatterns(RewritePatternSet &patterns);
 
+// Shuffle the output of BF16 type flat layout vector.contract operations.
+void populateShuffleBF16VectorContractResultPatterns(
+    RewritePatternSet &patterns);
+
 //===----------------------------------------------------------------------===//
 /// Helpers extracted from:
 ///   - clang/lib/Headers/avxintrin.h
diff --git a/mlir/include/mlir/Dialect/X86Vector/Utils/X86VectorUtils.h b/mlir/include/mlir/Dialect/X86Vector/Utils/X86VectorUtils.h
index 2de9a3122cbd9..3b9a10f77d35f 100644
--- a/mlir/include/mlir/Dialect/X86Vector/Utils/X86VectorUtils.h
+++ b/mlir/include/mlir/Dialect/X86Vector/Utils/X86VectorUtils.h
@@ -9,6 +9,9 @@
 #ifndef MLIR_DIALECT_X86VECTOR_UTILS_X86VECTORUTILS_H_
 #define MLIR_DIALECT_X86VECTOR_UTILS_X86VECTORUTILS_H_
 
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Value.h"
 #include "mlir/Support/LogicalResult.h"
 #include "llvm/ADT/ArrayRef.h"
 #include <cstdint>
@@ -26,6 +29,38 @@ namespace x86vector {
 bool isInVnniLayout(Operation *op, llvm::ArrayRef<AffineMap> indexingMaps,
                     std::optional<unsigned> blockingFactor = std::nullopt);
 
+// Returns true if two contraction ops form a valid pair for VNNI packing.
+// It verifies that both contractions share the appropriate operand, read from
+// the same source buffer, and use constant indices that differ by 8 or 16.
+bool validatePairVectorContract(vector::ContractionOp contractOp,
+                                vector::ContractionOp pairContOp,
+                                bool rhsHasMultipleNonUnitDims,
+                                int64_t nonUnitDimValue);
+
+// Walks backward from a value to find its originating vector read-like op
+// (vector.transfer_read or vector.load), following scf.for iter-args but
+// stopping at layout-transforming ops; returns the read op or nullptr.
+Operation *traceToVectorReadLikeParentOperation(mlir::Value v);
+
+// Recursively traces a value to find a downstream vector write-like op
+// (vector.transfer_write or vector.store), crossing scf.for/yield but
+// stopping at layout-altering ops; returns the first match or nullptr.
+Operation *traceToVectorWriteLikeUserOperation(mlir::Value v);
+
+// Packs the accumulators of two flat BF16 vector.contraction ops into a
+// VNNI-packed layout and replaces the original accumulators to enable post-read
+// packing transformations.
+void shuffleAfterReadLikeOp(PatternRewriter &rewriter, Operation *op,
+                            Operation *op1, vector::ContractionOp contractOp,
+                            vector::ContractionOp pairContractOp,
+                            int64_t nonUnitDimAcc, VectorType accTy);
+
+// Shuffles vectors produced by vector.contraction ops into a flat layout
+// before they are written to memory.
+void shuffleBeforeWriteLikeOp(PatternRewriter &rewriter, Operation *op,
+                              Operation *op1, int64_t nonUnitDimAcc,
+                              VectorType accTy);
+
 } // namespace x86vector
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
index e77d30c9c5ffb..e40ddd3a4b1c0 100644
--- a/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
@@ -42,6 +42,11 @@ void mlir::transform::ApplySinkVectorProducerOpsPatternsOp::populatePatterns(
   x86vector::populateSinkVectorProducerOpsPatterns(patterns);
 }
 
+void mlir::transform::ApplyShuffleBF16VectorContractResultPatternsOp::
+    populatePatterns(RewritePatternSet &patterns) {
+  x86vector::populateShuffleBF16VectorContractResultPatterns(patterns);
+}
+
 //===----------------------------------------------------------------------===//
 // Transform op registration
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
index bbd9be880eb0a..acbc7fcfb635e 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
@@ -5,6 +5,7 @@ add_mlir_dialect_library(MLIRX86VectorTransforms
   VectorContractToPackedTypeDotProduct.cpp
   VectorContractBF16ToFMA.cpp
   SinkVectorProducerOps.cpp
+  ShuffleBF16VectorContractResult.cpp
 
   LINK_LIBS PUBLIC
   MLIRArithDialect
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/ShuffleBF16VectorContractResult.cpp b/mlir/lib/Dialect/X86Vector/Transforms/ShuffleBF16VectorContractResult.cpp
new file mode 100644
index 0000000000000..24b8a7489dbfa
--- /dev/null
+++ b/mlir/lib/Dialect/X86Vector/Transforms/ShuffleBF16VectorContractResult.cpp
@@ -0,0 +1,192 @@
+//===- ShuffleBF16VectorContractResult.cpp --------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/Dialect/X86Vector/Transforms.h"
+#include "mlir/Dialect/X86Vector/Utils/X86VectorUtils.h"
+#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
+
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/PatternMatch.h"
+
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/Casting.h"
+
+using namespace mlir;
+using namespace mlir::vector;
+using namespace mlir::x86vector;
+
+// Shuffle the output of BF16 type flat layout vector.contract operations
+//
+// For example:
+// ```
+//   %1 = vector.load -> vector<1x1xbf16>
+//   %2 = vector.load from memref (%m1) -> vector<1x8xbf16>
+//   %3 = vector.load from memref (%m1) -> vector<1x8xbf16>
+//   %4 = vector.contract %1, %2, %arg0 ->  vector<1x8xf32>
+//   %5 = vector.contract %1, %3, %arg1 ->  vector<1x8xf32>
+//   vector.store %4, %m1
+//   vector.store %5, %m1
+// ```
+// to
+// ```
+//   %1 = vector.load -> vector<1x1xbf16>
+//   %2 = vector.load from memref (%m1) -> vector<1x8xbf16>
+//   %3 = vector.load from memref (%m1) -> vector<1x8xbf16>
+//   %4 = vector.shuffle %arg0, %arg1 [0, 8, 1, 9, 2, 10, 3, 11]
+//   %5 = vector.shuffle %arg0, %arg1 [4, 12, 5, 13, 6, 14, 7, 15]
+//   %6 = vector.contract %1, %2, %4 ->  vector<1x8xf32>
+//   %7 = vector.contract %1, %3, %5 ->  vector<1x8xf32>
+//   %8 = vector.shuffle %6, %7 [0, 8, 1, 9, 2, 10, 3, 11]
+//   %9 = vector.shuffle %6, %7 [4, 12, 5, 13, 6, 14, 7, 15]
+//   vector.store %8, %m1
+//   vector.store %9, %m1
+//```
+struct ShuffleBF16VectorContractResult
+    : public OpRewritePattern<vector::ContractionOp> {
+  using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
+                                PatternRewriter &rewriter) const override {
+
+    if (contractOp.getKind() != vector::CombiningKind::ADD)
+      return rewriter.notifyMatchFailure(contractOp,
+                                         "Expects add combining kind.");
+
+    // TODO: Move this validation to a common utility folder. Planned to
+    // do once (code refactoring), all architecture specific nanokernel
+    // passes are merged into the repo.
+    VectorType lhsTy = contractOp.getLhsType();
+    if (!lhsTy.getElementType().isBF16())
+      return rewriter.notifyMatchFailure(contractOp,
+                                         "Only BF16 lowering is supported.");
+
+    if (isInVnniLayout(contractOp.getOperation(),
+                       contractOp.getIndexingMapsArray(),
+                       /*blockingFactor=*/2))
+      return rewriter.notifyMatchFailure(contractOp,
+                                         "Input matrices in VNNI format.");
+
+    VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
+    if (!accTy)
+      return rewriter.notifyMatchFailure(contractOp, "Wrong accmulator type.");
+
+    if (!accTy.getElementType().isF32())
+      return rewriter.notifyMatchFailure(
+          contractOp, "Only F32 acumulation supported for BF16 type.");
+
+    ArrayRef<int64_t> accShape = accTy.getShape();
+    llvm::SmallVector<int64_t> nonUnitDimAcc;
+    llvm::copy_if(accShape, std::back_inserter(nonUnitDimAcc),
+                  [](int64_t dim) { return dim != 1; });
+
+    if (nonUnitDimAcc.size() != 1)
+      return rewriter.notifyMatchFailure(
+          contractOp, "A or B should be a non-unit dim in acc.");
+
+    int64_t nonUnitDimValue = nonUnitDimAcc.front();
+
+    if (nonUnitDimValue != 8 && nonUnitDimValue != 16)
+      return rewriter.notifyMatchFailure(
+          contractOp, "The accumulator dimension should be 8 or 16");
+
+    ArrayRef<int64_t> lhsShape = lhsTy.getShape();
+    llvm::SmallVector<int64_t> nonUnitDimLhs;
+    llvm::copy_if(lhsShape, std::back_inserter(nonUnitDimLhs),
+                  [](int64_t dim) { return dim != 1; });
+
+    VectorType rhsTy = contractOp.getRhsType();
+    ArrayRef<int64_t> rhsShape = rhsTy.getShape();
+    llvm::SmallVector<int64_t> nonUnitDimRhs;
+    llvm::copy_if(rhsShape, std::back_inserter(nonUnitDimRhs),
+                  [](int64_t dim) { return dim != 1; });
+
+    vector::ContractionOp pairContractOp;
+    bool rhsHasMultipleNonUnitDims =
+        nonUnitDimRhs.size() > nonUnitDimLhs.size();
+
+    // Get the pair vector.contract operation. The pair is decided on:
+    //  (1) - the unitDim operand Lhs or Rhs should be same,
+    //  (2) - the defining source memref should be same for nonUnitDim
+    //  operation, (3) - the nonUnit dim offset difference between the
+    //  vector.contracts should be 8.
+    Operation *nextOp = contractOp;
+    while ((nextOp = nextOp->getNextNode())) {
+      auto contOp = dyn_cast<vector::ContractionOp>(nextOp);
+
+      if (!contOp)
+        continue;
+
+      if (validatePairVectorContract(
+              contractOp, contOp, rhsHasMultipleNonUnitDims, nonUnitDimValue)) {
+        pairContractOp = contOp;
+        break;
+      }
+    }
+
+    if (!pairContractOp)
+      return rewriter.notifyMatchFailure(
+          contractOp, "Coudn't find pair contract operation for shuffling");
+
+    // Trace back to the load or transfer_read operations of the contract
+    // accumulators.
+    Operation *accReadOp0 =
+        traceToVectorReadLikeParentOperation(contractOp.getAcc());
+    Operation *accReadOp1 =
+        traceToVectorReadLikeParentOperation(pairContractOp.getAcc());
+
+    // Iterate dowm to find the users of contact operations until it is store or
+    // transfer_write.
+    Operation *resultWriteOp0 =
+        traceToVectorWriteLikeUserOperation(contractOp.getResult());
+    Operation *resultWriteOp1 =
+        traceToVectorWriteLikeUserOperation(pairContractOp.getResult());
+
+    if (!accReadOp0 || !accReadOp1)
+      return rewriter.notifyMatchFailure(
+          contractOp,
+          "Operands doesn't have load or transfer_read as it's parent op");
+
+    if (!resultWriteOp0 || !resultWriteOp1)
+      return rewriter.notifyMatchFailure(
+          contractOp, "The use of contract operations are neither vector.store "
+                      "or transfer_write");
+
+    if (contractOp->getBlock() == accReadOp1->getBlock() &&
+        contractOp->isBeforeInBlock(accReadOp1))
+      return rewriter.notifyMatchFailure(
+          contractOp, "The load/read operation of pair contract operation is "
+                      "after the contractOp");
+
+    if (pairContractOp->getBlock() == resultWriteOp0->getBlock() &&
+        resultWriteOp0->isBeforeInBlock(pairContractOp))
+      return rewriter.notifyMatchFailure(
+          contractOp, "The store/write operation of contract operation is "
+                      "before the pair contract operation");
+
+    // Shuffle the accumulators of the contract operations.
+    shuffleAfterReadLikeOp(rewriter, accReadOp0, accReadOp1, contractOp,
+                           pairContractOp, nonUnitDimValue, accTy);
+
+    // Shuffle the output of contract operations before it's use.
+    shuffleBeforeWriteLikeOp(rewriter, resultWriteOp0, resultWriteOp1,
+                             nonUnitDimValue, accTy);
+
+    return success();
+  }
+};
+
+void x86vector::populateShuffleBF16VectorContractResultPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<ShuffleBF16VectorContractResult>(patterns.getContext());
+}
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp
index c60d9b91c18e5..eada03977595d 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
 #include "mlir/Dialect/X86Vector/Transforms.h"
@@ -29,7 +30,7 @@ using namespace mlir::x86vector;
 // Verifies that the LHS and RHS operands of a vector.contract are load or
 // vector.transfer_read operations on a memref source buffer, and checks
 // their bounds, dimensions, offsets, and strides.
-static bool validateVectorContractOperands(Value prodOp) {
+static bool validateVectorContractOperands(Value prodOp, bool isVnni) {
   Operation *defOp = prodOp.getDefiningOp();
   if (!defOp)
     return false;
@@ -62,11 +63,13 @@ static bool validateVectorContractOperands(Value prodOp) {
   // Return false if the two innermost strides of the memref are not contiguous.
   // The x86vector.avx.cvt.packed.even/odd.indexed_to_f32 operations require
   // an eight-element tuple of bf16 values to be contiguous.
-  if (!llvm::cast<mlir::MemRefType>(srcType).areTrailingDimsContiguous(2))
+  int dimsToCheck = isVnni ? 2 : 1;
+  if (!llvm::cast<mlir::MemRefType>(srcType).areTrailingDimsContiguous(
+          dimsToCheck))
     return false;
 
   // Return false if the vnni offset of load or transfer_read is not zero.
-  if (getConstantIntValue(indexVals.back()) != 0)
+  if (isVnni && getConstantIntValue(indexVals.back()) != 0)
     return false;
 
   return true;
@@ -96,7 +99,8 @@ static bool validateVectorContractOperands(Value prodOp) {
 // ```
 static SmallVector<memref::SubViewOp>
 getSubviewFromVectorInput(Location loc, PatternRewriter &rewriter, Value prodOp,
-                          ArrayRef<int64_t> nonUnitDimShape, bool isUnitDim) {
+                          ArrayRef<int64_t> nonUnitDimShape, bool isUnitDim,
+                          bool isVNNI) {
 
   Operation *defOp = prodOp.getDefiningOp();
 
@@ -122,11 +126,26 @@ getSubviewFromVectorInput(Location loc, PatternRewriter &rewriter, Value prodOp,
     }
   }
 
-  int vnniDimSize = isUnitDim ? 1 : 2;
+  auto one = rewriter.getIndexAttr(1);
+  llvm::SmallVector<memref::SubViewOp> subviews;
 
+  if (!isVNNI) {
+    SmallVector<OpFoldResult> strides(indexVals.size(), one);
+    SmallVector<OpFoldResult> sizes(indexVals.size(), one);
+    // Retrive twice the nonUnit dim BF16 element for both even and odd
+    // index elements.
+    if (!isUnitDim)
+      mnDimSize = 2 * mnDimSize;
+    sizes[mnDimIdx] = rewriter.getIndexAttr(mnDimSize);
+    auto subview = memref::SubViewOp::create(rewriter, loc, srcBuff, indexVals,
+                                             sizes, strides);
+    subviews.push_back(subview);
+    return subviews;
+  }
+
+  int vnniDimSize = isUnitDim ? 1 : 2;
   auto nonVNNIDimSize = indexVals.size() - 1;
   // Create the size and stride offsets.
-  auto one = rewriter.getIndexAttr(1);
   SmallVector<OpFoldResult> strides(indexVals.size(), one);
   SmallVector<OpFoldResult> sizes(nonVNNIDimSize, one);
 
@@ -139,7 +158,6 @@ getSubviewFromVectorInput(Location loc, PatternRewriter &rewriter, Value prodOp,
   if (isUnitDim)
     indexVals[indexVals.size() - 1] = rewriter.getIndexAttr(1);
 
-  llvm::SmallVector<memref::SubViewOp> subviews;
   auto subview = memref::SubViewOp::create(rewriter, loc, srcBuff, indexVals,
                                            sizes, strides);
   subviews.push_back(subview);
@@ -168,7 +186,7 @@ getSubviewFromVectorInput(Location loc, PatternRewriter &rewriter, Value prodOp,
 // Implements outer product contraction as a sequence of BF16-packed
 // operation even/odd loads and FMA operations.
 //
-// For example:
+// For example (VNNI packed):
 // ```
 //   %1 = vector.load from memref (%m1) -> vector<1x1x2xbf16>
 //   %2 = vector.load from memref (%m2) -> vector<1x8x2xbf16>
@@ -183,6 +201,24 @@ getSubviewFromVectorInput(Location loc, PatternRewriter &rewriter, Value prodOp,
 //   %5 = x86vector.avx.cvt.packed.even.indexed_to_f32 %m2 -> vector<8xf32>
 //   return vector.fma %4, %5, %3
 // ```
+//
+// For example (Flat layout):
+// ```
+//   %1 = vector.load from memref (%m1) -> vector<1x1xbf16>
+//   %2 = vector.load from memref (%m2) -> vector<1x8xbf16>
+//   %3 = vector.contract %1, %2, %arg1
+//   %4 = vector.load from memref (%m2) -> vector<1x8xbf16>
+//   %5 = vector.contract %1, %4, %arg2
+//   scf.yield %3, %4
+// ```
+// to
+// ```
+//   %1 = x86vector.avx.bcst_to_f32.packed %m1[c0] -> vector<8xf32>
+//   %2 = x86vector.avx.cvt.packed.even.indexed_to_f32 %m2 -> vector<8xf32>
+//   %3 = vector.fma %1, %2, %arg1
+//   %4 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %m2 -> vector<8xf32>
+//   %5 = vector.fma %1, %4, %arg2
+//   scf.yield %3, %5
 struct VectorContractBF16ToFMA
     : public OpRewritePattern<vector::ContractionOp> {
   using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
@@ -202,11 +238,9 @@ struct VectorContractBF16ToFMA
       return rewriter.notifyMatchFailure(contractOp,
                                          "Only BF16 lowering is supported.");
 
-    if (!isInVnniLayout(contractOp.getOperation(),
-                        contractOp.getIndexingMapsArray(),
-                        /*blockingFactor=*/2))
-      return rewriter.notifyMatchFailure(contractOp,
-                                         "Input matrices not in VNNI format.");
+    bool isVnni = isInVnniLayout(contractOp.getOperation(),
+                                 contractOp.getIndexingMapsArray(),
+                                 /*blockingFactor=*/2);
 
     VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
     if (!accTy)
@@ -216,6 +250,14 @@ struct VectorContractBF16ToFMA
       return rewriter.notifyMatchFailure(
           contractOp, "Only F32 acumulation supported for BF16 type.");
 
+    ArrayRef<int64_t> accShape = accTy.getShape();
+    llvm::SmallVector<int64_t> nonUnitDi...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/174590


More information about the Mlir-commits mailing list