[Mlir-commits] [mlir] [mlir][x86vector] Shuffle BF16 vector.contract output for Flat layout. (PR #174590)
Arun Thangamani
llvmlistbot at llvm.org
Thu Feb 12 02:14:03 PST 2026
================
@@ -297,6 +372,104 @@ struct VectorContractBF16ToFMA
VectorType dstType =
VectorType::get(nonUnitDimAcc.front(), rewriter.getF32Type());
+ if (!isVnni) {
+
+ // Validate and shuffle the accumulator
+ Operation *accReadOp0 =
+ traceToVectorReadLikeParentOperation(contractOp.getAcc());
+ Operation *accReadOp1 =
+ traceToVectorReadLikeParentOperation(pairContractOp.getAcc());
+
+ // Iterate dowm to find the users of contact operations until it is store
+ // or transfer_write.
+ Operation *resultWriteOp0 =
+ traceToVectorWriteLikeUserOperation(contractOp.getResult());
+ Operation *resultWriteOp1 =
+ traceToVectorWriteLikeUserOperation(pairContractOp.getResult());
+
+ if (!accReadOp0 || !accReadOp1)
+ return rewriter.notifyMatchFailure(
+ contractOp,
+ "Operands doesn't have load or transfer_read as it's parent op");
+
+ if (!resultWriteOp0 || !resultWriteOp1)
+ return rewriter.notifyMatchFailure(
+ contractOp,
+ "The use of contract operations are neither vector.store "
+ "or transfer_write");
+
+ if (contractOp->getBlock() == accReadOp1->getBlock() &&
+ contractOp->isBeforeInBlock(accReadOp1))
+ return rewriter.notifyMatchFailure(
+ contractOp, "The load/read operation of pair contract operation is "
+ "after the contractOp");
+
+ if (pairContractOp->getBlock() == resultWriteOp0->getBlock() &&
+ resultWriteOp0->isBeforeInBlock(pairContractOp)) {
+ return rewriter.notifyMatchFailure(
+ contractOp, "The store/write operation of contract operation is "
+ "before the pair contract operation");
+ }
+
+ // Shuffle the accumulators of the contract operations.
+ LogicalResult readShuffle =
+ shuffleAfterReadLikeOp(rewriter, accReadOp0, accReadOp1, contractOp,
+ pairContractOp, nonUnitDim, accTy);
+
+ if (failed(readShuffle))
+ return rewriter.notifyMatchFailure(
+ contractOp, "Accumulator read is not by transfer_read or load");
+
+ rewriter.setInsertionPoint(contractOp);
+
+ castAcc = vector::ShapeCastOp::create(
+ rewriter, loc,
+ VectorType::get(nonUnitDimAcc.front(), accTy.getElementType()),
+ contractOp.getAcc());
+
+ auto loadBcstBF16ElementToF32 = x86vector::BcstToPackedF32Op::create(
+ rewriter, loc, dstType, unitDimSubview[0]);
+ auto loadEvenIdxElementF32 =
+ x86vector::CvtPackedEvenIndexedToF32Op::create(rewriter, loc, dstType,
+ nonUnitDimSubview[0]);
+ auto evenIdxFMA =
+ vector::FMAOp::create(rewriter, loc, loadBcstBF16ElementToF32,
+ loadEvenIdxElementF32, castAcc);
+ auto castEvenFma =
+ vector::ShapeCastOp::create(rewriter, loc, accTy, evenIdxFMA);
+ rewriter.replaceOp(contractOp, castEvenFma);
+
+ rewriter.setInsertionPoint(pairContractOp);
+ auto pairContOpLoc = pairContractOp.getLoc();
+ VectorType accTyPairCont =
+ dyn_cast<VectorType>(pairContractOp.getAccType());
+ auto castAccPairCont = vector::ShapeCastOp::create(
+ rewriter, pairContOpLoc,
+ VectorType::get(nonUnitDimAcc.front(),
+ accTyPairCont.getElementType()),
+ pairContractOp.getAcc());
+
+ auto loadOddIdxElementF32 = x86vector::CvtPackedOddIndexedToF32Op::create(
+ rewriter, pairContOpLoc, dstType, nonUnitDimSubview[0]);
+ auto oddIdxFMA = vector::FMAOp::create(
+ rewriter, pairContOpLoc, loadBcstBF16ElementToF32,
+ loadOddIdxElementF32, castAccPairCont);
+ auto castOddFma = vector::ShapeCastOp::create(rewriter, pairContOpLoc,
+ accTyPairCont, oddIdxFMA);
+ rewriter.replaceOp(pairContractOp, castOddFma);
+
+ // Shuffle the output of contract operations before it's use.
----------------
arun-thmn wrote:
Fixed.
https://github.com/llvm/llvm-project/pull/174590
More information about the Mlir-commits
mailing list