[Mlir-commits] [mlir] [mlir][x86vector] Shuffle BF16 vector.contract output for Flat layout. (PR #174590)

Adam Siemieniuk llvmlistbot at llvm.org
Wed Feb 11 00:59:27 PST 2026


================
@@ -104,5 +111,305 @@ bool isInVnniLayout(Operation *op, ArrayRef<AffineMap> indexingMaps,
   return true;
 }
 
+struct ShuffleMasks {
+  llvm::ArrayRef<int64_t> maskLo;
+  llvm::ArrayRef<int64_t> maskHi;
+};
+
+inline ShuffleMasks getShuffleMasks(int64_t nonUnitDimAcc) {
+  // We only support these two layouts for now.
+  assert((nonUnitDimAcc == 8 || nonUnitDimAcc == 16) &&
+         "Unsupported nonUnitDimAcc value");
+  // Do interleaving between two <8xf32> targeting AVX2.
+  static constexpr int64_t maskLo8[] = {0, 8, 1, 9, 2, 10, 3, 11};
+  static constexpr int64_t maskHi8[] = {4, 12, 5, 13, 6, 14, 7, 15};
+
+  // Shuffle two <16xf32> as below targeting AVX512.
+  static constexpr int64_t maskLo16[] = {0, 1, 2, 3, 16, 17, 18, 19,
+                                         4, 5, 6, 7, 20, 21, 22, 23};
+  static constexpr int64_t maskHi16[] = {8,  9,  10, 11, 24, 25, 26, 27,
+                                         12, 13, 14, 15, 28, 29, 30, 31};
+
+  if (nonUnitDimAcc == 16)
+    return {maskLo16, maskHi16};
+
+  return {maskLo8, maskHi8};
+}
+
+// This function walks backward from a value to locate its originating
+// vector read-like operation (`vector.transfer_read` or `vector.load`).
+// It follows simple forwarding through unary ops and across `scf.for`
+// loop iter-arguments, while stopping if layout-transforming ops such
+// as `shape_cast` or `shuffle` are encountered. The traversal returns
+// the read-like defining operation or `nullptr` if no valid source
+// is found.
+Operation *traceToVectorReadLikeParentOperation(Value v) {
+  while (true) {
+    // Case 1: Value defined by an operation
+    if (Operation *defOp = v.getDefiningOp()) {
+      if (isa<vector::TransferReadOp, vector::LoadOp>(defOp)) {
+        return defOp;
+      }
+
+      if (isa<vector::ShapeCastOp, vector::ShuffleOp>(defOp)) {
+        return nullptr;
+      }
+
+      if (defOp->getNumOperands() == 1) {
+        v = defOp->getOperand(0);
+        return defOp;
+      }
+
+      return nullptr;
+    }
+
+    // Case 2: BlockArgument (scf.for iter_arg)
+    if (auto barg = dyn_cast<BlockArgument>(v)) {
+      auto *parentOp = barg.getOwner()->getParentOp();
+
+      if (auto forOp = dyn_cast<scf::ForOp>(parentOp)) {
+        unsigned argNum = barg.getArgNumber();
+
+        // arg0 = induction variable (not an iter_arg)
+        if (argNum == 0)
+          return nullptr;
+
+        unsigned iterIdx = argNum - 1;
+        v = forOp.getInitArgs()[iterIdx];
+        continue;
+      }
+
+      return nullptr;
+    }
+
+    return nullptr;
+  }
+}
+
+// This function recursively traces a value through its uses to find
+// a downstream vector write-like operation (`vector.transfer_write`
+// or `vector.store`). It transparently follows values across `scf.for`
+// and `scf.yield` boundaries while stopping if layout-altering ops such
+// as `shape_cast` or `shuffle` are encountered. The traversal returns
+// the first matching write-like user or `nullptr` if none is found.
+Operation *traceToVectorWriteLikeUserOperation(Value v) {
+  for (OpOperand &use : v.getUses()) {
+    Operation *user = use.getOwner();
+
+    // --- TERMINAL OPS ---
+    if (isa<vector::TransferWriteOp>(user) || isa<vector::StoreOp>(user)) {
+      return user;
+    }
+
+    if (isa<vector::ShapeCastOp, vector::ShuffleOp>(user)) {
+      return nullptr;
+    }
+
+    // --- SCF YIELD ---
+    if (auto yield = dyn_cast<scf::YieldOp>(user)) {
+      Operation *parent = yield->getParentOp();
+      unsigned idx = use.getOperandNumber();
+      if (auto *res =
+              traceToVectorWriteLikeUserOperation(parent->getResult(idx)))
+        return res;
+      continue;
+    }
+
+    // --- SCF FOR ---
+    if (auto forOp = dyn_cast<scf::ForOp>(user)) {
+      unsigned idx = use.getOperandNumber();
+      if (auto *res = traceToVectorWriteLikeUserOperation(forOp.getResult(idx)))
+        return res;
+      continue;
+    }
+
+    // --- GENERIC CASE ---
+    for (Value res : user->getResults()) {
+      if (auto *found = traceToVectorWriteLikeUserOperation(res))
+        return found;
+    }
+  }
+
+  return nullptr;
+}
+
+// TODO: replace all use with the packed value along with contration
+// and for op.
+static void rewriteUses(mlir::Value oldVal, mlir::Value newVal) {
+  for (mlir::OpOperand &use : llvm::make_early_inc_range(oldVal.getUses())) {
+    mlir::Operation *user = use.getOwner();
+
+    if (mlir::isa<mlir::vector::ContractionOp>(user) ||
+        mlir::isa<mlir::scf::ForOp>(user)) {
+      use.set(newVal);
+    }
+  }
+}
+
+// This function packs the accumulator of two flat BF16 vector.contract
+// operations into VNNI packed and are then replaced in their respective
+// contraction ops, enabling post-read layout or packing transformations.
+void shuffleAfterReadLikeOp(mlir::PatternRewriter &rewriter,
+                            mlir::Operation *opA, mlir::Operation *opB,
+                            mlir::vector::ContractionOp contractA,
+                            mlir::vector::ContractionOp contractB,
+                            int64_t nonUnitDimAcc, mlir::VectorType accTy) {
+  mlir::Operation *insertAfter = opA->isBeforeInBlock(opB) ? opB : opA;
+
+  rewriter.setInsertionPointAfter(insertAfter);
+  mlir::Location loc = insertAfter->getLoc();
+
+  auto elemTy = accTy.getElementType();
+  auto flatTy = mlir::VectorType::get(nonUnitDimAcc, elemTy);
+
+  auto castA = mlir::vector::ShapeCastOp::create(rewriter, loc, flatTy,
+                                                 opA->getResult(0));
+  auto castB = mlir::vector::ShapeCastOp::create(rewriter, loc, flatTy,
+                                                 opB->getResult(0));
+
+  auto masks = getShuffleMasks(nonUnitDimAcc);
+
+  auto shuffleLo = mlir::vector::ShuffleOp::create(rewriter, loc, flatTy, castA,
+                                                   castB, masks.maskLo);
+  auto shuffleHi = mlir::vector::ShuffleOp::create(rewriter, loc, flatTy, castA,
+                                                   castB, masks.maskHi);
+
+  auto newAccA =
+      mlir::vector::ShapeCastOp::create(rewriter, loc, accTy, shuffleLo);
+  auto newAccB =
+      mlir::vector::ShapeCastOp::create(rewriter, loc, accTy, shuffleHi);
+
+  rewriteUses(opA->getResult(0), newAccA.getResult());
+  rewriteUses(opB->getResult(0), newAccB.getResult());
----------------
adam-smnk wrote:

Why this can't be done with the rewriter?

https://github.com/llvm/llvm-project/pull/174590


More information about the Mlir-commits mailing list