[Mlir-commits] [mlir] [mlir][x86vector] Lower BF16 vector.contract to FMA using AVX2 BF16 packed ops. (PR #170267)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Dec 8 05:19:20 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-vector

Author: Arun Thangamani (arun-thmn)

<details>
<summary>Changes</summary>

A `transform` pass to lower `BF16` type `vector.contract` to `vector.fma` using `AVX2` BF16 packed operations:

- `vbcstnebf162ps` - Broadcasts BF16/F16 into packed F32.
- `vcvtneebf162ps` - Convert packed BF16/F16 even-indexed elements into packed F32.
- `vcvtneobf162ps` - Convert packed BF16/F16 odd-indexed elements into packed F32 Data.


---

Patch is 38.84 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/170267.diff


11 Files Affected:

- (modified) mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td (+11) 
- (modified) mlir/include/mlir/Dialect/X86Vector/Transforms.h (+5) 
- (added) mlir/include/mlir/Dialect/X86Vector/Utils/X86VectorUtils.h (+43) 
- (modified) mlir/lib/Dialect/X86Vector/CMakeLists.txt (+1) 
- (modified) mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp (+5) 
- (modified) mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt (+3-1) 
- (added) mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp (+274) 
- (modified) mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp (+1-86) 
- (added) mlir/lib/Dialect/X86Vector/Utils/CMakeLists.txt (+13) 
- (added) mlir/lib/Dialect/X86Vector/Utils/X86VectorUtils.cpp (+111) 
- (added) mlir/test/Dialect/X86Vector/vector-contract-bf16-to-fma.mlir (+301) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
index 3c5294ff14fc7..9c3ed1c8092a1 100644
--- a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
@@ -38,6 +38,17 @@ def ApplyVectorContractToPackedTypeDotProductPatternsOp : Op<Transform_Dialect,
   let assemblyFormat = "attr-dict";
 }
 
+def ApplyVectorContractBF16ToFMAPatternsOp : Op<Transform_Dialect,
+    "apply_patterns.x86vector.vector_contract_bf16_to_fma",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Collect patterns to lower a BF16 type vector.contract operation
+        to a FMA via emulation lowering using BF16 packed operations.
+  }];
+
+  let assemblyFormat = "attr-dict";
+}
+
 
 #endif // X86VECTOR_TRANSFORM_OPS
 
diff --git a/mlir/include/mlir/Dialect/X86Vector/Transforms.h b/mlir/include/mlir/Dialect/X86Vector/Transforms.h
index fc46dff63c2b7..c4960ae28cb4f 100644
--- a/mlir/include/mlir/Dialect/X86Vector/Transforms.h
+++ b/mlir/include/mlir/Dialect/X86Vector/Transforms.h
@@ -91,6 +91,11 @@ void populateVectorContractToFMAPatterns(RewritePatternSet &patterns);
 void populateVectorContractToPackedTypeDotProductPatterns(
     RewritePatternSet &patterns);
 
+// A set of patterns for lowering 32-bit packed BF16 vector contraction
+// operations to vector fused multiply-add (FMA) operations, following
+// the emulation-based approach using BF16 packed operations.
+void populateVectorContractBF16ToFMAPatterns(RewritePatternSet &patterns);
+
 //===----------------------------------------------------------------------===//
 /// Helpers extracted from:
 ///   - clang/lib/Headers/avxintrin.h
diff --git a/mlir/include/mlir/Dialect/X86Vector/Utils/X86VectorUtils.h b/mlir/include/mlir/Dialect/X86Vector/Utils/X86VectorUtils.h
new file mode 100644
index 0000000000000..8a76009ddb907
--- /dev/null
+++ b/mlir/include/mlir/Dialect/X86Vector/Utils/X86VectorUtils.h
@@ -0,0 +1,43 @@
+//===- X86VectorUtils.h - X86Vector Utilities -------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_X86VECTOR_UTILS_X86VECTORUTILS_H_
+#define MLIR_DIALECT_X86VECTOR_UTILS_X86VECTORUTILS_H_
+
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/ArrayRef.h"
+#include <cstdint>
+#include <optional>
+#include <string>
+
+namespace mlir {
+class Type;
+class ShapedType;
+class OpOperand;
+class AffineDimExpr;
+class AffineMap;
+class VectorType;
+class Operation;
+
+namespace x86vector {
+enum class VnniOperandRank {
+  TRANSPOSE = 3,
+  GEMM = 3,
+  BRGEMM_INS = 4,
+  BRGEMM_OUTS = 3
+};
+
+// Return true if the operation is in VNNI layout.
+// Optionally, the check can be constrained to a specific VNNI blocking factor.
+bool isInVnniLayout(Operation *op, llvm::ArrayRef<AffineMap> indexingMaps,
+                    std::optional<unsigned> blockingFactor = std::nullopt);
+
+} // namespace x86vector
+} // namespace mlir
+
+#endif // MLIR_DIALECT_X86VECTOR_UTILS_X86VECTORUTILS_H_
diff --git a/mlir/lib/Dialect/X86Vector/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/CMakeLists.txt
index cb1e9d01821a2..329a6c3e80254 100644
--- a/mlir/lib/Dialect/X86Vector/CMakeLists.txt
+++ b/mlir/lib/Dialect/X86Vector/CMakeLists.txt
@@ -1,3 +1,4 @@
 add_subdirectory(IR)
 add_subdirectory(Transforms)
 add_subdirectory(TransformOps)
+add_subdirectory(Utils)
diff --git a/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
index 95db208207672..172f159b43f80 100644
--- a/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
@@ -32,6 +32,11 @@ void mlir::transform::ApplyVectorContractToPackedTypeDotProductPatternsOp::
   x86vector::populateVectorContractToPackedTypeDotProductPatterns(patterns);
 }
 
+void mlir::transform::ApplyVectorContractBF16ToFMAPatternsOp::populatePatterns(
+    RewritePatternSet &patterns) {
+  x86vector::populateVectorContractBF16ToFMAPatterns(patterns);
+}
+
 //===----------------------------------------------------------------------===//
 // Transform op registration
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
index 2cab50fb591c4..9eb94691753cf 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
@@ -3,14 +3,16 @@ add_mlir_dialect_library(MLIRX86VectorTransforms
   LegalizeForLLVMExport.cpp
   VectorContractToFMA.cpp
   VectorContractToPackedTypeDotProduct.cpp
+  VectorContractBF16ToFMA.cpp
 
   LINK_LIBS PUBLIC
   MLIRArithDialect
-  MLIRX86VectorDialect
   MLIRIR
   MLIRLinalgDialect
   MLIRLLVMCommonConversion
   MLIRLLVMDialect
   MLIRVectorDialect
   MLIRVectorUtils
+  MLIRX86VectorDialect
+  MLIRX86VectorUtils
   )
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp
new file mode 100644
index 0000000000000..43c3a38b277b7
--- /dev/null
+++ b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractBF16ToFMA.cpp
@@ -0,0 +1,274 @@
+//===- VectorContractBF16ToFMA.cpp
+//--------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/Dialect/X86Vector/Transforms.h"
+#include "mlir/Dialect/X86Vector/Utils/X86VectorUtils.h"
+#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
+
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/PatternMatch.h"
+
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/Casting.h"
+
+using namespace mlir;
+using namespace mlir::vector;
+using namespace mlir::x86vector;
+
+// This function retrives the source operation of the load or transfer
+// reads and creates subviews for the BF16 packed-operations to
+// broadcast or load BF16 elements as F32 packed elements.
+//
+// For example:
+// ```
+//   vector.load %arg0[%c0, %c0, %c0]:memref<4x1x2xbf16>,vector<1x1x2xbf16>
+//   vector.load %arg0[%c0, %c0, %c0]:memref<1x32x2xbf16>,vector<1x8x2xbf16>
+// ```
+// to
+// ```
+//   memref.subview %arg0[%c0,%c0,%c1]:memref<4x1x2xbf16> to memref<1x1x1xbf16>
+//   memref.subview %arg1[%c0,%c0,%c0]:memref<1x32x2xbf16> to memref<1x8x2xbf16>
+//   memref.subview %arg0[%c0,%c0,%c0]:memref<4x1x2xbf16> to memref<1x1x1xbf16>
+// ```
+static FailureOr<llvm::SmallVector<mlir::memref::SubViewOp>>
+getSubviewFromVectorInput(Location loc, PatternRewriter &rewriter,
+                          mlir::Value prodOp, int64_t mnDim, int64_t vnniDim,
+                          int64_t mnDimIndx) {
+
+  llvm::SmallVector<mlir::memref::SubViewOp> subviews;
+
+  Value srcOperation;
+  SmallVector<OpFoldResult> indexVals;
+
+  if (auto transferRead =
+          prodOp.getDefiningOp<mlir::vector::TransferReadOp>()) {
+    srcOperation = transferRead.getOperand(0);
+    SmallVector<OpFoldResult> indexValues(transferRead.getIndices().begin(),
+                                          transferRead.getIndices().end());
+    indexVals = indexValues;
+  }
+
+  if (auto load = prodOp.getDefiningOp<mlir::vector::LoadOp>()) {
+    srcOperation = load.getOperand(0);
+    SmallVector<OpFoldResult> indexValues(load.getIndices().begin(),
+                                          load.getIndices().end());
+    indexVals = indexValues;
+  }
+
+  if (!srcOperation)
+    return failure();
+
+  Type srcType = srcOperation.getType();
+  if (!llvm::isa<mlir::MemRefType>(srcType))
+    return failure();
+
+  llvm::SmallVector<OpFoldResult> strides;
+  llvm::SmallVector<OpFoldResult> sizes;
+
+  for (unsigned int i = 0; i < indexVals.size(); i++) {
+    strides.push_back(rewriter.getIndexAttr(1));
+    sizes.push_back(rewriter.getIndexAttr(1));
+  }
+
+  sizes[indexVals.size() - 1] = rewriter.getIndexAttr(vnniDim);
+  sizes[indexVals.size() - mnDimIndx] = rewriter.getIndexAttr(mnDim);
+
+  if (mnDim == 1) {
+    indexVals[indexVals.size() - 1] = rewriter.getIndexAttr(1);
+  }
+
+  auto subview = memref::SubViewOp::create(rewriter, loc, srcOperation,
+                                           indexVals, sizes, strides);
+  subviews.push_back(subview);
+
+  if (mnDim == 1) {
+    indexVals[indexVals.size() - 1] = rewriter.getIndexAttr(0);
+    sizes[indexVals.size() - 1] = rewriter.getIndexAttr(1);
+
+    auto unitDimEvenIndxSubview = memref::SubViewOp::create(
+        rewriter, loc, srcOperation, indexVals, sizes, strides);
+    subviews.push_back(unitDimEvenIndxSubview);
+  }
+
+  return subviews;
+}
+
+// Implements outer product contraction as a sequence of BF16-packed
+// operation even/odd loads and FMA operations.
+//
+// For example:
+// ```
+//   %1 = vector.load from memref (%m1) -> vector<1x1x2xbf16>
+//   %2 = vector.load from memref (%m2) -> vector<1x8x2xbf16>
+//   return vector.contract %1, %2, %arg1
+// ```
+// to
+// ```
+//   %1 = x86vector.avx.bcst_to_f32.packed %m1[c1] -> vector<8xf32>
+//   %2 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %m2 -> vector<8xf32>
+//   %3 = vector.fma %1, %2, %arg1
+//   %4 = x86vector.avx.bcst_to_f32.packed %m1[c0] -> vector<8xf32>
+//   %5 = x86vector.avx.cvt.packed.even.indexed_to_f32 %m2 -> vector<8xf32>
+//   return vector.fma %4, %5, %3
+// ```
+struct VectorContractBF16ToFMA
+    : public OpRewritePattern<vector::ContractionOp> {
+  using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
+                                PatternRewriter &rewriter) const override {
+
+    if (contractOp.getKind() != vector::CombiningKind::ADD)
+      return rewriter.notifyMatchFailure(contractOp,
+                                         "Expects add combining kind.");
+
+    // TODO: Move this validation to a comon utility folder. Planned to
+    // do once (code refactoring), all architecture specific nanokernel
+    // passes are merged into the repo.
+    VectorType lhsTy = contractOp.getLhsType();
+    if (!lhsTy.getElementType().isBF16())
+      return rewriter.notifyMatchFailure(contractOp,
+                                         "Only BF16 lowering is supported.");
+
+    if (!isInVnniLayout(contractOp.getOperation(),
+                        contractOp.getIndexingMapsArray(), 2))
+      return rewriter.notifyMatchFailure(contractOp,
+                                         "Input matrices not in VNNI format.");
+
+    ArrayRef<int64_t> lhsShape = lhsTy.getShape();
+    llvm::SmallVector<int64_t> nonUnitDimLhs;
+    llvm::copy_if(lhsShape, std::back_inserter(nonUnitDimLhs),
+                  [](int64_t dim) { return dim != 1; });
+
+    VectorType rhsTy = contractOp.getRhsType();
+    ArrayRef<int64_t> rhsShape = rhsTy.getShape();
+    llvm::SmallVector<int64_t> nonUnitDimRhs;
+    llvm::copy_if(rhsShape, std::back_inserter(nonUnitDimRhs),
+                  [](int64_t dim) { return dim != 1; });
+
+    if ((nonUnitDimLhs.size() - 1) > 0 && (nonUnitDimRhs.size() - 1) > 0)
+      return rewriter.notifyMatchFailure(contractOp,
+                                         "Excepts unit dimensions for either "
+                                         "LHS or RHS shape other than VNNI.");
+
+    if ((nonUnitDimLhs.size() - 1) != 1 && (nonUnitDimRhs.size() - 1) != 1)
+      return rewriter.notifyMatchFailure(
+          contractOp,
+          "Excepts a one non-unit A/B dimension for either LHS or RHS shape.");
+
+    VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
+    if (!accTy)
+      return rewriter.notifyMatchFailure(contractOp, "Wrong accmulator type.");
+
+    if ((lhsTy.getElementType().isBF16() && !accTy.getElementType().isF32()))
+      return rewriter.notifyMatchFailure(
+          contractOp, "Only F32 acumulation supported for BF16 type.");
+
+    ArrayRef<int64_t> accShape = accTy.getShape();
+    llvm::SmallVector<int64_t> nonUnitDimAcc;
+    llvm::copy_if(accShape, std::back_inserter(nonUnitDimAcc),
+                  [](int64_t dim) { return dim != 1; });
+    if (nonUnitDimAcc.size() != 1)
+      return rewriter.notifyMatchFailure(
+          contractOp, "A or B should be a non-unit dim in acc.");
+
+    // Non-unit dimensions should match the vector length of BF16.
+    unsigned int nonUnitDim = nonUnitDimLhs.size() == 2 ? nonUnitDimLhs.front()
+                                                        : nonUnitDimRhs.front();
+    if (nonUnitDim != 4 && nonUnitDim != 8 &&
+        !(nonUnitDimAcc.front() == nonUnitDim))
+      return rewriter.notifyMatchFailure(
+          contractOp, "BF16 packed load operation expects non-unit (LHR or "
+                      "RHS) dim and acc dim of size 4/8.");
+
+    // Lower vector.contract to FMAs with help of BF16 packed ops.
+    auto loc = contractOp.getLoc();
+    llvm::SmallVector<mlir::memref::SubViewOp> unitDimSubview;
+    llvm::SmallVector<mlir::memref::SubViewOp> nonUnitDimSubview;
+
+    // create the unit-dimension LHS or RHS subview and the
+    // corresponding non-unit dimension LHS or RHS subview on the other-side.
+    // For example, if LHS has type vector<1x1x2xbf16> and RHS has type
+    // vector<1x8x2xbf16>, we create two subview for the LHS and one subview
+    // for the RHS. In the opposite case (non-unit dimension on the LHS), we
+    // do vice-versa.
+    if ((nonUnitDimRhs.size() - 1) > 0) {
+      auto unitSubview = getSubviewFromVectorInput(
+          loc, rewriter, contractOp.getLhs(), 1, 1, 2);
+      auto nonUnitSubview = getSubviewFromVectorInput(
+          loc, rewriter, contractOp.getRhs(), nonUnitDimRhs.front(), 2, 2);
+      if (failed(unitSubview) || failed(nonUnitSubview))
+        return rewriter.notifyMatchFailure(
+            contractOp, " The input source is not MemRef Type.");
+
+      unitDimSubview = *unitSubview;
+      nonUnitDimSubview = *nonUnitSubview;
+
+    } else {
+      auto nonUnitSubview = getSubviewFromVectorInput(
+          loc, rewriter, contractOp.getLhs(), nonUnitDimRhs.front(), 2, 3);
+      auto unitSubview = getSubviewFromVectorInput(
+          loc, rewriter, contractOp.getRhs(), 1, 1, 2);
+      if (failed(unitSubview) || failed(nonUnitSubview))
+        return rewriter.notifyMatchFailure(
+            contractOp, " The input source is not MemRef Type.");
+
+      unitDimSubview = *unitSubview;
+      nonUnitDimSubview = *nonUnitSubview;
+    }
+
+    auto castAcc = vector::ShapeCastOp::create(
+        rewriter, loc,
+        VectorType::get(nonUnitDimAcc.front(), accTy.getElementType()),
+        contractOp.getAcc());
+    mlir::VectorType dstType =
+        mlir::VectorType::get(nonUnitDimAcc.front(), rewriter.getF32Type());
+
+    // Load, broadcast, and do FMA for odd indexed BF16 elements.
+    auto loadBcstOddIndxElementToF32 = x86vector::BcstToPackedF32Op::create(
+        rewriter, loc, dstType, unitDimSubview[0]);
+    auto loadOddIndxElementF32 = x86vector::CvtPackedOddIndexedToF32Op::create(
+        rewriter, loc, dstType, nonUnitDimSubview[0]);
+    auto oddIndxFMA =
+        vector::FMAOp::create(rewriter, loc, loadBcstOddIndxElementToF32,
+                              loadOddIndxElementF32, castAcc);
+
+    llvm::SmallVector<Operation *> users;
+    for (OpResult result : contractOp->getResults())
+      for (Operation *user : result.getUsers())
+        users.push_back(user);
+
+    if (users.size() == 1) {
+      rewriter.setInsertionPoint(users[0]);
+    }
+
+    // Load, broadcast, and do FMA for even indexed BF16 elements.
+    auto loadBcstEvenIndxElementToF32 = x86vector::BcstToPackedF32Op::create(
+        rewriter, loc, dstType, unitDimSubview[1]);
+    auto loadEvenIndxElementF32 =
+        x86vector::CvtPackedEvenIndexedToF32Op::create(rewriter, loc, dstType,
+                                                       nonUnitDimSubview[0]);
+    vector::FMAOp fma =
+        vector::FMAOp::create(rewriter, loc, loadBcstEvenIndxElementToF32,
+                              loadEvenIndxElementF32, oddIndxFMA);
+
+    auto castFma = vector::ShapeCastOp::create(rewriter, loc, accTy, fma);
+    rewriter.replaceOp(contractOp, castFma);
+    return success();
+  }
+};
+
+void x86vector::populateVectorContractBF16ToFMAPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<VectorContractBF16ToFMA>(patterns.getContext());
+}
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp
index 1e64811db910b..a00a3e5bdd766 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp
@@ -11,6 +11,7 @@
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
 #include "mlir/Dialect/X86Vector/Transforms.h"
+#include "mlir/Dialect/X86Vector/Utils/X86VectorUtils.h"
 #include "mlir/Dialect/X86Vector/X86VectorDialect.h"
 
 #include "mlir/IR/BuiltinAttributes.h"
@@ -26,92 +27,6 @@ using namespace mlir::x86vector;
 
 namespace {
 
-static FailureOr<SmallVector<mlir::utils::IteratorType>>
-inferIteratorsFromOutMap(AffineMap map) {
-  if (!map.isProjectedPermutation())
-    return failure();
-  SmallVector<mlir::utils::IteratorType> iterators(
-      map.getNumDims(), mlir::utils::IteratorType::reduction);
-  for (auto expr : map.getResults())
-    if (auto dim = dyn_cast<AffineDimExpr>(expr))
-      iterators[dim.getPosition()] = mlir::utils::IteratorType::parallel;
-  return iterators;
-}
-
-// Returns true if the operation is in VNNI layout.
-// Optionally, the check can be constrained to a specific VNNI blocking factor.
-static bool isInVnniLayout(Operation *op, ArrayRef<AffineMap> indexingMaps,
-                           std::optional<unsigned> blockingFactor) {
-  // Narrow down type operations - VNNI only applies to contractions.
-  FailureOr<linalg::ContractionDimensions> dims =
-      linalg::inferContractionDims(indexingMaps);
-  if (failed(dims))
-    return false;
-
-  auto matA = op->getOperand(0);
-  auto matB = op->getOperand(1);
-  auto typeA = dyn_cast<ShapedType>(matA.getType());
-  auto typeB = dyn_cast<ShapedType>(matB.getType());
-  unsigned rankA = typeA.getRank();
-  unsigned rankB = typeB.getRank();
-  // VNNI format requires at least 1 parallel and 2 reduction dimensions.
-  if (rankA < 3 || rankB < 3)
-    return false;
-
-  // At least two reduction dimensions are expected:
-  // one for the VNNI factor and one for the K dimension
-  if (dims->k.size() < 2)
-    return false;
-
-  // Validate affine maps - VNNI computation should be defined by the two
-  // innermost reduction iterators.
-  // The input matrix dimensions layout must match the following:
-  //   - matrix A - [...][K/vnniFactor][vnniFactor]
-  //   - matrix B - [...][K/vnniFactor][N][vnniFactor]
-  auto maybeIters = inferIteratorsFromOutMap(indexingMaps[2]);
-  if (failed(maybeIters))
-    return false;
-  SmallVector<mlir::utils::IteratorType> iteratorTypes = *maybeIters;
-  AffineMap mapA = indexingMaps[0];
-  AffineMap mapB = indexingMaps[1];
-
-  auto vnniDimA = dyn_cast<AffineDimExpr>(mapA.getResult(rankA - 1));
-  auto vnniDimB = dyn_cast<AffineDimExpr>(mapB.getResult(rankB - 1));
-  if (!vnniDimA || !vnniDimB || vnniDimA != vnniDimB ||
-      iteratorTypes[vnniDimA.getPosition()] !=
-          mlir::utils::IteratorType::reduction)
-    return false;
-  auto redDimA = dyn_cast<AffineDimExpr>(mapA.getResult(rankA - 2));
-  auto redDimB...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/170267


More information about the Mlir-commits mailing list