[Mlir-commits] [mlir] [mlir][x86vector] Shuffle BF16 vector.contract output for Flat layout. (PR #174590)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jan 12 07:59:04 PST 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-vector
Author: Arun Thangamani (arun-thmn)
<details>
<summary>Changes</summary>
This patch shuffles the output of a `bf16` type `non-vnni` packed `vector.contract` operation (`flat` layout). The output of the contraction operation is shuffle to match the `flat` layout, before get stored in the `acc` matrix.
Following this transform schedule, the `vector.contract` will be lowered to one of the following operations:
- x86vector::DotBF16Op with `B` matrix shuffled to compensate the `flat` layout (supported as part of this PR), or
- vector.fma with loads + broadcast using `bf16` packed operations (supported as part of this PR).
---
Patch is 108.91 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/174590.diff
12 Files Affected:
- (modified) mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td (+11)
- (modified) mlir/include/mlir/Dialect/X86Vector/Transforms.h (+4)
- (modified) mlir/include/mlir/Dialect/X86Vector/Utils/X86VectorUtils.h (+35)
- (modified) mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp (+5)
- (modified) mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt (+1)
- (added) mlir/lib/Dialect/X86Vector/Transforms/ShuffleBF16VectorContractResult.cpp (+192)
- (modified) mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp (+142-31)
- (modified) mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp (+228-98)
- (modified) mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp (+295)
- (modified) mlir/test/Dialect/X86Vector/vector-contract-bf16-to-fma.mlir (+495)
- (modified) mlir/test/Dialect/X86Vector/vector-contract-to-packed-type-dotproduct.mlir (+392-20)
- (added) mlir/test/Dialect/shuffle-bf16-vector-contract-result.mlir (+467)
``````````diff
diff --git a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
index 3c73eadf82167..00c611a9f3a7a 100644
--- a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
@@ -60,6 +60,17 @@ def ApplySinkVectorProducerOpsPatternsOp : Op<Transform_Dialect,
let assemblyFormat = "attr-dict";
}
+def ApplyShuffleBF16VectorContractResultPatternsOp : Op<Transform_Dialect,
+ "apply_patterns.x86vector.shuffle_bf16_vector_contract_result",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ let description = [{
+ Collect patterns to shuffle results of flat layout BF16 type
+ vector.contract operations.
+ }];
+
+ let assemblyFormat = "attr-dict";
+}
+
#endif // X86VECTOR_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/X86Vector/Transforms.h b/mlir/include/mlir/Dialect/X86Vector/Transforms.h
index c25cdaf2d9428..e07fb4aedf539 100644
--- a/mlir/include/mlir/Dialect/X86Vector/Transforms.h
+++ b/mlir/include/mlir/Dialect/X86Vector/Transforms.h
@@ -100,6 +100,10 @@ void populateVectorContractBF16ToFMAPatterns(RewritePatternSet &patterns);
// range by placing them at their earliest legal use site.
void populateSinkVectorProducerOpsPatterns(RewritePatternSet &patterns);
+// Shuffle the output of BF16 type flat layout vector.contract operations.
+void populateShuffleBF16VectorContractResultPatterns(
+ RewritePatternSet &patterns);
+
//===----------------------------------------------------------------------===//
/// Helpers extracted from:
/// - clang/lib/Headers/avxintrin.h
diff --git a/mlir/include/mlir/Dialect/X86Vector/Utils/X86VectorUtils.h b/mlir/include/mlir/Dialect/X86Vector/Utils/X86VectorUtils.h
index 2de9a3122cbd9..3b9a10f77d35f 100644
--- a/mlir/include/mlir/Dialect/X86Vector/Utils/X86VectorUtils.h
+++ b/mlir/include/mlir/Dialect/X86Vector/Utils/X86VectorUtils.h
@@ -9,6 +9,9 @@
#ifndef MLIR_DIALECT_X86VECTOR_UTILS_X86VECTORUTILS_H_
#define MLIR_DIALECT_X86VECTOR_UTILS_X86VECTORUTILS_H_
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Value.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/ArrayRef.h"
#include <cstdint>
@@ -26,6 +29,38 @@ namespace x86vector {
bool isInVnniLayout(Operation *op, llvm::ArrayRef<AffineMap> indexingMaps,
std::optional<unsigned> blockingFactor = std::nullopt);
+// Returns true if two contraction ops form a valid pair for VNNI packing.
+// It verifies that both contractions share the appropriate operand, read from
+// the same source buffer, and use constant indices that differ by 8 or 16.
+bool validatePairVectorContract(vector::ContractionOp contractOp,
+ vector::ContractionOp pairContOp,
+ bool rhsHasMultipleNonUnitDims,
+ int64_t nonUnitDimValue);
+
+// Walks backward from a value to find its originating vector read-like op
+// (vector.transfer_read or vector.load), following scf.for iter-args but
+// stopping at layout-transforming ops; returns the read op or nullptr.
+Operation *traceToVectorReadLikeParentOperation(mlir::Value v);
+
+// Recursively traces a value to find a downstream vector write-like op
+// (vector.transfer_write or vector.store), crossing scf.for/yield but
+// stopping at layout-altering ops; returns the first match or nullptr.
+Operation *traceToVectorWriteLikeUserOperation(mlir::Value v);
+
+// Packs the accumulators of two flat BF16 vector.contraction ops into a
+// VNNI-packed layout and replaces the original accumulators to enable post-read
+// packing transformations.
+void shuffleAfterReadLikeOp(PatternRewriter &rewriter, Operation *op,
+ Operation *op1, vector::ContractionOp contractOp,
+ vector::ContractionOp pairContractOp,
+ int64_t nonUnitDimAcc, VectorType accTy);
+
+// Shuffles vectors produced by vector.contraction ops into a flat layout
+// before they are written to memory.
+void shuffleBeforeWriteLikeOp(PatternRewriter &rewriter, Operation *op,
+ Operation *op1, int64_t nonUnitDimAcc,
+ VectorType accTy);
+
} // namespace x86vector
} // namespace mlir
diff --git a/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
index e77d30c9c5ffb..e40ddd3a4b1c0 100644
--- a/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
@@ -42,6 +42,11 @@ void mlir::transform::ApplySinkVectorProducerOpsPatternsOp::populatePatterns(
x86vector::populateSinkVectorProducerOpsPatterns(patterns);
}
+void mlir::transform::ApplyShuffleBF16VectorContractResultPatternsOp::
+ populatePatterns(RewritePatternSet &patterns) {
+ x86vector::populateShuffleBF16VectorContractResultPatterns(patterns);
+}
+
//===----------------------------------------------------------------------===//
// Transform op registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
index bbd9be880eb0a..acbc7fcfb635e 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
@@ -5,6 +5,7 @@ add_mlir_dialect_library(MLIRX86VectorTransforms
VectorContractToPackedTypeDotProduct.cpp
VectorContractBF16ToFMA.cpp
SinkVectorProducerOps.cpp
+ ShuffleBF16VectorContractResult.cpp
LINK_LIBS PUBLIC
MLIRArithDialect
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/ShuffleBF16VectorContractResult.cpp b/mlir/lib/Dialect/X86Vector/Transforms/ShuffleBF16VectorContractResult.cpp
new file mode 100644
index 0000000000000..24b8a7489dbfa
--- /dev/null
+++ b/mlir/lib/Dialect/X86Vector/Transforms/ShuffleBF16VectorContractResult.cpp
@@ -0,0 +1,192 @@
+//===- ShuffleBF16VectorContractResult.cpp --------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/Dialect/X86Vector/Transforms.h"
+#include "mlir/Dialect/X86Vector/Utils/X86VectorUtils.h"
+#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
+
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/PatternMatch.h"
+
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/Casting.h"
+
+using namespace mlir;
+using namespace mlir::vector;
+using namespace mlir::x86vector;
+
+// Shuffle the output of BF16 type flat layout vector.contract operations
+//
+// For example:
+// ```
+// %1 = vector.load -> vector<1x1xbf16>
+// %2 = vector.load from memref (%m1) -> vector<1x8xbf16>
+// %3 = vector.load from memref (%m1) -> vector<1x8xbf16>
+// %4 = vector.contract %1, %2, %arg0 -> vector<1x8xf32>
+// %5 = vector.contract %1, %3, %arg1 -> vector<1x8xf32>
+// vector.store %4, %m1
+// vector.store %5, %m1
+// ```
+// to
+// ```
+// %1 = vector.load -> vector<1x1xbf16>
+// %2 = vector.load from memref (%m1) -> vector<1x8xbf16>
+// %3 = vector.load from memref (%m1) -> vector<1x8xbf16>
+// %4 = vector.shuffle %arg0, %arg1 [0, 8, 1, 9, 2, 10, 3, 11]
+// %5 = vector.shuffle %arg0, %arg1 [4, 12, 5, 13, 6, 14, 7, 15]
+// %6 = vector.contract %1, %2, %4 -> vector<1x8xf32>
+// %7 = vector.contract %1, %3, %5 -> vector<1x8xf32>
+// %8 = vector.shuffle %6, %7 [0, 8, 1, 9, 2, 10, 3, 11]
+// %9 = vector.shuffle %6, %7 [4, 12, 5, 13, 6, 14, 7, 15]
+// vector.store %8, %m1
+// vector.store %9, %m1
+//```
+struct ShuffleBF16VectorContractResult
+ : public OpRewritePattern<vector::ContractionOp> {
+ using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
+ PatternRewriter &rewriter) const override {
+
+ if (contractOp.getKind() != vector::CombiningKind::ADD)
+ return rewriter.notifyMatchFailure(contractOp,
+ "Expects add combining kind.");
+
+ // TODO: Move this validation to a common utility folder. Planned to
+ // do once (code refactoring), all architecture specific nanokernel
+ // passes are merged into the repo.
+ VectorType lhsTy = contractOp.getLhsType();
+ if (!lhsTy.getElementType().isBF16())
+ return rewriter.notifyMatchFailure(contractOp,
+ "Only BF16 lowering is supported.");
+
+ if (isInVnniLayout(contractOp.getOperation(),
+ contractOp.getIndexingMapsArray(),
+ /*blockingFactor=*/2))
+ return rewriter.notifyMatchFailure(contractOp,
+ "Input matrices in VNNI format.");
+
+ VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
+ if (!accTy)
+ return rewriter.notifyMatchFailure(contractOp, "Wrong accmulator type.");
+
+ if (!accTy.getElementType().isF32())
+ return rewriter.notifyMatchFailure(
+ contractOp, "Only F32 acumulation supported for BF16 type.");
+
+ ArrayRef<int64_t> accShape = accTy.getShape();
+ llvm::SmallVector<int64_t> nonUnitDimAcc;
+ llvm::copy_if(accShape, std::back_inserter(nonUnitDimAcc),
+ [](int64_t dim) { return dim != 1; });
+
+ if (nonUnitDimAcc.size() != 1)
+ return rewriter.notifyMatchFailure(
+ contractOp, "A or B should be a non-unit dim in acc.");
+
+ int64_t nonUnitDimValue = nonUnitDimAcc.front();
+
+ if (nonUnitDimValue != 8 && nonUnitDimValue != 16)
+ return rewriter.notifyMatchFailure(
+ contractOp, "The accumulator dimension should be 8 or 16");
+
+ ArrayRef<int64_t> lhsShape = lhsTy.getShape();
+ llvm::SmallVector<int64_t> nonUnitDimLhs;
+ llvm::copy_if(lhsShape, std::back_inserter(nonUnitDimLhs),
+ [](int64_t dim) { return dim != 1; });
+
+ VectorType rhsTy = contractOp.getRhsType();
+ ArrayRef<int64_t> rhsShape = rhsTy.getShape();
+ llvm::SmallVector<int64_t> nonUnitDimRhs;
+ llvm::copy_if(rhsShape, std::back_inserter(nonUnitDimRhs),
+ [](int64_t dim) { return dim != 1; });
+
+ vector::ContractionOp pairContractOp;
+ bool rhsHasMultipleNonUnitDims =
+ nonUnitDimRhs.size() > nonUnitDimLhs.size();
+
+ // Get the pair vector.contract operation. The pair is decided on:
+ // (1) - the unitDim operand Lhs or Rhs should be same,
+ // (2) - the defining source memref should be same for nonUnitDim
+ // operation, (3) - the nonUnit dim offset difference between the
+ // vector.contracts should be 8.
+ Operation *nextOp = contractOp;
+ while ((nextOp = nextOp->getNextNode())) {
+ auto contOp = dyn_cast<vector::ContractionOp>(nextOp);
+
+ if (!contOp)
+ continue;
+
+ if (validatePairVectorContract(
+ contractOp, contOp, rhsHasMultipleNonUnitDims, nonUnitDimValue)) {
+ pairContractOp = contOp;
+ break;
+ }
+ }
+
+ if (!pairContractOp)
+ return rewriter.notifyMatchFailure(
+ contractOp, "Coudn't find pair contract operation for shuffling");
+
+ // Trace back to the load or transfer_read operations of the contract
+ // accumulators.
+ Operation *accReadOp0 =
+ traceToVectorReadLikeParentOperation(contractOp.getAcc());
+ Operation *accReadOp1 =
+ traceToVectorReadLikeParentOperation(pairContractOp.getAcc());
+
+ // Iterate dowm to find the users of contact operations until it is store or
+ // transfer_write.
+ Operation *resultWriteOp0 =
+ traceToVectorWriteLikeUserOperation(contractOp.getResult());
+ Operation *resultWriteOp1 =
+ traceToVectorWriteLikeUserOperation(pairContractOp.getResult());
+
+ if (!accReadOp0 || !accReadOp1)
+ return rewriter.notifyMatchFailure(
+ contractOp,
+ "Operands doesn't have load or transfer_read as it's parent op");
+
+ if (!resultWriteOp0 || !resultWriteOp1)
+ return rewriter.notifyMatchFailure(
+ contractOp, "The use of contract operations are neither vector.store "
+ "or transfer_write");
+
+ if (contractOp->getBlock() == accReadOp1->getBlock() &&
+ contractOp->isBeforeInBlock(accReadOp1))
+ return rewriter.notifyMatchFailure(
+ contractOp, "The load/read operation of pair contract operation is "
+ "after the contractOp");
+
+ if (pairContractOp->getBlock() == resultWriteOp0->getBlock() &&
+ resultWriteOp0->isBeforeInBlock(pairContractOp))
+ return rewriter.notifyMatchFailure(
+ contractOp, "The store/write operation of contract operation is "
+ "before the pair contract operation");
+
+ // Shuffle the accumulators of the contract operations.
+ shuffleAfterReadLikeOp(rewriter, accReadOp0, accReadOp1, contractOp,
+ pairContractOp, nonUnitDimValue, accTy);
+
+ // Shuffle the output of contract operations before it's use.
+ shuffleBeforeWriteLikeOp(rewriter, resultWriteOp0, resultWriteOp1,
+ nonUnitDimValue, accTy);
+
+ return success();
+ }
+};
+
+void x86vector::populateShuffleBF16VectorContractResultPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<ShuffleBF16VectorContractResult>(patterns.getContext());
+}
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp
index c60d9b91c18e5..eada03977595d 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp
@@ -8,6 +8,7 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/Dialect/X86Vector/Transforms.h"
@@ -29,7 +30,7 @@ using namespace mlir::x86vector;
// Verifies that the LHS and RHS operands of a vector.contract are load or
// vector.transfer_read operations on a memref source buffer, and checks
// their bounds, dimensions, offsets, and strides.
-static bool validateVectorContractOperands(Value prodOp) {
+static bool validateVectorContractOperands(Value prodOp, bool isVnni) {
Operation *defOp = prodOp.getDefiningOp();
if (!defOp)
return false;
@@ -62,11 +63,13 @@ static bool validateVectorContractOperands(Value prodOp) {
// Return false if the two innermost strides of the memref are not contiguous.
// The x86vector.avx.cvt.packed.even/odd.indexed_to_f32 operations require
// an eight-element tuple of bf16 values to be contiguous.
- if (!llvm::cast<mlir::MemRefType>(srcType).areTrailingDimsContiguous(2))
+ int dimsToCheck = isVnni ? 2 : 1;
+ if (!llvm::cast<mlir::MemRefType>(srcType).areTrailingDimsContiguous(
+ dimsToCheck))
return false;
// Return false if the vnni offset of load or transfer_read is not zero.
- if (getConstantIntValue(indexVals.back()) != 0)
+ if (isVnni && getConstantIntValue(indexVals.back()) != 0)
return false;
return true;
@@ -96,7 +99,8 @@ static bool validateVectorContractOperands(Value prodOp) {
// ```
static SmallVector<memref::SubViewOp>
getSubviewFromVectorInput(Location loc, PatternRewriter &rewriter, Value prodOp,
- ArrayRef<int64_t> nonUnitDimShape, bool isUnitDim) {
+ ArrayRef<int64_t> nonUnitDimShape, bool isUnitDim,
+ bool isVNNI) {
Operation *defOp = prodOp.getDefiningOp();
@@ -122,11 +126,26 @@ getSubviewFromVectorInput(Location loc, PatternRewriter &rewriter, Value prodOp,
}
}
- int vnniDimSize = isUnitDim ? 1 : 2;
+ auto one = rewriter.getIndexAttr(1);
+ llvm::SmallVector<memref::SubViewOp> subviews;
+ if (!isVNNI) {
+ SmallVector<OpFoldResult> strides(indexVals.size(), one);
+ SmallVector<OpFoldResult> sizes(indexVals.size(), one);
+ // Retrive twice the nonUnit dim BF16 element for both even and odd
+ // index elements.
+ if (!isUnitDim)
+ mnDimSize = 2 * mnDimSize;
+ sizes[mnDimIdx] = rewriter.getIndexAttr(mnDimSize);
+ auto subview = memref::SubViewOp::create(rewriter, loc, srcBuff, indexVals,
+ sizes, strides);
+ subviews.push_back(subview);
+ return subviews;
+ }
+
+ int vnniDimSize = isUnitDim ? 1 : 2;
auto nonVNNIDimSize = indexVals.size() - 1;
// Create the size and stride offsets.
- auto one = rewriter.getIndexAttr(1);
SmallVector<OpFoldResult> strides(indexVals.size(), one);
SmallVector<OpFoldResult> sizes(nonVNNIDimSize, one);
@@ -139,7 +158,6 @@ getSubviewFromVectorInput(Location loc, PatternRewriter &rewriter, Value prodOp,
if (isUnitDim)
indexVals[indexVals.size() - 1] = rewriter.getIndexAttr(1);
- llvm::SmallVector<memref::SubViewOp> subviews;
auto subview = memref::SubViewOp::create(rewriter, loc, srcBuff, indexVals,
sizes, strides);
subviews.push_back(subview);
@@ -168,7 +186,7 @@ getSubviewFromVectorInput(Location loc, PatternRewriter &rewriter, Value prodOp,
// Implements outer product contraction as a sequence of BF16-packed
// operation even/odd loads and FMA operations.
//
-// For example:
+// For example (VNNI packed):
// ```
// %1 = vector.load from memref (%m1) -> vector<1x1x2xbf16>
// %2 = vector.load from memref (%m2) -> vector<1x8x2xbf16>
@@ -183,6 +201,24 @@ getSubviewFromVectorInput(Location loc, PatternRewriter &rewriter, Value prodOp,
// %5 = x86vector.avx.cvt.packed.even.indexed_to_f32 %m2 -> vector<8xf32>
// return vector.fma %4, %5, %3
// ```
+//
+// For example (Flat layout):
+// ```
+// %1 = vector.load from memref (%m1) -> vector<1x1xbf16>
+// %2 = vector.load from memref (%m2) -> vector<1x8xbf16>
+// %3 = vector.contract %1, %2, %arg1
+// %4 = vector.load from memref (%m2) -> vector<1x8xbf16>
+// %5 = vector.contract %1, %4, %arg2
+// scf.yield %3, %4
+// ```
+// to
+// ```
+// %1 = x86vector.avx.bcst_to_f32.packed %m1[c0] -> vector<8xf32>
+// %2 = x86vector.avx.cvt.packed.even.indexed_to_f32 %m2 -> vector<8xf32>
+// %3 = vector.fma %1, %2, %arg1
+// %4 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %m2 -> vector<8xf32>
+// %5 = vector.fma %1, %4, %arg2
+// scf.yield %3, %5
struct VectorContractBF16ToFMA
: public OpRewritePattern<vector::ContractionOp> {
using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
@@ -202,11 +238,9 @@ struct VectorContractBF16ToFMA
return rewriter.notifyMatchFailure(contractOp,
"Only BF16 lowering is supported.");
- if (!isInVnniLayout(contractOp.getOperation(),
- contractOp.getIndexingMapsArray(),
- /*blockingFactor=*/2))
- return rewriter.notifyMatchFailure(contractOp,
- "Input matrices not in VNNI format.");
+ bool isVnni = isInVnniLayout(contractOp.getOperation(),
+ contractOp.getIndexingMapsArray(),
+ /*blockingFactor=*/2);
VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
if (!accTy)
@@ -216,6 +250,14 @@ struct VectorContractBF16ToFMA
return rewriter.notifyMatchFailure(
contractOp, "Only F32 acumulation supported for BF16 type.");
+ ArrayRef<int64_t> accShape = accTy.getShape();
+ llvm::SmallVector<int64_t> nonUnitDi...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/174590
More information about the Mlir-commits
mailing list