[Mlir-commits] [mlir] [mlir][x86vector] Shuffle BF16 vector.contract output for Flat layout. (PR #174590)
Arun Thangamani
llvmlistbot at llvm.org
Thu Feb 12 02:14:29 PST 2026
================
@@ -94,127 +219,188 @@ struct VectorContractToPackedTypeDotProduct
"Only F32 for BF16 or Int32 for Int8 "
"accumulation type is supported.");
- ArrayRef<int64_t> accShape = accTy.getShape();
- llvm::SmallVector<int64_t> nonUnitDimAcc;
- llvm::copy_if(accShape, std::back_inserter(nonUnitDimAcc),
- [](int64_t dim) { return dim != 1; });
- if (nonUnitDimAcc.size() != 1)
- return rewriter.notifyMatchFailure(
- contractOp, "A or B should be a non-unit dim in acc.");
+ Value unitDimOperand =
+ rhsHasMultipleNonUnitDims ? contractOp.getLhs() : contractOp.getRhs();
+ Value nonUnitDimOperand =
+ rhsHasMultipleNonUnitDims ? contractOp.getRhs() : contractOp.getLhs();
- // Non-unit dimensions should match the vector length of BF16 or Int8
- // dot-product.
- unsigned int nonUnitDim = nonUnitDimLhs.size() == 2 ? nonUnitDimLhs.front()
- : nonUnitDimRhs.front();
- if (lhsTy.getElementType().isBF16() && nonUnitDim != 4 && nonUnitDim != 8 &&
- nonUnitDim != 16 && nonUnitDimAcc.front() == nonUnitDim)
- return rewriter.notifyMatchFailure(
- contractOp, "BF16 dot-product operation expects non-unit (LHR or "
- "RHS) dim and acc dim of size 4/8/16.");
+ // If the A or B matrix vector of the contact operation is not packed, then
+ // find it's pair contract operation and pack (shuffle) them to VNNI packed.
+ if (!isVnni) {
+ vector::ContractionOp pairContractOp;
+ Operation *nextOp = contractOp;
+ while ((nextOp = nextOp->getNextNode())) {
+ auto contOp = dyn_cast<vector::ContractionOp>(nextOp);
- if (lhsTy.getElementType().isSignlessInteger(8) && nonUnitDim != 4 &&
- nonUnitDim != 8 && nonUnitDim != 16 &&
- nonUnitDimAcc.front() == nonUnitDim)
- return rewriter.notifyMatchFailure(
- contractOp, "Int8 dot-product operation expects non-unit (LHR or "
- "RHS) dim and acc dim of size 4/8/16.");
+ if (!contOp)
+ continue;
+
+ if (validatePairVectorContract(contractOp, contOp,
+ rhsHasMultipleNonUnitDims,
+ nonUnitDimValue)) {
+ pairContractOp = contOp;
+ break;
+ }
+ }
+
+ // If the accumulators are shuffled we get nullptr else the
+ // transfer_read or load operations.
+ Operation *accRead =
+ traceToVectorReadLikeParentOperation(contractOp.getAcc());
+
+ if (!pairContractOp &&
+ (!isNonUnitDimOperandShuffled(nonUnitDimOperand) || accRead))
+ return rewriter.notifyMatchFailure(contractOp,
+ "Could not find a contract pair");
+
+ if (!isNonUnitDimOperandShuffled(nonUnitDimOperand)) {
+ Value nonUnitDimOperandPairContract = rhsHasMultipleNonUnitDims
+ ? pairContractOp.getRhs()
+ : pairContractOp.getLhs();
+
+ // Get the non-packed A or B matrix's vector<32xbf16> elements.
+ Operation *nonUnitDimReadOp =
+ traceToVectorReadLikeParentOperation(nonUnitDimOperand);
+ Operation *nonUnitDimReadOpPairContract =
+ traceToVectorReadLikeParentOperation(nonUnitDimOperandPairContract);
+
+ if (!nonUnitDimReadOp || !nonUnitDimReadOpPairContract)
+ return rewriter.notifyMatchFailure(
+ contractOp, "Could not find a valid contract pair");
+
+ if (contractOp->getBlock() ==
+ nonUnitDimReadOpPairContract->getBlock() &&
+ contractOp->isBeforeInBlock(nonUnitDimReadOpPairContract))
+ return rewriter.notifyMatchFailure(
+ contractOp,
+ "The load/read operation of pair contract operation is "
+ "after the contractOp");
+
+ VectorType nonUnitDimTy = rhsHasMultipleNonUnitDims
+ ? contractOp.getRhsType()
+ : contractOp.getLhsType();
+
+ packNonUnitDimOperandToVNNI(
+ rewriter, nonUnitDimReadOp, nonUnitDimReadOpPairContract,
+ contractOp, pairContractOp, blockingFactor * nonUnitDimValue,
+ nonUnitDimTy);
+
+ nonUnitDimOperand = rhsHasMultipleNonUnitDims ? contractOp.getRhs()
+ : contractOp.getLhs();
+ }
+
+ // Validate and shuffle the accumulator
+ if (accRead) {
+ // Trace back to the load or transfer_read operations of the contract
+ // accumulators.
+ Operation *accReadOp0 =
+ traceToVectorReadLikeParentOperation(contractOp.getAcc());
+ Operation *accReadOp1 =
+ traceToVectorReadLikeParentOperation(pairContractOp.getAcc());
+ // Iterate dowm to find the users of contact operations until it is
----------------
arun-thmn wrote:
Fixed.
https://github.com/llvm/llvm-project/pull/174590
More information about the Mlir-commits
mailing list