[Mlir-commits] [mlir] [mlir][x86vector] Shuffle BF16 vector.contract output for Flat layout. (PR #174590)

Arun Thangamani llvmlistbot at llvm.org
Thu Feb 12 02:14:03 PST 2026


================
@@ -297,6 +372,104 @@ struct VectorContractBF16ToFMA
     VectorType dstType =
         VectorType::get(nonUnitDimAcc.front(), rewriter.getF32Type());
 
+    if (!isVnni) {
+
+      // Validate and shuffle the accumulator
+      Operation *accReadOp0 =
+          traceToVectorReadLikeParentOperation(contractOp.getAcc());
+      Operation *accReadOp1 =
+          traceToVectorReadLikeParentOperation(pairContractOp.getAcc());
+
+      // Iterate dowm to find the users of contact operations until it is store
+      // or transfer_write.
+      Operation *resultWriteOp0 =
+          traceToVectorWriteLikeUserOperation(contractOp.getResult());
+      Operation *resultWriteOp1 =
+          traceToVectorWriteLikeUserOperation(pairContractOp.getResult());
+
+      if (!accReadOp0 || !accReadOp1)
+        return rewriter.notifyMatchFailure(
+            contractOp,
+            "Operands doesn't have load or transfer_read as it's parent op");
+
+      if (!resultWriteOp0 || !resultWriteOp1)
+        return rewriter.notifyMatchFailure(
+            contractOp,
+            "The use of contract operations are neither vector.store "
+            "or transfer_write");
+
+      if (contractOp->getBlock() == accReadOp1->getBlock() &&
+          contractOp->isBeforeInBlock(accReadOp1))
+        return rewriter.notifyMatchFailure(
+            contractOp, "The load/read operation of pair contract operation is "
+                        "after the contractOp");
+
+      if (pairContractOp->getBlock() == resultWriteOp0->getBlock() &&
+          resultWriteOp0->isBeforeInBlock(pairContractOp)) {
+        return rewriter.notifyMatchFailure(
+            contractOp, "The store/write operation of contract operation is "
+                        "before the pair contract operation");
+      }
+
+      // Shuffle the accumulators of the contract operations.
+      LogicalResult readShuffle =
+          shuffleAfterReadLikeOp(rewriter, accReadOp0, accReadOp1, contractOp,
+                                 pairContractOp, nonUnitDim, accTy);
+
+      if (failed(readShuffle))
+        return rewriter.notifyMatchFailure(
+            contractOp, "Accumulator read is not by transfer_read or load");
+
+      rewriter.setInsertionPoint(contractOp);
+
+      castAcc = vector::ShapeCastOp::create(
+          rewriter, loc,
+          VectorType::get(nonUnitDimAcc.front(), accTy.getElementType()),
+          contractOp.getAcc());
+
+      auto loadBcstBF16ElementToF32 = x86vector::BcstToPackedF32Op::create(
+          rewriter, loc, dstType, unitDimSubview[0]);
+      auto loadEvenIdxElementF32 =
+          x86vector::CvtPackedEvenIndexedToF32Op::create(rewriter, loc, dstType,
+                                                         nonUnitDimSubview[0]);
+      auto evenIdxFMA =
+          vector::FMAOp::create(rewriter, loc, loadBcstBF16ElementToF32,
+                                loadEvenIdxElementF32, castAcc);
+      auto castEvenFma =
+          vector::ShapeCastOp::create(rewriter, loc, accTy, evenIdxFMA);
+      rewriter.replaceOp(contractOp, castEvenFma);
+
+      rewriter.setInsertionPoint(pairContractOp);
+      auto pairContOpLoc = pairContractOp.getLoc();
+      VectorType accTyPairCont =
+          dyn_cast<VectorType>(pairContractOp.getAccType());
+      auto castAccPairCont = vector::ShapeCastOp::create(
+          rewriter, pairContOpLoc,
+          VectorType::get(nonUnitDimAcc.front(),
+                          accTyPairCont.getElementType()),
+          pairContractOp.getAcc());
+
+      auto loadOddIdxElementF32 = x86vector::CvtPackedOddIndexedToF32Op::create(
+          rewriter, pairContOpLoc, dstType, nonUnitDimSubview[0]);
+      auto oddIdxFMA = vector::FMAOp::create(
+          rewriter, pairContOpLoc, loadBcstBF16ElementToF32,
+          loadOddIdxElementF32, castAccPairCont);
+      auto castOddFma = vector::ShapeCastOp::create(rewriter, pairContOpLoc,
+                                                    accTyPairCont, oddIdxFMA);
+      rewriter.replaceOp(pairContractOp, castOddFma);
+
+      // Shuffle the output of contract operations before it's use.
----------------
arun-thmn wrote:

Fixed.

https://github.com/llvm/llvm-project/pull/174590


More information about the Mlir-commits mailing list