[Mlir-commits] [mlir] [mlir][x86vector] Shuffle BF16 vector.contract output for Flat layout. (PR #174590)
Adam Siemieniuk
llvmlistbot at llvm.org
Wed Feb 11 00:59:27 PST 2026
================
@@ -104,5 +111,305 @@ bool isInVnniLayout(Operation *op, ArrayRef<AffineMap> indexingMaps,
return true;
}
+struct ShuffleMasks {
+ llvm::ArrayRef<int64_t> maskLo;
+ llvm::ArrayRef<int64_t> maskHi;
+};
+
+inline ShuffleMasks getShuffleMasks(int64_t nonUnitDimAcc) {
+ // We only support these two layouts for now.
+ assert((nonUnitDimAcc == 8 || nonUnitDimAcc == 16) &&
+ "Unsupported nonUnitDimAcc value");
+ // Do interleaving between two <8xf32> targeting AVX2.
+ static constexpr int64_t maskLo8[] = {0, 8, 1, 9, 2, 10, 3, 11};
+ static constexpr int64_t maskHi8[] = {4, 12, 5, 13, 6, 14, 7, 15};
+
+ // Shuffle two <16xf32> as below targeting AVX512.
+ static constexpr int64_t maskLo16[] = {0, 1, 2, 3, 16, 17, 18, 19,
+ 4, 5, 6, 7, 20, 21, 22, 23};
+ static constexpr int64_t maskHi16[] = {8, 9, 10, 11, 24, 25, 26, 27,
+ 12, 13, 14, 15, 28, 29, 30, 31};
+
+ if (nonUnitDimAcc == 16)
+ return {maskLo16, maskHi16};
+
+ return {maskLo8, maskHi8};
+}
+
+// This function walks backward from a value to locate its originating
+// vector read-like operation (`vector.transfer_read` or `vector.load`).
+// It follows simple forwarding through unary ops and across `scf.for`
+// loop iter-arguments, while stopping if layout-transforming ops such
+// as `shape_cast` or `shuffle` are encountered. The traversal returns
+// the read-like defining operation or `nullptr` if no valid source
+// is found.
+Operation *traceToVectorReadLikeParentOperation(Value v) {
+ while (true) {
+ // Case 1: Value defined by an operation
+ if (Operation *defOp = v.getDefiningOp()) {
+ if (isa<vector::TransferReadOp, vector::LoadOp>(defOp)) {
+ return defOp;
+ }
+
+ if (isa<vector::ShapeCastOp, vector::ShuffleOp>(defOp)) {
+ return nullptr;
+ }
+
+ if (defOp->getNumOperands() == 1) {
+ v = defOp->getOperand(0);
+ return defOp;
+ }
+
+ return nullptr;
+ }
+
+ // Case 2: BlockArgument (scf.for iter_arg)
+ if (auto barg = dyn_cast<BlockArgument>(v)) {
+ auto *parentOp = barg.getOwner()->getParentOp();
+
+ if (auto forOp = dyn_cast<scf::ForOp>(parentOp)) {
+ unsigned argNum = barg.getArgNumber();
+
+ // arg0 = induction variable (not an iter_arg)
+ if (argNum == 0)
+ return nullptr;
+
+ unsigned iterIdx = argNum - 1;
+ v = forOp.getInitArgs()[iterIdx];
+ continue;
+ }
+
+ return nullptr;
+ }
+
+ return nullptr;
+ }
+}
+
+// This function recursively traces a value through its uses to find
+// a downstream vector write-like operation (`vector.transfer_write`
+// or `vector.store`). It transparently follows values across `scf.for`
+// and `scf.yield` boundaries while stopping if layout-altering ops such
+// as `shape_cast` or `shuffle` are encountered. The traversal returns
+// the first matching write-like user or `nullptr` if none is found.
+Operation *traceToVectorWriteLikeUserOperation(Value v) {
+ for (OpOperand &use : v.getUses()) {
+ Operation *user = use.getOwner();
+
+ // --- TERMINAL OPS ---
+ if (isa<vector::TransferWriteOp>(user) || isa<vector::StoreOp>(user)) {
+ return user;
+ }
+
+ if (isa<vector::ShapeCastOp, vector::ShuffleOp>(user)) {
+ return nullptr;
+ }
+
+ // --- SCF YIELD ---
+ if (auto yield = dyn_cast<scf::YieldOp>(user)) {
+ Operation *parent = yield->getParentOp();
+ unsigned idx = use.getOperandNumber();
+ if (auto *res =
+ traceToVectorWriteLikeUserOperation(parent->getResult(idx)))
+ return res;
+ continue;
+ }
+
+ // --- SCF FOR ---
+ if (auto forOp = dyn_cast<scf::ForOp>(user)) {
+ unsigned idx = use.getOperandNumber();
+ if (auto *res = traceToVectorWriteLikeUserOperation(forOp.getResult(idx)))
+ return res;
+ continue;
+ }
+
+ // --- GENERIC CASE ---
+ for (Value res : user->getResults()) {
+ if (auto *found = traceToVectorWriteLikeUserOperation(res))
+ return found;
+ }
+ }
+
+ return nullptr;
+}
+
+// TODO: replace all use with the packed value along with contration
+// and for op.
+static void rewriteUses(mlir::Value oldVal, mlir::Value newVal) {
+ for (mlir::OpOperand &use : llvm::make_early_inc_range(oldVal.getUses())) {
+ mlir::Operation *user = use.getOwner();
+
+ if (mlir::isa<mlir::vector::ContractionOp>(user) ||
+ mlir::isa<mlir::scf::ForOp>(user)) {
+ use.set(newVal);
+ }
+ }
+}
+
+// This function packs the accumulator of two flat BF16 vector.contract
+// operations into VNNI packed and are then replaced in their respective
+// contraction ops, enabling post-read layout or packing transformations.
+void shuffleAfterReadLikeOp(mlir::PatternRewriter &rewriter,
+ mlir::Operation *opA, mlir::Operation *opB,
+ mlir::vector::ContractionOp contractA,
+ mlir::vector::ContractionOp contractB,
+ int64_t nonUnitDimAcc, mlir::VectorType accTy) {
+ mlir::Operation *insertAfter = opA->isBeforeInBlock(opB) ? opB : opA;
+
+ rewriter.setInsertionPointAfter(insertAfter);
+ mlir::Location loc = insertAfter->getLoc();
+
+ auto elemTy = accTy.getElementType();
+ auto flatTy = mlir::VectorType::get(nonUnitDimAcc, elemTy);
+
+ auto castA = mlir::vector::ShapeCastOp::create(rewriter, loc, flatTy,
+ opA->getResult(0));
+ auto castB = mlir::vector::ShapeCastOp::create(rewriter, loc, flatTy,
+ opB->getResult(0));
+
+ auto masks = getShuffleMasks(nonUnitDimAcc);
+
+ auto shuffleLo = mlir::vector::ShuffleOp::create(rewriter, loc, flatTy, castA,
+ castB, masks.maskLo);
+ auto shuffleHi = mlir::vector::ShuffleOp::create(rewriter, loc, flatTy, castA,
+ castB, masks.maskHi);
+
+ auto newAccA =
+ mlir::vector::ShapeCastOp::create(rewriter, loc, accTy, shuffleLo);
+ auto newAccB =
+ mlir::vector::ShapeCastOp::create(rewriter, loc, accTy, shuffleHi);
+
+ rewriteUses(opA->getResult(0), newAccA.getResult());
+ rewriteUses(opB->getResult(0), newAccB.getResult());
+}
+
+// This function shuffles the vectors written by vector.contract operation
+// as a flat layout structure before they are stored.
+void shuffleBeforeWriteLikeOp(mlir::PatternRewriter &rewriter,
+ mlir::Operation *opA, mlir::Operation *opB,
+ int64_t nonUnitDimAcc, mlir::VectorType accTy) {
+ // Helper to extract vector operand from write-like ops
+ auto getWrittenVector = [](mlir::Operation *op) -> mlir::Value {
+ if (auto write = mlir::dyn_cast<mlir::vector::TransferWriteOp>(op))
+ return write.getVector();
+ if (auto store = mlir::dyn_cast<mlir::vector::StoreOp>(op))
+ return store.getValueToStore();
+ return nullptr;
+ };
+
+ mlir::Value vecA = getWrittenVector(opA);
+ mlir::Value vecB = getWrittenVector(opB);
+
+ // Decide insertion point and location
+ mlir::Operation *insertBefore = opA->isBeforeInBlock(opB) ? opA : opB;
+
+ rewriter.setInsertionPoint(insertBefore);
+ mlir::Location loc = insertBefore->getLoc();
+
+ auto elemTy = accTy.getElementType();
+ auto flatTy = mlir::VectorType::get(nonUnitDimAcc, elemTy);
+
+ // Flatten vectors
+ auto castA = mlir::vector::ShapeCastOp::create(rewriter, loc, flatTy, vecA);
+ auto castB = mlir::vector::ShapeCastOp::create(rewriter, loc, flatTy, vecB);
+
+ // TODO: derive shuffle masks instead of hard-coding
+ auto masks = getShuffleMasks(nonUnitDimAcc);
+
+ auto shuffledLo = mlir::vector::ShuffleOp::create(rewriter, loc, flatTy,
+ castA, castB, masks.maskLo);
+ auto shuffledHi = mlir::vector::ShuffleOp::create(rewriter, loc, flatTy,
+ castA, castB, masks.maskHi);
+
+ // Cast back to accumulator type
+ auto newVecA =
+ mlir::vector::ShapeCastOp::create(rewriter, loc, accTy, shuffledLo);
+ auto newVecB =
+ mlir::vector::ShapeCastOp::create(rewriter, loc, accTy, shuffledHi);
+
+ // Update write operands in place
+ opA->setOperand(0, newVecA.getResult());
+ opB->setOperand(0, newVecB.getResult());
+}
+
+// Return true if vector.contract operations matches on below conditions:
+// (1) - the unitDim operand Lhs or Rhs should be same,
+// (2) - the defining source memref should be same for nonUnitDim
+// operation,
+// (3) - the nonUnit dim offset difference between the
+// vector.contracts should be 8.
----------------
adam-smnk wrote:
Isn't this configurable through `nonUnitDimValue`? Header mentions also 16?
Could you please clarify and unify the docs?
https://github.com/llvm/llvm-project/pull/174590
More information about the Mlir-commits
mailing list