[Mlir-commits] [mlir] [mlir][x86vector] Shuffle BF16 vector.contract output for Flat layout. (PR #174590)
Adam Siemieniuk
llvmlistbot at llvm.org
Wed Feb 11 00:59:27 PST 2026
================
@@ -104,5 +111,305 @@ bool isInVnniLayout(Operation *op, ArrayRef<AffineMap> indexingMaps,
return true;
}
+struct ShuffleMasks {
+ llvm::ArrayRef<int64_t> maskLo;
+ llvm::ArrayRef<int64_t> maskHi;
+};
+
+inline ShuffleMasks getShuffleMasks(int64_t nonUnitDimAcc) {
+ // We only support these two layouts for now.
+ assert((nonUnitDimAcc == 8 || nonUnitDimAcc == 16) &&
+ "Unsupported nonUnitDimAcc value");
+ // Do interleaving between two <8xf32> targeting AVX2.
+ static constexpr int64_t maskLo8[] = {0, 8, 1, 9, 2, 10, 3, 11};
+ static constexpr int64_t maskHi8[] = {4, 12, 5, 13, 6, 14, 7, 15};
+
+ // Shuffle two <16xf32> as below targeting AVX512.
+ static constexpr int64_t maskLo16[] = {0, 1, 2, 3, 16, 17, 18, 19,
+ 4, 5, 6, 7, 20, 21, 22, 23};
+ static constexpr int64_t maskHi16[] = {8, 9, 10, 11, 24, 25, 26, 27,
+ 12, 13, 14, 15, 28, 29, 30, 31};
+
+ if (nonUnitDimAcc == 16)
+ return {maskLo16, maskHi16};
+
+ return {maskLo8, maskHi8};
+}
+
+// This function walks backward from a value to locate its originating
+// vector read-like operation (`vector.transfer_read` or `vector.load`).
+// It follows simple forwarding through unary ops and across `scf.for`
+// loop iter-arguments, while stopping if layout-transforming ops such
+// as `shape_cast` or `shuffle` are encountered. The traversal returns
+// the read-like defining operation or `nullptr` if no valid source
+// is found.
+Operation *traceToVectorReadLikeParentOperation(Value v) {
+ while (true) {
+ // Case 1: Value defined by an operation
+ if (Operation *defOp = v.getDefiningOp()) {
+ if (isa<vector::TransferReadOp, vector::LoadOp>(defOp)) {
+ return defOp;
+ }
+
+ if (isa<vector::ShapeCastOp, vector::ShuffleOp>(defOp)) {
+ return nullptr;
+ }
+
+ if (defOp->getNumOperands() == 1) {
+ v = defOp->getOperand(0);
+ return defOp;
+ }
+
+ return nullptr;
+ }
+
+ // Case 2: BlockArgument (scf.for iter_arg)
+ if (auto barg = dyn_cast<BlockArgument>(v)) {
+ auto *parentOp = barg.getOwner()->getParentOp();
+
+ if (auto forOp = dyn_cast<scf::ForOp>(parentOp)) {
+ unsigned argNum = barg.getArgNumber();
+
+ // arg0 = induction variable (not an iter_arg)
+ if (argNum == 0)
+ return nullptr;
+
+ unsigned iterIdx = argNum - 1;
+ v = forOp.getInitArgs()[iterIdx];
+ continue;
+ }
+
+ return nullptr;
+ }
+
+ return nullptr;
+ }
+}
+
+// This function recursively traces a value through its uses to find
+// a downstream vector write-like operation (`vector.transfer_write`
+// or `vector.store`). It transparently follows values across `scf.for`
+// and `scf.yield` boundaries while stopping if layout-altering ops such
+// as `shape_cast` or `shuffle` are encountered. The traversal returns
+// the first matching write-like user or `nullptr` if none is found.
+Operation *traceToVectorWriteLikeUserOperation(Value v) {
+ for (OpOperand &use : v.getUses()) {
+ Operation *user = use.getOwner();
+
+ // --- TERMINAL OPS ---
+ if (isa<vector::TransferWriteOp>(user) || isa<vector::StoreOp>(user)) {
+ return user;
+ }
+
+ if (isa<vector::ShapeCastOp, vector::ShuffleOp>(user)) {
+ return nullptr;
+ }
+
+ // --- SCF YIELD ---
+ if (auto yield = dyn_cast<scf::YieldOp>(user)) {
+ Operation *parent = yield->getParentOp();
+ unsigned idx = use.getOperandNumber();
+ if (auto *res =
+ traceToVectorWriteLikeUserOperation(parent->getResult(idx)))
+ return res;
+ continue;
+ }
+
+ // --- SCF FOR ---
+ if (auto forOp = dyn_cast<scf::ForOp>(user)) {
+ unsigned idx = use.getOperandNumber();
+ if (auto *res = traceToVectorWriteLikeUserOperation(forOp.getResult(idx)))
+ return res;
+ continue;
+ }
+
+ // --- GENERIC CASE ---
+ for (Value res : user->getResults()) {
+ if (auto *found = traceToVectorWriteLikeUserOperation(res))
+ return found;
+ }
+ }
+
+ return nullptr;
+}
+
+// TODO: replace all use with the packed value along with contration
+// and for op.
+static void rewriteUses(mlir::Value oldVal, mlir::Value newVal) {
+ for (mlir::OpOperand &use : llvm::make_early_inc_range(oldVal.getUses())) {
+ mlir::Operation *user = use.getOwner();
+
+ if (mlir::isa<mlir::vector::ContractionOp>(user) ||
+ mlir::isa<mlir::scf::ForOp>(user)) {
+ use.set(newVal);
+ }
+ }
+}
+
+// This function packs the accumulator of two flat BF16 vector.contract
+// operations into VNNI packed and are then replaced in their respective
+// contraction ops, enabling post-read layout or packing transformations.
+void shuffleAfterReadLikeOp(mlir::PatternRewriter &rewriter,
+ mlir::Operation *opA, mlir::Operation *opB,
+ mlir::vector::ContractionOp contractA,
+ mlir::vector::ContractionOp contractB,
+ int64_t nonUnitDimAcc, mlir::VectorType accTy) {
+ mlir::Operation *insertAfter = opA->isBeforeInBlock(opB) ? opB : opA;
+
+ rewriter.setInsertionPointAfter(insertAfter);
+ mlir::Location loc = insertAfter->getLoc();
+
+ auto elemTy = accTy.getElementType();
+ auto flatTy = mlir::VectorType::get(nonUnitDimAcc, elemTy);
+
+ auto castA = mlir::vector::ShapeCastOp::create(rewriter, loc, flatTy,
+ opA->getResult(0));
+ auto castB = mlir::vector::ShapeCastOp::create(rewriter, loc, flatTy,
+ opB->getResult(0));
+
+ auto masks = getShuffleMasks(nonUnitDimAcc);
+
+ auto shuffleLo = mlir::vector::ShuffleOp::create(rewriter, loc, flatTy, castA,
+ castB, masks.maskLo);
+ auto shuffleHi = mlir::vector::ShuffleOp::create(rewriter, loc, flatTy, castA,
+ castB, masks.maskHi);
+
+ auto newAccA =
+ mlir::vector::ShapeCastOp::create(rewriter, loc, accTy, shuffleLo);
+ auto newAccB =
+ mlir::vector::ShapeCastOp::create(rewriter, loc, accTy, shuffleHi);
+
+ rewriteUses(opA->getResult(0), newAccA.getResult());
+ rewriteUses(opB->getResult(0), newAccB.getResult());
----------------
adam-smnk wrote:
Why this can't be done with the rewriter?
https://github.com/llvm/llvm-project/pull/174590
More information about the Mlir-commits
mailing list