[Mlir-commits] [mlir] [mlir][x86vector] Shuffle BF16 vector.contract output for Flat layout. (PR #174590)
Arun Thangamani
llvmlistbot at llvm.org
Mon Feb 16 02:56:36 PST 2026
================
@@ -297,6 +407,80 @@ struct VectorContractBF16ToFMA
VectorType dstType =
VectorType::get(nonUnitDimAcc.front(), rewriter.getF32Type());
+ if (!isVnni) {
+
+ // Validate and shuffle the accumulator
+ Operation *accReadOp0 =
+ traceToVectorReadLikeParentOperation(contractOp.getAcc());
+ Operation *accReadOp1 =
+ traceToVectorReadLikeParentOperation(pairContractOp.getAcc());
+
+ // Iterate down to find the users of contact operations until it is store
+ // or transfer_write.
+ Operation *resultWriteOp0 =
+ traceToVectorWriteLikeUserOperation(contractOp.getResult());
+ Operation *resultWriteOp1 =
+ traceToVectorWriteLikeUserOperation(pairContractOp.getResult());
+
+ // Shuffle the accumulators of the contract operations.
+ LogicalResult readShuffle =
+ shuffleAfterReadLikeOp(rewriter, accReadOp0, accReadOp1, contractOp,
+ pairContractOp, nonUnitDim, accTy);
+
+ if (failed(readShuffle))
+ return rewriter.notifyMatchFailure(
+ contractOp, "Accumulator read is not by transfer_read or load");
+
+ rewriter.setInsertionPoint(contractOp);
+
+ castAcc = vector::ShapeCastOp::create(
+ rewriter, loc,
+ VectorType::get(nonUnitDimAcc.front(), accTy.getElementType()),
+ contractOp.getAcc());
+
+ auto loadBcstBF16ElementToF32 = x86vector::BcstToPackedF32Op::create(
+ rewriter, loc, dstType, unitDimSubview[0]);
+ auto loadEvenIdxElementF32 =
+ x86vector::CvtPackedEvenIndexedToF32Op::create(rewriter, loc, dstType,
+ nonUnitDimSubview[0]);
+ auto evenIdxFMA =
+ vector::FMAOp::create(rewriter, loc, loadBcstBF16ElementToF32,
+ loadEvenIdxElementF32, castAcc);
+ auto castEvenFma =
+ vector::ShapeCastOp::create(rewriter, loc, accTy, evenIdxFMA);
+ rewriter.replaceOp(contractOp, castEvenFma);
+
+ rewriter.setInsertionPoint(pairContractOp);
+ auto pairContOpLoc = pairContractOp.getLoc();
+ VectorType accTyPairCont =
+ dyn_cast<VectorType>(pairContractOp.getAccType());
+ auto castAccPairCont = vector::ShapeCastOp::create(
+ rewriter, pairContOpLoc,
+ VectorType::get(nonUnitDimAcc.front(),
+ accTyPairCont.getElementType()),
+ pairContractOp.getAcc());
+
+ auto loadOddIdxElementF32 = x86vector::CvtPackedOddIndexedToF32Op::create(
+ rewriter, pairContOpLoc, dstType, nonUnitDimSubview[0]);
+ auto oddIdxFMA = vector::FMAOp::create(
+ rewriter, pairContOpLoc, loadBcstBF16ElementToF32,
+ loadOddIdxElementF32, castAccPairCont);
+ auto castOddFma = vector::ShapeCastOp::create(rewriter, pairContOpLoc,
+ accTyPairCont, oddIdxFMA);
+ rewriter.replaceOp(pairContractOp, castOddFma);
+
+ // Shuffle the output of contract operations before its use.
+ LogicalResult writeShuffle = shuffleBeforeWriteLikeOp(
----------------
arun-thmn wrote:
Yes, for now.
We are just shuffling before the `transfer_write or store`, but there are TODO's to cover for many other cases like eltwise ops in-between.
https://github.com/llvm/llvm-project/pull/174590
More information about the Mlir-commits
mailing list