[Mlir-commits] [mlir] [MLIR][AArch64] Lower `vector.contract` to SVE FEAT_BF16 operations (PR #147052)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jul 4 06:26:54 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-neon

Author: Momchil Velikov (momchil-velikov)

<details>
<summary>Changes</summary>

This patch adds lowering of Bfloat16 widening matrix multiply and accumulate `vector.contract`, by parametrising and refactoring the pattern for 8-bit integers.

---

Patch is 65.28 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/147052.diff


12 Files Affected:

- (modified) mlir/include/mlir/Conversion/Passes.td (+4) 
- (modified) mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.td (+11-1) 
- (modified) mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h (+2) 
- (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp (+3) 
- (modified) mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp (+1-1) 
- (modified) mlir/lib/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.cpp (+6-1) 
- (modified) mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt (+1-1) 
- (added) mlir/lib/Dialect/ArmSVE/Transforms/LowerContractToSVEPatterns.cpp (+607) 
- (removed) mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp (-366) 
- (added) mlir/test/Dialect/Vector/CPU/ArmSVE/vector-bfmmla.mlir (+105) 
- (added) mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/vector-contract-bfmmla.mlir (+201) 
- (modified) mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/vector-contract-i8mm.mlir (+3-3) 


``````````diff
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 5a864865adffc..4f304b39a0528 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1437,6 +1437,10 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
            "bool", /*default=*/"false",
            "Enables the use of Arm FEAT_I8MM instructions while lowering "
            "the vector dialect.">,
+    Option<"armBF16", "enable-arm-bf16",
+           "bool", /*default=*/"false",
+           "Enables the use of Arm FEAT_BF16 instructions while lowering "
+           "the vector dialect.">,
     Option<"x86Vector", "enable-x86vector",
            "bool", /*default=*/"false",
            "Enables the use of X86Vector dialect while lowering the vector "
diff --git a/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.td b/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.td
index 53784982be6dc..81b3c736b93f3 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.td
@@ -12,7 +12,7 @@ include "mlir/Dialect/Transform/IR/TransformAttrs.td"
 include "mlir/Dialect/Transform/IR/TransformDialect.td"
 include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
 
-def ApplyArmSVELowerContractionPatternsOp
+def ApplyArmSVELowerContractionToI8MMPatternsOp
     : Op<Transform_Dialect, "apply_patterns.arm_sve.vector_contract_to_i8mm",
          [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
   let description = [{
@@ -23,4 +23,14 @@ def ApplyArmSVELowerContractionPatternsOp
   let assemblyFormat = "attr-dict";
 }
 
+def ApplyArmSVELowerContractionToBFMMLAPatternsOp
+    : Op<Transform_Dialect, "apply_patterns.arm_sve.vector_contract_to_bfmmla",
+         [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Indicates that vector contraction-like operations should be lowered to
+    finer-grained vector primitives using the ArmSVE dialect.
+  }];
+
+  let assemblyFormat = "attr-dict";
+}
 #endif // ARMSVE_VECTOR_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h
index 232e2be29e574..de160dbf8ed94 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h
@@ -23,6 +23,8 @@ void populateArmSVELegalizeForLLVMExportPatterns(
 void populateLowerContractionToSVEI8MMPatternPatterns(
     RewritePatternSet &patterns);
 
+void populateLowerContractionToSVEBFMMLAPatterns(RewritePatternSet &patterns);
+
 /// Configure the target to support lowering ArmSVE ops to ops that map to LLVM
 /// intrinsics.
 void configureArmSVELegalizeForExportTarget(LLVMConversionTarget &target);
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 67c0eca15638a..4d74aabcaa50d 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -89,6 +89,9 @@ void ConvertVectorToLLVMPass::runOnOperation() {
       if (armSVE)
         populateLowerContractionToSVEI8MMPatternPatterns(patterns);
     }
+    if (armBF16)
+      populateLowerContractionToSVEBFMMLAPatterns(patterns);
+
     (void)applyPatternsGreedily(getOperation(), std::move(patterns));
   }
 
diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp
index 7180884c77e98..a95fc51d562c2 100644
--- a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp
+++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp
@@ -12,7 +12,7 @@
 // TODO: There may be opportunities to unify this with a similar pattern
 // for SVE. See:
 //   https://github.com/llvm/llvm-project/issues/145559
-//   LowerContractionToSVEI8MMPattern.cpp
+//   LowerContractToSVEPatterns.cpp
 //
 //===----------------------------------------------------------------------===//
 
diff --git a/mlir/lib/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.cpp b/mlir/lib/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.cpp
index b2ca4fc1eaa8c..8572c34c8b12b 100644
--- a/mlir/lib/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.cpp
+++ b/mlir/lib/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.cpp
@@ -18,11 +18,16 @@ using namespace mlir;
 // Apply...PatternsOp
 //===----------------------------------------------------------------------===//
 
-void transform::ApplyArmSVELowerContractionPatternsOp::populatePatterns(
+void transform::ApplyArmSVELowerContractionToI8MMPatternsOp::populatePatterns(
     RewritePatternSet &patterns) {
   mlir::populateLowerContractionToSVEI8MMPatternPatterns(patterns);
 }
 
+void transform::ApplyArmSVELowerContractionToBFMMLAPatternsOp::populatePatterns(
+    RewritePatternSet &patterns) {
+  mlir::populateLowerContractionToSVEBFMMLAPatterns(patterns);
+}
+
 //===----------------------------------------------------------------------===//
 // Transform op registration
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt
index 65f98b44b1b69..c29eaca244b4a 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt
@@ -1,7 +1,7 @@
 add_mlir_dialect_library(MLIRArmSVETransforms
   LegalizeForLLVMExport.cpp
   LegalizeVectorStorage.cpp
-  LowerContractionToSVEI8MMPattern.cpp
+  LowerContractToSVEPatterns.cpp
 
   DEPENDS
   MLIRArmSVEConversionsIncGen
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractToSVEPatterns.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractToSVEPatterns.cpp
new file mode 100644
index 0000000000000..2987287afe9cd
--- /dev/null
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractToSVEPatterns.cpp
@@ -0,0 +1,607 @@
+//===- LowerContractToSVEPatterns.cpp - Contract to I8MM/BF16 ---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements lowering patterns from vector.contract to operations
+// that map to instructions from the SVE FEAT_I8MM and FEAT_BF16 extensions.
+//
+// TODO: There may be opportunities to unify this with a similar pattern
+// for Neon. See:
+//   https://github.com/llvm/llvm-project/issues/145559
+//   LowerContractionToNeonI8MMPattern.cpp
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
+#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#include <numeric>
+
+#define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm"
+
+using namespace mlir;
+
+namespace {
+// Get the operand of a `vector.contract`. This function is intended to abstract
+// away from the particular way a value is extended before feeding it into the
+// `vector.contract` - via zero-extend or an explicit or implicit sign-extend
+// (for implicit sign-extension see `vector.contract` documentation).
+//
+// The template parameter `Op` indicates the extension operation (explicit or
+// implicit) for which we are checking.
+//
+// Return success only for extensions from `i8` to `i32`.
+template <typename Op>
+std::optional<Value> getExtOperand(Value v) {
+
+  static_assert(llvm::is_one_of<Op, arith::ExtSIOp, arith::ExtUIOp>::value,
+                "Must be instantiated with either sign- or zero- extension op");
+
+  // If the operand is not defined by an explicit extend operation of the
+  // accepted operation type allow for an implicit sign-extension.
+  auto extOp = dyn_cast_or_null<Op>(v.getDefiningOp());
+  if (!extOp) {
+    if constexpr (std::is_same<Op, arith::ExtSIOp>::value) {
+      auto vTy = cast<VectorType>(v.getType());
+      if (!vTy.getElementType().isSignlessInteger(8))
+        return {};
+      return v;
+    }
+    return {};
+  }
+
+  // If the operand is defined by an explicit extend operation of the accepted
+  // operation type, check it's extended from `i8` to `i32`.
+  auto inOp = extOp.getIn();
+  auto inTy = dyn_cast<VectorType>(inOp.getType());
+  if (!inTy || !inTy.getElementType().isSignlessInteger(8))
+    return {};
+
+  auto outTy = dyn_cast<VectorType>(extOp.getType());
+  if (!outTy || !outTy.getElementType().isSignlessInteger(32))
+    return {};
+
+  return inOp;
+}
+
+/// This class encapsulates the algorithm and parametrisation (in terms of types
+/// and dimensions) of lowering a `vector.contract` to "primitive" matrix
+/// multiplication operations of the SVE dialect (here "primitive" would mean
+/// corresponding to a single target instruction).
+///
+/// Supported are lowering to FEAT_I8MM `smmla`, `ummla`, and `usmmla`, and to
+/// FEAT_BF16 `bfmmla`. All the transformations are very similar to each other
+/// for concreteness the description below is given for `smmla`.
+///
+/// The lowering triggers for a contraction operation that performs a matrix
+/// multiply of two 8-bit integer matrix tiles with logical dimensions
+/// <Mx8> and <8x[N]> for the left-hand side (LHS) and the right-hand side
+/// (RHS), respectively, added to a 32-bit integer accumulator operand (ACC)
+/// with dimensions <Mx[N]>, yielding a <Mx[N]> 32-bit integer result (OUT).
+///
+/// The operands' shapes are such that the operands can be evenly split into
+/// sub-tiles with dimensions as expected by the targeted FEAT_I8MM
+/// instructions. The intent is that M and N are chosen (by higher level
+/// transforms) in such a way as to maximise register usage. The main use case
+/// we envision as of now is MMT4D, thus the RHS operand is expected
+/// pre-transposed.
+///
+/// The matrix multiplication is performed by unrolling the usual tiled matrix
+/// multiplication algorithm using sub-tiles with dimensions <2x8> for the
+/// LHS, <8x[2]> for the RHS, and <2x[2]> for the result and the input
+/// accumulator.
+///
+/// One way to illustrate the operation is as follows:
+///
+/// RHS<8x[N]>:       <8x[2]> <8x[2]> ... <8x[2]>
+///                 +-----------------------------
+/// LHS<Mx8>: <2x8> | <2x[2]> <2x[2]> ... <2x[2]>
+///           <2x8> | <2x[2]> <2x[2]> ... <2x[2]>
+///            ...  |   ...     ...   ...   ...
+///           <2x8> | <2x[2]> <2x[2]> ... <2x[2]>
+///
+/// The RHS operand is unpacked into N/2 values, each representing a sequence
+/// of VSCALE number of sub-tiles with dimensions <8x2>.
+/// The LHS operand is initially unpacked into M/2 values, each representing a
+/// sub-tile with dimensions <2x8>, and then each such sub-tile is replicated
+/// VSCALE times. Multiplying thus replicated LHS sub-tile by the corresponding
+/// RHS sub-tile correctly computes an entire result sub-tile.
+/// The 2x2 sub-tiles of the ACC and OUT have rows that are not adjacent
+/// (in memory or when imposing a row-major layout on the 2D vector value).
+/// Reading the ACC is implemented as reading two consecutive rows and
+/// interleaving the by pairs to obtain a vector having length twice the length
+/// of an ACC row. This vector now is a sequence of one-dimensional tiles with
+/// the exact layout needed by the `smmla`/`bfmmla`/etc instructions, which
+/// tiles are extracted one by one. For illustration, if we have an 2x4 ACC tile
+///   a0 a1 b0 b1
+///   a2 a3 b2 b3
+/// we read the two rows as separate values and then interleave by pairs
+/// to obtain
+///   a0 a1 a2 a3 b0 b1 b2 b3
+/// from which we extract `a0 a1 a2 a3` and `b0 b1 b2 b3`.
+///
+/// Writing the OUT tile is done by the reverse of the above procedure,
+/// concatenate two "flattened" sub-tiles into
+///   c0 c1 c2 c3 d0 d1 d2 d3
+/// deinterleave by pairs to obtain as separate values
+///   c0 c1 d0 d1
+///   c2 c3 d2 d3
+/// which are then inserted into the final result.
+///
+/// Multiplication of a signed LHS by an unsigned LHS is performed by
+/// swapping the order of the operands and emitting an `usmmla` (since there
+/// isn't an `summla` instruction). Therefore each ACC sub-tile needs
+/// to be transposed before the addition and the sum, an OUT sub-tile,
+/// needs to be transposed before insertion into the final result.
+/// This is done very elegantly by a modification of the above to
+/// interleave/deinterleave not by pairs, but by individual elements, e.g.
+/// after ordinary interleave we obtain
+///   a0 a2 a1 a3 b0 b2 b1 b3
+/// which is exactly the desired layout of having each individual 2x2 tile
+/// transposed.
+///
+/// All of the above readily applies to FEAT_BF16 `bfmmla` with the
+/// difference that the shapes of the LHS, RHS are <Mx4>, <4x[M]>, and
+/// respectively, that is the "K" dimension is fixed to 4, instead of 8 (like
+/// for the integer case).
+class VectorContractRewriter {
+protected:
+  // Designate the operation (resp. instruction) used to do sub-tile matrix
+  // multiplications.
+  enum class MMLA {
+    Nop,
+    SignedInt,   // smmla
+    UnsignedInt, // ummla
+    MixedInt,    // usmmla
+    Bfloat       // bfmmla
+  };
+
+  // Lower-level operation to be emitted.
+  MMLA mmlaOp = MMLA::Nop;
+
+  // The operand tiles. These are not necessarily the operends of
+  // `vector.contract`, for example they could be operands to `arith.extsi`
+  // that is in turn fed into `vector.contract`.
+  Value lhs;
+  Value rhs;
+  Value acc;
+
+  // Conventional names for matrix dimensions.
+  int64_t M = 0;
+  int64_t N = 0;
+  int64_t K = 0;
+
+  // Single-dimensional vector types for the operands of the ArmSVE dialect
+  // op.
+  VectorType flatLhsType;
+  VectorType flatRhsType;
+  VectorType flatAccType;
+
+  // Single-dimension vector type for the entire RHS tile.
+  VectorType flatRhsTileType;
+
+  // Vector type having the same number of elements as a row in the
+  // accumulator/output tile and the same element type.
+  VectorType accRowTy;
+
+  // Vector type having twice the number of elements as a row in the
+  // accumulator/output tile the same element type.
+  VectorType accRowX2Ty;
+
+  // Vector type having half the number of elements as a row in the
+  // accumulator/output tile and an integer element type with twice the bit
+  // width.
+  VectorType accRow64Ty;
+  VectorType accRowX264Ty;
+
+  // Indicate if the operands for the ArmSVE dialect operation need to be
+  // swapped. Currently this is needed in order to emulate an "summla"
+  // operation.
+  bool swapOperands = false;
+
+  // Create the matrix mulitply and accumulate operation according to
+  // `mmlaOp`.
+  Value createMMLA(PatternRewriter &rewriter, Location loc, Value acc,
+                   Value lhs, Value rhs);
+
+  // Check general preconditions for applying the transformation, common to the
+  // integer and the bfloat16 case.
+  LogicalResult match(vector::ContractionOp op, PatternRewriter &rewriter);
+
+public:
+  VectorContractRewriter() = default;
+
+  // Do the actuall rewrite. This member function is shared by both integer and
+  // bfloat16 rewrites.
+  Value rewrite(vector::ContractionOp op, PatternRewriter &rewriter);
+};
+
+Value VectorContractRewriter::createMMLA(PatternRewriter &rewriter,
+                                         Location loc, Value acc, Value lhs,
+                                         Value rhs) {
+  if (swapOperands)
+    std::swap(lhs, rhs);
+
+  switch (mmlaOp) {
+  case MMLA::SignedInt:
+    return rewriter.create<arm_sve::SmmlaOp>(loc, flatAccType, acc, lhs, rhs);
+  case MMLA::UnsignedInt:
+    return rewriter.create<arm_sve::UmmlaOp>(loc, flatAccType, acc, lhs, rhs);
+  case MMLA::MixedInt:
+    return rewriter.create<arm_sve::UsmmlaOp>(loc, flatAccType, acc, lhs, rhs);
+  case MMLA::Bfloat:
+    return rewriter.create<arm_sve::BfmmlaOp>(loc, flatAccType, acc, lhs, rhs);
+  default:
+    llvm_unreachable("Uninitialized operation kind");
+  }
+}
+
+LogicalResult VectorContractRewriter::match(vector::ContractionOp op,
+                                            PatternRewriter &rewriter) {
+  // Check iterator types for matrix multiplication.
+  auto itTypes = op.getIteratorTypesArray();
+  if (itTypes.size() != 3 || itTypes[0] != vector::IteratorType::parallel ||
+      itTypes[1] != vector::IteratorType::parallel ||
+      itTypes[2] != vector::IteratorType::reduction)
+    return rewriter.notifyMatchFailure(
+        op, "iterator types do not correspond to matrix multiplication");
+
+  // Check permutation maps. For now only accept
+  //   lhs: (d0, d1, d2) -> (d0, d2)
+  //   rhs: (d0, d1, d2) -> (d1, d2)
+  //   acc: (d0, d1, d2) -> (d0, d1)
+  // This corresponds to matrix multiplication with transposed RHS.
+  if (op.getIndexingMapsArray()[0] !=
+          AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 2u},
+                                               op.getContext()) ||
+      op.getIndexingMapsArray()[1] !=
+          AffineMap::getMultiDimMapWithTargets(3, ArrayRef{1u, 2u},
+                                               op.getContext()) ||
+      op.getIndexingMapsArray()[2] != AffineMap::getMultiDimMapWithTargets(
+                                          3, ArrayRef{0u, 1u}, op.getContext()))
+    return rewriter.notifyMatchFailure(op, "non-matching permutation maps");
+
+  // Check the combining kind is addition.
+  if (op.getKind() != vector::CombiningKind::ADD)
+    return rewriter.notifyMatchFailure(op, "combining kind is not an addition");
+
+  return success();
+}
+
+Value VectorContractRewriter::rewrite(vector::ContractionOp op,
+                                      PatternRewriter &rewriter) {
+  Location loc = op.getLoc();
+
+  // Extract LHS sub-tiles with logical shape <2xK>.
+  SmallVector<Value> lhsTile;
+  for (int64_t i = 0; i < M; i += 2) {
+    // Extract two consecutive rows of the LHS tile.
+    auto r0 =
+        rewriter.create<vector::ExtractOp>(loc, lhs, ArrayRef<int64_t>{i});
+    auto r1 =
+        rewriter.create<vector::ExtractOp>(loc, lhs, ArrayRef<int64_t>{i + 1});
+    // Concatenate to obtain a 2 x K x <input-type> flattened sub-tile.
+    SmallVector<int64_t> shuffleIdx(2 * K);
+    std::iota(shuffleIdx.begin(), shuffleIdx.end(), 0);
+    auto t = rewriter.create<vector::ShuffleOp>(loc, r0, r1, shuffleIdx);
+    // Turn it into a scalable vector.
+    auto s = rewriter.create<vector::ScalableInsertOp>(
+        loc, t, rewriter.create<ub::PoisonOp>(loc, flatLhsType), 0);
+    // Replicate the sub-tile VSCALE times to fill the entire vector.
+    auto r = rewriter.create<arm_sve::DupQLaneOp>(loc, s, 0);
+    lhsTile.push_back(r);
+  }
+
+  // "Flatten" the RHS tile from <[N]xK> to <[N*K]>.
+  auto rhs = rewriter.create<vector::ShapeCastOp>(this->rhs.getLoc(),
+                                                  flatRhsTileType, this->rhs);
+
+  // Extract the RHS sub-tiles with logical shape <Kx[2]>.
+  SmallVector<Value> rhsTile;
+  for (int64_t j = 0; j < N; j += 2)
+    rhsTile.push_back(rewriter.create<vector::ScalableExtractOp>(
+        loc, flatRhsType, rhs, j * K));
+
+  // Extract and pack the ACC sub-tiles.
+  SmallVector<Value> accTile;
+  for (int64_t i = 0; i < M; i += 2) {
+    // Extract two consecutive rows of the accumulator tile.
+    auto r0 = rewriter.create<vector::ExtractOp>(loc, op.getAcc(),
+                                                 ArrayRef<int64_t>{i});
+    auto r1 = rewriter.create<vector::ExtractOp>(loc, op.getAcc(),
+                                                 ArrayRef<int64_t>{i + 1});
+    Value accTileVec;
+    if (swapOperands) {
+      // We are performing the operation with swapped LHS and RHS we need to
+      // transpose e...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/147052


More information about the Mlir-commits mailing list