[Mlir-commits] [mlir] [mlir][x86vector] Shuffle BF16 vector.contract output for Flat layout. (PR #174590)

Arun Thangamani llvmlistbot at llvm.org
Thu Feb 12 02:14:29 PST 2026


================
@@ -94,127 +219,188 @@ struct VectorContractToPackedTypeDotProduct
                                          "Only F32 for BF16 or Int32 for Int8 "
                                          "accumulation type is supported.");
 
-    ArrayRef<int64_t> accShape = accTy.getShape();
-    llvm::SmallVector<int64_t> nonUnitDimAcc;
-    llvm::copy_if(accShape, std::back_inserter(nonUnitDimAcc),
-                  [](int64_t dim) { return dim != 1; });
-    if (nonUnitDimAcc.size() != 1)
-      return rewriter.notifyMatchFailure(
-          contractOp, "A or B should be a non-unit dim in acc.");
+    Value unitDimOperand =
+        rhsHasMultipleNonUnitDims ? contractOp.getLhs() : contractOp.getRhs();
+    Value nonUnitDimOperand =
+        rhsHasMultipleNonUnitDims ? contractOp.getRhs() : contractOp.getLhs();
 
-    // Non-unit dimensions should match the vector length of BF16 or Int8
-    // dot-product.
-    unsigned int nonUnitDim = nonUnitDimLhs.size() == 2 ? nonUnitDimLhs.front()
-                                                        : nonUnitDimRhs.front();
-    if (lhsTy.getElementType().isBF16() && nonUnitDim != 4 && nonUnitDim != 8 &&
-        nonUnitDim != 16 && nonUnitDimAcc.front() == nonUnitDim)
-      return rewriter.notifyMatchFailure(
-          contractOp, "BF16 dot-product operation expects non-unit (LHR or "
-                      "RHS) dim and acc dim of size 4/8/16.");
+    // If the A or B matrix vector of the contact operation is not packed, then
+    // find it's pair contract operation and pack (shuffle) them to VNNI packed.
+    if (!isVnni) {
+      vector::ContractionOp pairContractOp;
+      Operation *nextOp = contractOp;
+      while ((nextOp = nextOp->getNextNode())) {
+        auto contOp = dyn_cast<vector::ContractionOp>(nextOp);
 
-    if (lhsTy.getElementType().isSignlessInteger(8) && nonUnitDim != 4 &&
-        nonUnitDim != 8 && nonUnitDim != 16 &&
-        nonUnitDimAcc.front() == nonUnitDim)
-      return rewriter.notifyMatchFailure(
-          contractOp, "Int8 dot-product operation expects non-unit (LHR or "
-                      "RHS) dim and acc dim of size 4/8/16.");
+        if (!contOp)
+          continue;
+
+        if (validatePairVectorContract(contractOp, contOp,
+                                       rhsHasMultipleNonUnitDims,
+                                       nonUnitDimValue)) {
+          pairContractOp = contOp;
+          break;
+        }
+      }
+
+      // If the accumulators are shuffled we get nullptr else the
+      // transfer_read or load operations.
+      Operation *accRead =
+          traceToVectorReadLikeParentOperation(contractOp.getAcc());
+
+      if (!pairContractOp &&
+          (!isNonUnitDimOperandShuffled(nonUnitDimOperand) || accRead))
+        return rewriter.notifyMatchFailure(contractOp,
+                                           "Could not find a contract pair");
+
+      if (!isNonUnitDimOperandShuffled(nonUnitDimOperand)) {
+        Value nonUnitDimOperandPairContract = rhsHasMultipleNonUnitDims
+                                                  ? pairContractOp.getRhs()
+                                                  : pairContractOp.getLhs();
+
+        // Get the non-packed A or B matrix's vector<32xbf16> elements.
+        Operation *nonUnitDimReadOp =
+            traceToVectorReadLikeParentOperation(nonUnitDimOperand);
+        Operation *nonUnitDimReadOpPairContract =
+            traceToVectorReadLikeParentOperation(nonUnitDimOperandPairContract);
+
+        if (!nonUnitDimReadOp || !nonUnitDimReadOpPairContract)
+          return rewriter.notifyMatchFailure(
+              contractOp, "Could not find a valid contract pair");
+
+        if (contractOp->getBlock() ==
+                nonUnitDimReadOpPairContract->getBlock() &&
+            contractOp->isBeforeInBlock(nonUnitDimReadOpPairContract))
+          return rewriter.notifyMatchFailure(
+              contractOp,
+              "The load/read operation of pair contract operation is "
+              "after the contractOp");
+
+        VectorType nonUnitDimTy = rhsHasMultipleNonUnitDims
+                                      ? contractOp.getRhsType()
+                                      : contractOp.getLhsType();
+
+        packNonUnitDimOperandToVNNI(
+            rewriter, nonUnitDimReadOp, nonUnitDimReadOpPairContract,
+            contractOp, pairContractOp, blockingFactor * nonUnitDimValue,
+            nonUnitDimTy);
+
+        nonUnitDimOperand = rhsHasMultipleNonUnitDims ? contractOp.getRhs()
+                                                      : contractOp.getLhs();
+      }
+
+      // Validate and shuffle the accumulator
+      if (accRead) {
+        // Trace back to the load or transfer_read operations of the contract
+        // accumulators.
+        Operation *accReadOp0 =
+            traceToVectorReadLikeParentOperation(contractOp.getAcc());
+        Operation *accReadOp1 =
+            traceToVectorReadLikeParentOperation(pairContractOp.getAcc());
 
+        // Iterate dowm to find the users of contact operations until it is
----------------
arun-thmn wrote:

Fixed.

https://github.com/llvm/llvm-project/pull/174590


More information about the Mlir-commits mailing list