[Mlir-commits] [mlir] [mlir][x86] Lower packed type vector.contract to AMX dot-product (online-packing) (PR #188192)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Mar 26 21:48:24 PDT 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Arun Thangamani (arun-thmn)

<details>
<summary>Changes</summary>

A transform pass to lower flat layout `vector.contract` operation to (a) amx.tile_mulf for BF16, or (b) amx.tile_muli for Int8 packed types via `online` packing.

TODOs: On an another `patch` planned to re-factor this pass + retiring `convert-vector-to-amx` pass.

---

Patch is 81.44 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/188192.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp (+897-170) 
- (modified) mlir/test/Dialect/X86/AMX/vector-contract-to-tiled-dp.mlir (+480-20) 


``````````diff
diff --git a/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp b/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
index 85966a85af40e..744c065b4e05e 100644
--- a/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
+++ b/mlir/lib/Dialect/X86/Transforms/VectorContractToAMXDotProduct.cpp
@@ -70,8 +70,9 @@ getSrcIndxValue(OpBuilder &rewriter, Location loc, Value operand,
   if (!srcBuff)
     return failure();
 
-  if (isNotAcc)
+  if (isNotAcc) {
     indexVals.pop_back();
+  }
 
   SmallVector<Value> indices;
   indices.reserve(indexVals.size());
@@ -189,37 +190,184 @@ static unsigned getIndexPosition(Value operand, scf::ForOp loop) {
 // Creates amx.tile_loads.
 static amx::TileLoadOp createTileLoads(OpBuilder &rewriter, Location loc,
                                        Value operand, Value mat, Type ipType,
-                                       bool rhs, unsigned int offset) {
+                                       bool rhs, unsigned int offset,
+                                       bool isVnni) {
 
   auto srcIndx = getSrcIndxValue(rewriter, loc, operand, false);
   auto [srcBuff, indices] = *srcIndx;
-  indices.pop_back();
+  if (isVnni) {
+    indices.pop_back();
+  }
 
-  if (rhs) {
+  if (rhs && isVnni) {
     auto cOffset = arith::ConstantIndexOp::create(rewriter, loc, offset);
     indices[indices.size() - 1] = arith::MulIOp::create(
         rewriter, loc, indices[indices.size() - 1], cOffset);
   }
 
   amx::TileType tileType = amx::TileType::get({16, (16 * offset)}, ipType);
-  return amx::TileLoadOp::create(rewriter, loc, tileType, mat, indices);
+  auto load = amx::TileLoadOp::create(rewriter, loc, tileType, mat, indices);
+  return load;
 }
 
-// Creates tiled amx dot-products.
-static SmallVector<Value> createTiledDp(OpBuilder &rewriter, Location loc,
-                                        SmallVector<vector::ContractionOp> ops,
-                                        Value matA, Value matB, Type ipType,
-                                        Type opType, ValueRange accIterArgs,
-                                        unsigned int offset) {
+static void performShuffle(OpBuilder &rewriter, Location loc, Value matB,
+                           Type ipType, unsigned int offset, Value packedBuffer,
+                           Value indxToStoreInBuffer) {
+
+  Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
+  Value c16 = arith::ConstantIndexOp::create(rewriter, loc, 16);
+
+  auto subview = matB.getDefiningOp<mlir::memref::SubViewOp>();
+  SmallVector<Value> subviewOffset(subview.getOffsets().size(), c0);
+
+  Value cStep = arith::ConstantIndexOp::create(rewriter, loc, offset);
+  Value cBound = arith::ConstantIndexOp::create(rewriter, loc, (16 * offset));
+  Value offsetIndx =
+      arith::ConstantIndexOp::create(rewriter, loc, (offset / 2));
+
+  scf::ForOp::create(
+      rewriter, loc, c0, cBound, cStep, ValueRange{},
+      [&](OpBuilder &nestedBuilder, Location loc, Value iv,
+          ValueRange iterArgs) {
+        subviewOffset[subviewOffset.size() - 2] = iv;
+        auto vec1 = vector::LoadOp::create(
+            rewriter, loc, VectorType::get((16 * offset), ipType), matB,
+            ValueRange(subviewOffset));
+
+        // Increment the iv by 1 or 2 based on the type to load the next 32/64
+        // elements
+        Value incIV = arith::AddIOp::create(rewriter, loc, offsetIndx, iv);
+        subviewOffset[subviewOffset.size() - 2] = incIV;
+        auto vec2 = vector::LoadOp::create(
+            rewriter, loc, VectorType::get((16 * offset), ipType), matB,
+            ValueRange(subviewOffset));
+
+        vector::ShuffleOp shuffle1;
+        vector::ShuffleOp shuffle2;
+
+        if (ipType.isBF16()) {
+
+          shuffle1 = vector::ShuffleOp::create(
+              rewriter, loc, VectorType::get({(16 * offset)}, ipType), vec1,
+              vec2,
+              ArrayRef<int64_t>{0,  32, 1,  33, 2,  34, 3,  35, 8,  40, 9,
+                                41, 10, 42, 11, 43, 16, 48, 17, 49, 18, 50,
+                                19, 51, 24, 56, 25, 57, 26, 58, 27, 59});
+
+          shuffle2 = vector::ShuffleOp::create(
+              rewriter, loc, VectorType::get({(16 * offset)}, ipType), vec1,
+              vec2,
+              ArrayRef<int64_t>{4,  36, 5,  37, 6,  38, 7,  39, 12, 44, 13,
+                                45, 14, 46, 15, 47, 20, 52, 21, 53, 22, 54,
+                                23, 55, 28, 60, 29, 61, 30, 62, 31, 63});
+        }
+
+        if (ipType.isSignlessInteger(8)) {
+
+          shuffle1 = vector::ShuffleOp::create(
+              rewriter, loc, VectorType::get({(16 * offset)}, ipType), vec1,
+              vec2,
+              ArrayRef<int64_t>{
+                  0,   32,  64, 96,  1,   33,  65,  97,  2,   34,  66,  98, 3,
+                  35,  67,  99, 4,   36,  68,  100, 5,   37,  69,  101, 6,  38,
+                  70,  102, 7,  39,  71,  103, 8,   40,  72,  104, 9,   41, 73,
+                  105, 10,  42, 74,  106, 11,  43,  75,  107, 12,  44,  76, 108,
+                  13,  45,  77, 109, 14,  46,  78,  110, 15,  47,  79,  111});
+
+          shuffle2 = vector::ShuffleOp::create(
+              rewriter, loc, VectorType::get({(16 * offset)}, ipType), vec1,
+              vec2,
+              ArrayRef<int64_t>{
+                  16, 48,  80, 112, 17, 49,  81, 113, 18, 50,  82, 114, 19, 51,
+                  83, 115, 20, 52,  84, 116, 21, 53,  85, 117, 22, 54,  86, 118,
+                  23, 55,  87, 119, 24, 56,  88, 120, 25, 57,  89, 121, 26, 58,
+                  90, 122, 27, 59,  91, 123, 28, 60,  92, 124, 29, 61,  93, 125,
+                  30, 62,  94, 126, 31, 63,  95, 127});
+        }
+
+        // iv to store the shuffled elements
+        Value ivShuff1 = arith::DivUIOp::create(rewriter, loc, iv, cStep);
+        Value ivShuff2 = arith::AddIOp::create(rewriter, loc, ivShuff1, c16);
+
+        vector::StoreOp::create(rewriter, loc, shuffle1, packedBuffer,
+                                ValueRange{indxToStoreInBuffer, ivShuff1, c0});
+        vector::StoreOp::create(rewriter, loc, shuffle2, packedBuffer,
+                                ValueRange{indxToStoreInBuffer, ivShuff2, c0});
+
+        scf::YieldOp::create(nestedBuilder, loc);
+      });
+}
+
+static llvm::DenseMap<Operation *, amx::TileLoadOp>
+packInputs(OpBuilder &rewriter, Location loc,
+           SmallVector<vector::ContractionOp> ops, Value matB, Type ipType,
+           unsigned int offset, Value packedBuffer, bool pack,
+           Value indxToStoreInBuffer, Value indxToLoadFromMatB) {
+
+  llvm::DenseMap<Operation *, amx::TileLoadOp> readsToTileLoads;
+  Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
+  Value c16 = arith::ConstantIndexOp::create(rewriter, loc, 16);
+
+  for (size_t j = 0; j < ops.size(); j++) {
+    for (size_t i = 0; i < ops.size(); i++) {
+
+      if (i != j && validatePairVectorContract(ops[j], ops[i], true, 16)) {
+
+        Operation *readOpRhs = ops[j].getRhs().getDefiningOp();
+        auto itRhs = readsToTileLoads.find(readOpRhs);
+        if (itRhs != readsToTileLoads.end()) {
+          continue;
+        }
 
-  auto subviewCollapseLhs = collapseInnerDims(rewriter, loc, matA);
-  auto subviewCollapseRhs = collapseInnerDims(rewriter, loc, matB);
+        if (pack) {
+          performShuffle(rewriter, loc, matB, ipType, offset, packedBuffer,
+                         indxToStoreInBuffer);
+        }
+
+        amx::TileType tileType =
+            amx::TileType::get({16, (16 * offset)}, ipType);
+        auto loadRow1 =
+            amx::TileLoadOp::create(rewriter, loc, tileType, packedBuffer,
+                                    ValueRange{indxToLoadFromMatB, c0, c0});
+
+        auto loadRow2 =
+            amx::TileLoadOp::create(rewriter, loc, tileType, packedBuffer,
+                                    ValueRange{indxToLoadFromMatB, c16, c0});
+
+        readsToTileLoads.try_emplace(readOpRhs, loadRow1);
+        readsToTileLoads.try_emplace(ops[i].getRhs().getDefiningOp(), loadRow2);
+      }
+    }
+  }
+
+  return readsToTileLoads;
+}
+
+// Creates tiled amx dot-products.
+static SmallVector<Value>
+createTiledDp(OpBuilder &rewriter, Location loc,
+              SmallVector<vector::ContractionOp> ops, Value matA, Value matB,
+              Type ipType, Type opType, ValueRange accIterArgs,
+              unsigned int offset, bool isVnni, Value packedBuffer, bool pack,
+              Value indxToStoreInBuffer, Value indxToLoadFromMatB) {
+
+  if (isVnni) {
+    matA = collapseInnerDims(rewriter, loc, matA);
+    matB = collapseInnerDims(rewriter, loc, matB);
+  }
 
   SmallVector<Value> accumulators;
   // Stores the amx.tile_load operation vs it's equivalent vector tranfer_read
   // or load operations.
   llvm::DenseMap<Operation *, amx::TileLoadOp> readsToTileLoads;
 
+  // function call to online pack the input  B matrix
+  if (!isVnni) {
+    readsToTileLoads =
+        packInputs(rewriter, loc, ops, matB, ipType, offset, packedBuffer, pack,
+                   indxToStoreInBuffer, indxToLoadFromMatB);
+  }
+
   // Iterate over the contraction operations and compute the tiled dot-product.
   for (size_t i = 0; i < ops.size(); i++) {
 
@@ -229,8 +377,8 @@ static SmallVector<Value> createTiledDp(OpBuilder &rewriter, Location loc,
     if (itLhs != readsToTileLoads.end()) {
       tilesLhs = itLhs->second;
     } else {
-      tilesLhs = createTileLoads(rewriter, loc, ops[i].getLhs(),
-                                 subviewCollapseLhs, ipType, false, offset);
+      tilesLhs = createTileLoads(rewriter, loc, ops[i].getLhs(), matA, ipType,
+                                 false, offset, isVnni);
       readsToTileLoads.try_emplace(readOpLhs, tilesLhs);
     }
 
@@ -240,8 +388,8 @@ static SmallVector<Value> createTiledDp(OpBuilder &rewriter, Location loc,
     if (itRhs != readsToTileLoads.end()) {
       tilesRhs = itRhs->second;
     } else {
-      tilesRhs = createTileLoads(rewriter, loc, ops[i].getRhs(),
-                                 subviewCollapseRhs, ipType, true, offset);
+      tilesRhs = createTileLoads(rewriter, loc, ops[i].getRhs(), matB, ipType,
+                                 true, offset, isVnni);
       readsToTileLoads.try_emplace(readOpRhs, tilesRhs);
     }
 
@@ -276,10 +424,186 @@ static SmallVector<Value> createTileZeros(OpBuilder &rewriter, Location loc,
   return loopItrArgs;
 }
 
+static Value bufferIndxToStore(OpBuilder &rewriter, Location loc,
+                               Value ivInnerLoop, Value ivOuterLoop,
+                               bool isInnerLoopUBHasOddQuot,
+                               bool isInnerLoopUBLarger, bool pack,
+                               unsigned int blockingFactor) {
+
+  Value c2 = arith::ConstantIndexOp::create(rewriter, loc, 2);
+  Value packOffset =
+      arith::ConstantIndexOp::create(rewriter, loc, (16 * blockingFactor));
+
+  Value quotientInnerLoop =
+      arith::DivUIOp::create(rewriter, loc, ivInnerLoop, packOffset);
+  Value remInnerLoop = arith::RemUIOp::create(
+      rewriter, loc, rewriter.getIndexType(), quotientInnerLoop, c2);
+
+  if (!isInnerLoopUBLarger && !pack) {
+    remInnerLoop = arith::RemUIOp::create(
+        rewriter, loc, rewriter.getIndexType(), ivOuterLoop, c2);
+  }
+
+  // if K quotient is odd. Then, BR loop iv is taken
+  // into consideration
+  if (isInnerLoopUBHasOddQuot) {
+    auto remOuterLoop = arith::RemUIOp::create(
+        rewriter, loc, rewriter.getIndexType(), ivOuterLoop, c2);
+    auto remAdd = arith::AddIOp::create(rewriter, loc, rewriter.getIndexType(),
+                                        remInnerLoop, remOuterLoop);
+    remInnerLoop = arith::RemUIOp::create(rewriter, loc,
+                                          rewriter.getIndexType(), remAdd, c2);
+  }
+  return remInnerLoop;
+}
+
+static scf::ForOp
+createLoops(OpBuilder &rewriter, Location loc, Value lowerBound,
+            Value upperBound, Value step, SmallVector<Value> loopItrArgs,
+            Type ipType, Type opType, unsigned int blockingFactor, bool isVnni,
+            Operation *vectorOpLhs, Operation *vectorOpRhs,
+            vector::ContractionOp contractOp, scf::ForOp outerLoop,
+            scf::ForOp innerLoop, SmallVector<vector::ContractionOp> ops,
+            Value ivOuterLoop, Value packedBuffer, bool pack,
+            arith::ConstantIndexOp innerLoopIndex, bool isInnerLoopUBLarger,
+            bool isInnerLoopUBHasOddQuot) {
+
+  Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
+  Value c1 = arith::ConstantIndexOp::create(rewriter, loc, 1);
+  Value c2 = arith::ConstantIndexOp::create(rewriter, loc, 2);
+
+  auto newLoop = scf::ForOp::create(
+      rewriter, loc, lowerBound, upperBound, step, loopItrArgs,
+      [&](OpBuilder &rewriterNewInnerLoop, Location locNewInnerLoop,
+          Value ivNewInnerLoop, ValueRange iterArgsNewInnerLoop) {
+        IRMapping mapping;
+        if (outerLoop) {
+          mapping.map(vectorOpLhs->getOperand(
+                          getIndexPosition(contractOp.getLhs(), outerLoop) + 1),
+                      ivOuterLoop);
+        }
+        mapping.map(vectorOpLhs->getOperand(
+                        getIndexPosition(contractOp.getLhs(), innerLoop) + 1),
+                    ivNewInnerLoop);
+        auto lhsClone = rewriterNewInnerLoop.clone(*vectorOpLhs, mapping);
+
+        Value indxToStoreInBuffer = c0;
+        Value indxToLoadFromBuffer = c0;
+
+        if (!isVnni) {
+          if (outerLoop) {
+            if (innerLoopIndex.value() == 0) {
+              if (pack) {
+                ivNewInnerLoop = c0;
+                ivOuterLoop = arith::AddIOp::create(rewriter, locNewInnerLoop,
+                                                    c1, ivOuterLoop);
+
+                if (!isInnerLoopUBLarger || isInnerLoopUBHasOddQuot) {
+                  indxToStoreInBuffer = arith::RemUIOp::create(
+                      rewriter, locNewInnerLoop, rewriter.getIndexType(),
+                      ivOuterLoop, c2);
+                }
+
+                Value indxToLoadFromMatB = arith::AddIOp::create(
+                    rewriter, loc, indxToStoreInBuffer, c1);
+                indxToLoadFromBuffer = arith::RemUIOp::create(
+                    rewriter, loc, rewriter.getIndexType(), indxToLoadFromMatB,
+                    c2);
+              }
+
+            } else {
+              Value nLoadIndx = arith::ConstantIndexOp::create(
+                  rewriter, locNewInnerLoop, (16 * blockingFactor));
+              ivNewInnerLoop = arith::AddIOp::create(rewriter, locNewInnerLoop,
+                                                     nLoadIndx, ivNewInnerLoop);
+              indxToStoreInBuffer =
+                  bufferIndxToStore(rewriter, loc, ivNewInnerLoop, ivOuterLoop,
+                                    isInnerLoopUBHasOddQuot,
+                                    isInnerLoopUBLarger, pack, blockingFactor);
+              Value indxToLoadFromMatB =
+                  arith::AddIOp::create(rewriter, loc, indxToStoreInBuffer, c1);
+              indxToLoadFromBuffer =
+                  arith::RemUIOp::create(rewriter, loc, rewriter.getIndexType(),
+                                         indxToLoadFromMatB, c2);
+            }
+          } else {
+            if (pack) {
+              Value nLoadIndx = arith::ConstantIndexOp::create(
+                  rewriter, locNewInnerLoop, (16 * blockingFactor));
+              ivNewInnerLoop = arith::AddIOp::create(rewriter, locNewInnerLoop,
+                                                     nLoadIndx, ivNewInnerLoop);
+              Value quotient_K = arith::DivUIOp::create(
+                  rewriter, loc, ivNewInnerLoop, nLoadIndx);
+              indxToStoreInBuffer = arith::RemUIOp::create(
+                  rewriter, loc, rewriter.getIndexType(), quotient_K, c2);
+
+              Value indxToLoadFromMatB =
+                  arith::AddIOp::create(rewriter, loc, indxToStoreInBuffer, c1);
+              indxToLoadFromBuffer =
+                  arith::RemUIOp::create(rewriter, loc, rewriter.getIndexType(),
+                                         indxToLoadFromMatB, c2);
+            }
+          }
+        }
+
+        IRMapping rhsMapping;
+        if (outerLoop) {
+          rhsMapping.map(
+              vectorOpRhs->getOperand(
+                  getIndexPosition(contractOp.getRhs(), outerLoop) + 1),
+              ivOuterLoop);
+        }
+        rhsMapping.map(
+            vectorOpRhs->getOperand(
+                getIndexPosition(contractOp.getRhs(), innerLoop) + 1),
+            ivNewInnerLoop);
+        auto rhsClone = rewriterNewInnerLoop.clone(*vectorOpRhs, rhsMapping);
+
+        Value matB = rhsClone->getResult(0);
+
+        if (!isVnni) {
+          if (outerLoop) {
+            if (!pack) {
+              Value nLoadIndx = arith::ConstantIndexOp::create(
+                  rewriter, locNewInnerLoop, (16 * blockingFactor));
+              matB = Value();
+              indxToLoadFromBuffer = c0;
+              indxToLoadFromBuffer =
+                  bufferIndxToStore(rewriter, loc, nLoadIndx, ivOuterLoop,
+                                    isInnerLoopUBHasOddQuot,
+                                    isInnerLoopUBLarger, pack, blockingFactor);
+            }
+          } else {
+            if (!pack) {
+              Value nLoadIndx = arith::ConstantIndexOp::create(
+                  rewriter, locNewInnerLoop, (16 * blockingFactor));
+              matB = Value();
+              Value quotient_K = arith::DivUIOp::create(
+                  rewriter, loc, ivNewInnerLoop, nLoadIndx);
+              indxToLoadFromBuffer = arith::RemUIOp::create(
+                  rewriter, loc, rewriter.getIndexType(), quotient_K, c2);
+            }
+          }
+        }
+
+        // compute tiled dot-product
+        SmallVector<Value> accumulators = createTiledDp(
+            rewriter, locNewInnerLoop, ops, lhsClone->getResult(0), matB,
+            ipType, opType, iterArgsNewInnerLoop, blockingFactor, isVnni,
+            packedBuffer, pack, indxToStoreInBuffer, indxToLoadFromBuffer);
+
+        scf::YieldOp::create(rewriterNewInnerLoop, locNewInnerLoop,
+                             accumulators);
+      });
+
+  return newLoop;
+}
+
 // Implements tiled dot-product operation for a vector.contract operation or a
 // sequence of vector.contracts inside the reduction loops.
 //
-// For example - for F32 type:
+// For example:
+// Case 1: register blocked vector.contract with prepacked input
 // ```
 //   vector.transfer_read %arg0 {{.}*} : memref<16x32x4xi8>, vector<16x16x4xi8>
 //   vector.transfer_read %arg1 {{.}*} : memref<16x32x4xi8>, vector<16x16x4xi8>
@@ -293,6 +617,52 @@ static SmallVector<Value> createTileZeros(OpBuilder &rewriter, Location loc,
 //   amx.tile_muli !amx.tile<16x64xi8> -> !amx.tile<16x16xi32>
 //   amx.tile_store %arg2{{.}*} : memref<32x32xi32>, !amx.tile<16x16xi32>
 // ```
+//
+//
+// Case2: vector.contract with register blocked
+//
+// Output IR with online packing (with s/w pipeline advantage):
+// s/w pipeline: load, pack to VNNI, and store the B sub matrix
+// of the 0th batch-reduce and K iteration.
+// scf.for (0 to 31) {
+// 	- load 0th and 1st  vector<32xbf16>, pack into VNNI, store the
+// 	first shuffle in 0th and 2nd shuffle in 16th index of the
+// 	buffer.
+// }
+// scf.for (0 to br-2) { batch-reduce loop
+//   scf.for (0 to k-2) { K loop
+// 	- load A matrix
+//	- scf.loop for s/w pipeline: load, pack to VNNI, and store the B sub
+// matrix 	for the next K loop iteration 	(c) load VNNI pack B matrix of K
+// iteration from the buffer 	(d) compute the tiled dot-product
+//   }
+//   Last iteration of the the K Loop (k-1) {
+//      - load A matrix
+//      - scf.loop for s/w pipeline: load, pack to VNNI, and store the B sub
+//      matrix for the next batch-reduce + K loop iteration (c) load VNNI pack B
+//      matrix of K iteration from the buffer (d) compute the tiled dot-product
+//   }
+// }
+// Last iteration of the batch-reduce loop (br-1) {
+//   scf.for (0 to k-2) { K loop
+//      - load A matrix
+//      - scf.loop for s/w pipeline: load, pack to VNNI, and store the B sub
+//      matrix for the next K loop iteration (c) load VNNI pack B matrix of K
+//      iteration from the buffer (d) compute the tiled dot-product
+//   }...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/188192


More information about the Mlir-commits mailing list