[llvm-branch-commits] [mlir] [MLIR][ArmSVE] Add initial lowering of vector.contract to SVE `*MMLA` instructions (PR #135636)

Momchil Velikov via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Tue Apr 15 08:54:33 PDT 2025


https://github.com/momchil-velikov updated https://github.com/llvm/llvm-project/pull/135636

>From 8e87a7f3b1438d9542d28c90eb9593ebe8cf6500 Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Tue, 8 Apr 2025 14:43:54 +0000
Subject: [PATCH] [MLIR][ArmSVE] Add initial lowering of `vector.contract` to
 SVE `*MMLA` instructions

---
 mlir/include/mlir/Conversion/Passes.td        |   4 +
 .../Dialect/ArmSVE/Transforms/Transforms.h    |   3 +
 .../Conversion/VectorToLLVM/CMakeLists.txt    |   1 +
 .../VectorToLLVM/ConvertVectorToLLVMPass.cpp  |   7 +
 .../LowerContractionToSMMLAPattern.cpp        |   5 +-
 .../Dialect/ArmSVE/Transforms/CMakeLists.txt  |   1 +
 .../LowerContractionToSVEI8MMPattern.cpp      | 304 ++++++++++++++++++
 .../Vector/CPU/ArmSVE/vector-smmla.mlir       |  94 ++++++
 .../Vector/CPU/ArmSVE/vector-summla.mlir      |  85 +++++
 .../Vector/CPU/ArmSVE/vector-ummla.mlir       |  94 ++++++
 .../Vector/CPU/ArmSVE/vector-usmmla.mlir      |  95 ++++++
 .../CPU/ArmSVE/contraction-smmla-4x8x4.mlir   | 117 +++++++
 .../ArmSVE/contraction-smmla-8x8x8-vs2.mlir   | 159 +++++++++
 .../CPU/ArmSVE/contraction-summla-4x8x4.mlir  | 118 +++++++
 .../CPU/ArmSVE/contraction-ummla-4x8x4.mlir   | 119 +++++++
 .../CPU/ArmSVE/contraction-usmmla-4x8x4.mlir  | 117 +++++++
 16 files changed, 1322 insertions(+), 1 deletion(-)
 create mode 100644 mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp
 create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir
 create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-summla.mlir
 create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-ummla.mlir
 create mode 100644 mlir/test/Dialect/Vector/CPU/ArmSVE/vector-usmmla.mlir
 create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir
 create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-8x8x8-vs2.mlir
 create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-summla-4x8x4.mlir
 create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-ummla-4x8x4.mlir
 create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-usmmla-4x8x4.mlir

diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index bbba495e613b2..930d8b44abca0 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1406,6 +1406,10 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
            "bool", /*default=*/"false",
            "Enables the use of ArmSVE dialect while lowering the vector "
        "dialect.">,
+    Option<"armI8MM", "enable-arm-i8mm",
+           "bool", /*default=*/"false",
+           "Enables the use of Arm FEAT_I8MM instructions while lowering "
+           "the vector dialect.">,
     Option<"x86Vector", "enable-x86vector",
            "bool", /*default=*/"false",
            "Enables the use of X86Vector dialect while lowering the vector "
diff --git a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h
index 8665c8224cc45..232e2be29e574 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h
@@ -20,6 +20,9 @@ class RewritePatternSet;
 void populateArmSVELegalizeForLLVMExportPatterns(
     const LLVMTypeConverter &converter, RewritePatternSet &patterns);
 
+void populateLowerContractionToSVEI8MMPatternPatterns(
+    RewritePatternSet &patterns);
+
 /// Configure the target to support lowering ArmSVE ops to ops that map to LLVM
 /// intrinsics.
 void configureArmSVELegalizeForExportTarget(LLVMConversionTarget &target);
diff --git a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
index 330474a718e30..8e2620029c354 100644
--- a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
@@ -35,6 +35,7 @@ add_mlir_conversion_library(MLIRVectorToLLVMPass
   MLIRVectorToLLVM
 
   MLIRArmNeonDialect
+  MLIRArmNeonTransforms
   MLIRArmSVEDialect
   MLIRArmSVETransforms
   MLIRAMXDialect
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 7082b92c95d1d..1e6c8122b1d0e 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -14,6 +14,7 @@
 #include "mlir/Dialect/AMX/Transforms.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
+#include "mlir/Dialect/ArmNeon/Transforms.h"
 #include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
 #include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -82,6 +83,12 @@ void ConvertVectorToLLVMPass::runOnOperation() {
     populateVectorStepLoweringPatterns(patterns);
     populateVectorRankReducingFMAPattern(patterns);
     populateVectorGatherLoweringPatterns(patterns);
+    if (armI8MM) {
+      if (armNeon)
+        arm_neon::populateLowerContractionToSMMLAPatternPatterns(patterns);
+      if (armSVE)
+        populateLowerContractionToSVEI8MMPatternPatterns(patterns);
+    }
     (void)applyPatternsGreedily(getOperation(), std::move(patterns));
   }
 
diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
index 2a1271dfd6bdf..e807b233aa7aa 100644
--- a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
+++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
@@ -56,6 +56,9 @@ class LowerContractionToSMMLAPattern
     // Avoid 0-D vectors and 1-D rhs:
     if (!lhsType.hasRank() || !rhsType.hasRank() || rhsType.getRank() < 2)
       return failure();
+    // Avoid scalable vectors.
+    if (lhsType.isScalable() || rhsType.isScalable())
+      return failure();
     auto dimM = lhsType.getRank() == 1 ? 1 : lhsType.getDimSize(0);
     auto dimN = rhsType.getDimSize(0);
     auto dimK = rhsType.getDimSize(1);
@@ -238,5 +241,5 @@ class LowerContractionToSMMLAPattern
 void mlir::arm_neon::populateLowerContractionToSMMLAPatternPatterns(
     RewritePatternSet &patterns) {
   MLIRContext *context = patterns.getContext();
-  patterns.add<LowerContractionToSMMLAPattern>(context, /*benefit=*/1);
+  patterns.add<LowerContractionToSMMLAPattern>(context, /*benefit=*/2);
 }
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt
index a70c489a51fea..65f98b44b1b69 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
 add_mlir_dialect_library(MLIRArmSVETransforms
   LegalizeForLLVMExport.cpp
   LegalizeVectorStorage.cpp
+  LowerContractionToSVEI8MMPattern.cpp
 
   DEPENDS
   MLIRArmSVEConversionsIncGen
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp
new file mode 100644
index 0000000000000..c0620c71440bc
--- /dev/null
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp
@@ -0,0 +1,304 @@
+//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements lowering patterns from vector.contract to
+// SVE I8MM operations.
+//
+//===---
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
+#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#include "mlir/Dialect/UB/IR/UBOps.h"
+
+#define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm"
+
+using namespace mlir;
+using namespace mlir::arm_sve;
+
+namespace {
+// Check if the given value is a result of the operation `T` (which must be
+// sign- or zero- extend) from i8 to i32. Return the value before the extension.
+template <typename T>
+inline std::enable_if_t<(std::is_base_of_v<arith::ExtSIOp, T> ||
+                         std::is_base_of_v<arith::ExtUIOp, T>),
+                        std::optional<Value>>
+extractExtOperand(Value v, Type i8Ty, Type i32Ty) {
+  auto extOp = dyn_cast_or_null<T>(v.getDefiningOp());
+  if (!extOp)
+    return {};
+
+  auto inOp = extOp.getIn();
+  auto inTy = dyn_cast<VectorType>(inOp.getType());
+  if (!inTy || inTy.getElementType() != i8Ty)
+    return {};
+
+  auto outTy = dyn_cast<VectorType>(extOp.getType());
+  if (!outTy || outTy.getElementType() != i32Ty)
+    return {};
+
+  return inOp;
+}
+
+// Designate the operation (resp. instruction) used to do sub-tile matrix
+// multiplications.
+enum class MMLA {
+  Signed,      // smmla
+  Unsigned,    // ummla
+  Mixed,       // usmmla
+  MixedSwapped // usmmla with LHS and RHS swapped
+};
+
+// Create the matrix multply and accumulate operation according to `op`.
+Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc,
+                 mlir::VectorType accType, Value acc, Value lhs, Value rhs) {
+  switch (op) {
+  case MMLA::Signed:
+    return rewriter.create<arm_sve::SmmlaOp>(loc, accType, acc, lhs, rhs);
+  case MMLA::Unsigned:
+    return rewriter.create<arm_sve::UmmlaOp>(loc, accType, acc, lhs, rhs);
+  case MMLA::Mixed:
+    return rewriter.create<arm_sve::UsmmlaOp>(loc, accType, acc, lhs, rhs);
+  case MMLA::MixedSwapped:
+    // The accumulator comes transposed and the result will be transposed
+    // later, so all we have to do here is swap the operands.
+    return rewriter.create<arm_sve::UsmmlaOp>(loc, accType, acc, rhs, lhs);
+  }
+}
+
+class LowerContractionToSVEI8MMPattern
+    : public OpRewritePattern<vector::ContractionOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(vector::ContractionOp op,
+                                PatternRewriter &rewriter) const override {
+
+    Location loc = op.getLoc();
+    mlir::VectorType lhsType = op.getLhsType();
+    mlir::VectorType rhsType = op.getRhsType();
+
+    // For now handle LHS<Mx8> and RHS<8x[N]> - these are the types we
+    // eventually expect from MMT4D. M and N dimensions must be even and at
+    // least 2.
+    if (!lhsType.hasRank() || lhsType.getRank() != 2 || !rhsType.hasRank() ||
+        rhsType.getRank() != 2)
+      return failure();
+
+    if (lhsType.isScalable() || !rhsType.isScalable())
+      return failure();
+
+    // M, N, and K are the conventional names for matrix dimensions in the
+    // context of matrix multiplication.
+    auto M = lhsType.getDimSize(0);
+    auto N = rhsType.getDimSize(0);
+    auto K = rhsType.getDimSize(1);
+
+    if (lhsType.getDimSize(1) != K || K != 8 || M < 2 || M % 2 != 0 || N < 2 ||
+        N % 2 != 0 || !rhsType.getScalableDims()[0])
+      return failure();
+
+    // Check permutation maps. For now only accept
+    //   lhs: (d0, d1, d2) -> (d0, d2)
+    //   rhs: (d0, d1, d2) -> (d1, d2)
+    //   acc: (d0, d1, d2) -> (d0, d1)
+    // Note: RHS is transposed.
+    if (op.getIndexingMapsArray()[0] !=
+            AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 2u},
+                                                 op.getContext()) ||
+        op.getIndexingMapsArray()[1] !=
+            AffineMap::getMultiDimMapWithTargets(3, ArrayRef{1u, 2u},
+                                                 op.getContext()) ||
+        op.getIndexingMapsArray()[2] !=
+            AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 1u},
+                                                 op.getContext()))
+      return failure();
+
+    // Check iterator types for matrix multiplication.
+    auto itTypes = op.getIteratorTypesArray();
+    if (itTypes.size() != 3 || itTypes[0] != vector::IteratorType::parallel ||
+        itTypes[1] != vector::IteratorType::parallel ||
+        itTypes[2] != vector::IteratorType::reduction)
+      return failure();
+
+    // Check the combining kind is addition.
+    if (op.getKind() != vector::CombiningKind::ADD)
+      return failure();
+
+    // Check the output is a vector of i32 elements.
+    auto outTy = dyn_cast<VectorType>(op.getType());
+    if (!outTy || outTy.getElementType() != rewriter.getI32Type())
+      return failure();
+
+    // Check inputs are sign-/zero- extensions from i8 to i32. Get the values
+    // before the extension. All four signed/unsigned combinations for input
+    // operands are supported, but they are lowered to different operations.
+    // Determina which is the appropriate operation to lower to.
+    MMLA mmlaOp = MMLA::Signed;
+    auto maybeLhs = extractExtOperand<arith::ExtSIOp>(
+        op.getLhs(), rewriter.getI8Type(), rewriter.getI32Type());
+    if (!maybeLhs) {
+      mmlaOp = MMLA::Unsigned;
+      maybeLhs = extractExtOperand<arith::ExtUIOp>(
+          op.getLhs(), rewriter.getI8Type(), rewriter.getI32Type());
+    }
+    if (!maybeLhs)
+      return failure();
+
+    auto maybeRhs = extractExtOperand<arith::ExtSIOp>(
+        op.getRhs(), rewriter.getI8Type(), rewriter.getI32Type());
+    if (maybeRhs) {
+      if (mmlaOp == MMLA::Unsigned)
+        mmlaOp = MMLA::Mixed;
+    } else {
+      if (mmlaOp == MMLA::Signed)
+        mmlaOp = MMLA::MixedSwapped;
+      maybeRhs = extractExtOperand<arith::ExtUIOp>(
+          op.getRhs(), rewriter.getI8Type(), rewriter.getI32Type());
+    }
+    if (!maybeRhs)
+      return failure();
+
+    // One-dimensional vector types for arm_sve.*mmla
+    auto nxv16i8 = VectorType::get(16, rewriter.getI8Type(), {true});
+    auto nxv4i32 = VectorType::get(4, rewriter.getI32Type(), {true});
+
+    // Extract LHS sub-tiles.
+    SmallVector<Value> lhsTile;
+    for (int64_t i = 0; i < M; i += 2) {
+      // Exract two consective rows of the LHS tile.
+      auto r0 = rewriter.create<vector::ExtractOp>(loc, *maybeLhs,
+                                                   ArrayRef<int64_t>{i});
+      auto r1 = rewriter.create<vector::ExtractOp>(loc, *maybeLhs,
+                                                   ArrayRef<int64_t>{i + 1});
+      // Concatenate to obtain a 16 x i8 flattened sub-tile.
+      auto t = rewriter.create<vector::ShuffleOp>(
+          loc, r0, r1,
+          llvm::ArrayRef<int64_t>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
+                                  14, 15});
+      // Turn it into a scalable vector.
+      auto s = rewriter.create<vector::ScalableInsertOp>(
+          loc, t, rewriter.create<ub::PoisonOp>(loc, nxv16i8), 0);
+      // Replicate the sub-tile VSCALE times to fill the entire vector.
+      auto r = rewriter.create<arm_sve::DupQLaneOp>(loc, s, 0);
+      lhsTile.push_back(r);
+    }
+
+    // "Flatten" the RHS tile from <[N]x8> to <[8*N]>.
+    auto RHS = rewriter.create<vector::ShapeCastOp>(
+        maybeRhs->getLoc(),
+        VectorType::get(8 * N, rewriter.getI8Type(), {true}), *maybeRhs);
+
+    // Extract the RHS sub-tiles.
+    SmallVector<Value> rhsTile;
+    for (int64_t j = 0; j < N; j += 2)
+      rhsTile.push_back(
+          rewriter.create<vector::ScalableExtractOp>(loc, nxv16i8, RHS, j * 8));
+
+    // Handy types for packing/unpacking of the accumulator tile.
+    auto accRowTy = VectorType::get(N, rewriter.getI32Type(), {true});
+    auto accRowX2Ty = VectorType::get(2 * N, rewriter.getI32Type(), {true});
+    auto accRow64Ty = VectorType::get(N / 2, rewriter.getI64Type(), {true});
+    auto accRowX264Ty = VectorType::get(N, rewriter.getI64Type(), {true});
+
+    // Extract and pack the ACC sub-tiles.
+    SmallVector<Value> accTile;
+    for (int64_t i = 0; i < M; i += 2) {
+      // Extract two consecutive rows of the accumulator tile.
+      auto r0 = rewriter.create<vector::ExtractOp>(loc, op.getAcc(),
+                                                   ArrayRef<int64_t>{i});
+      auto r1 = rewriter.create<vector::ExtractOp>(loc, op.getAcc(),
+                                                   ArrayRef<int64_t>{i + 1});
+      Value accTileVec;
+      if (mmlaOp == MMLA::MixedSwapped) {
+        // We need to swap the positions of the LHS and RHS (since we don't have
+        // a signed * unsigned operation), but then each individual 2x2 tile of
+        // the acumulator and (later) the result need to be transposed.
+        accTileVec = rewriter.create<vector::InterleaveOp>(loc, r0, r1);
+      } else {
+        // Bitcast them to 64-bit elements, so subsequent
+        // interleave/deinterleave work on pairs of 32-bit numbers.
+        auto r0_i64 = rewriter.create<vector::BitCastOp>(loc, accRow64Ty, r0);
+        auto r1_i64 = rewriter.create<vector::BitCastOp>(loc, accRow64Ty, r1);
+
+        // Interleave the rows, effectively flattening each 2x2 tile into 4
+        // consecutive elements.
+        auto intr_i64 =
+            rewriter.create<vector::InterleaveOp>(loc, r0_i64, r1_i64);
+
+        // Bitcast back to 32-bit elements.
+        accTileVec =
+            rewriter.create<vector::BitCastOp>(loc, accRowX2Ty, intr_i64);
+      }
+      // Extract ACC sub-tiles.
+      for (int64_t j = 0; j < N; j += 2)
+        accTile.push_back(rewriter.create<vector::ScalableExtractOp>(
+            loc, nxv4i32, accTileVec, j * 2));
+    }
+
+    // Emit sub-tile matrix multiplications.
+    SmallVector<Value> outTile;
+    for (int64_t i = 0; i < M / 2; ++i)
+      for (int64_t j = 0; j < N / 2; ++j) {
+        Value mmla = createMMLA(rewriter, mmlaOp, loc, nxv4i32,
+                                accTile[i * N / 2 + j], lhsTile[i], rhsTile[j]);
+        outTile.push_back(mmla);
+      }
+
+    // Unpack the OUT sub-tiles and insert into the result.
+    Value result = rewriter.create<ub::PoisonOp>(loc, op.getResultType());
+    for (int64_t i = 0; i < M / 2; ++i) {
+      // Collect a number of sub-tiles in a row.
+      Value row = rewriter.create<ub::PoisonOp>(loc, accRowX2Ty);
+      for (int64_t j = 0; j < N / 2; ++j)
+        row = rewriter.create<vector::ScalableInsertOp>(
+            loc, outTile[i * N / 2 + j], row, j * 4);
+
+      // Unpack the row to obtain two rows of the output. If we have the out
+      // sub-tiles transposed we obtain two consecutive output rows by
+      // separating even and odd elements, i.e. a simple deinterleave.
+      // Otherwise, the interleave is by pairs.
+      Value out0, out1;
+      if (mmlaOp == MMLA::MixedSwapped) {
+        auto tmp = rewriter.create<vector::DeinterleaveOp>(loc, row);
+        out0 = tmp.getRes1();
+        out1 = tmp.getRes2();
+      } else {
+        // Deinterleave by pairs.
+        auto row64 = rewriter.create<vector::BitCastOp>(loc, accRowX264Ty, row);
+        auto deintr64 = rewriter.create<vector::DeinterleaveOp>(loc, row64);
+
+        // Bitcast back into 32-bit elements and insert into the result.
+        out0 = rewriter.create<vector::BitCastOp>(loc, accRowTy,
+                                                  deintr64.getRes1());
+        out1 = rewriter.create<vector::BitCastOp>(loc, accRowTy,
+                                                  deintr64.getRes2());
+      }
+      result = rewriter.create<vector::InsertOp>(loc, out0, result, i * 2);
+      result = rewriter.create<vector::InsertOp>(loc, out1, result, i * 2 + 1);
+    }
+
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+};
+
+} // namespace
+
+void mlir::populateLowerContractionToSVEI8MMPatternPatterns(
+    RewritePatternSet &patterns) {
+  MLIRContext *context = patterns.getContext();
+  patterns.add<LowerContractionToSVEI8MMPattern>(context, /*benefit=*/2);
+}
diff --git a/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir b/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir
new file mode 100644
index 0000000000000..2535ee9181c13
--- /dev/null
+++ b/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir
@@ -0,0 +1,94 @@
+// RUN:  mlir-opt %s --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' --split-input-file | FileCheck %s
+
+#packed_maps = [
+  affine_map<(d0, d1, d2) -> (d0, d2)>,
+  affine_map<(d0, d1, d2) -> (d1, d2)>,
+  affine_map<(d0, d1, d2) -> (d0, d1)>
+]
+
+// CHECK-LABEL: @test_vector_contract_to_smmla
+
+// Extract LHS rows 0 and 1, concatenate, turn into scalable vector
+// CHECK: %[[T6:[0-9]+]] = llvm.extractvalue %[[T4:[0-9]+]][0] : !llvm.array<4 x vector<8xi8>>
+// CHECK-NEXT: %[[T7:[0-9]+]] = llvm.extractvalue %[[T4]][1] : !llvm.array<4 x vector<8xi8>>
+// CHECK-NEXT: %[[T8:[0-9]+]] = llvm.shufflevector %[[T6]], %[[T7]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8>
+// CHECK-NEXT: %[[T9:[0-9]+]] = llvm.intr.vector.insert %[[T8]], %[[T0:[0-9+]]][0] : vector<16xi8> into vector<[16]xi8>
+
+// Replicate across the entire length of the scalabale vector
+// CHECK-NEXT: %[[T10:[0-9]+]] = "arm_sve.intr.dupq_lane"(%[[T9]]) <{lane = 0 : i64}> : (vector<[16]xi8>) -> vector<[16]xi8>
+
+// Same for LHS rows 2 and 4
+// CHECK-NEXT: %[[T11:[0-9]+]] = llvm.extractvalue %[[T4]][2] : !llvm.array<4 x vector<8xi8>>
+// CHECK-NEXT: %[[T12:[0-9]+]] = llvm.extractvalue %[[T4]][3] : !llvm.array<4 x vector<8xi8>>
+// CHECK-NEXT: %[[T13:[0-9]+]] = llvm.shufflevector %[[T11]], %[[T12]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8>
+// CHECK-NEXT: %[[T14:[0-9]+]] = llvm.intr.vector.insert %[[T13]], %[[T0]][0] : vector<16xi8> into vector<[16]xi8>
+// CHECK-NEXT: %[[T15:[0-9]+]] = "arm_sve.intr.dupq_lane"(%[[T14]]) <{lane = 0 : i64}> : (vector<[16]xi8>) -> vector<[16]xi8>
+
+// Extract sub-tiles from the RHS
+// CHECK-NEXT: %[[T16:[0-9]+]] = vector.shape_cast %arg1 : vector<[4]x8xi8> to vector<[32]xi8>
+// CHECK-NEXT: %[[T17:[0-9]+]] = llvm.intr.vector.extract %[[T16]][0] : vector<[16]xi8> from vector<[32]xi8>
+// CHECK-NEXT: %[[T18:[0-9]+]] = llvm.intr.vector.extract %[[T16]][16] : vector<[16]xi8> from vector<[32]xi8>
+
+// Extract accumulator rows 0 and 1 and pack (into "registers")
+// CHECK-NEXT: %[[T19:[0-9]+]] = llvm.extractvalue %[[T3:[0-9]+]][0] : !llvm.array<4 x vector<[4]xi32>>
+// CHECK-NEXT: %[[T20:[0-9]+]] = llvm.extractvalue %[[T3]][1] : !llvm.array<4 x vector<[4]xi32>>
+// CHECK-NEXT: %[[T21:[0-9]+]] = llvm.bitcast %[[T19]] : vector<[4]xi32> to vector<[2]xi64>
+// CHECK-NEXT: %[[T22:[0-9]+]] = llvm.bitcast %[[T20]] : vector<[4]xi32> to vector<[2]xi64>
+// CHECK-NEXT: %[[T23:[0-9]+]] = "llvm.intr.vector.interleave2"(%[[T21]], %[[T22]]) : (vector<[2]xi64>, vector<[2]xi64>) -> vector<[4]xi64>
+// CHECK-NEXT: %[[T24:[0-9]+]] = llvm.bitcast %[[T23]] : vector<[4]xi64> to vector<[8]xi32>
+// CHECK-NEXT: %[[T25:[0-9]+]] = llvm.intr.vector.extract %[[T24]][0] : vector<[4]xi32> from vector<[8]xi32>
+// CHECK-NEXT: %[[T26:[0-9]+]] = llvm.intr.vector.extract %[[T24]][4] : vector<[4]xi32> from vector<[8]xi32>
+
+// Same for accumulator rows 2 and 3.
+// CHECK-NEXT: %[[T27:[0-9]+]] = llvm.extractvalue %[[T3]][2] : !llvm.array<4 x vector<[4]xi32>>
+// CHECK-NEXT: %[[T28:[0-9]+]] = llvm.extractvalue %[[T3]][3] : !llvm.array<4 x vector<[4]xi32>>
+// CHECK-NEXT: %[[T29:[0-9]+]] = llvm.bitcast %[[T27]] : vector<[4]xi32> to vector<[2]xi64>
+// CHECK-NEXT: %[[T30:[0-9]+]] = llvm.bitcast %[[T28]] : vector<[4]xi32> to vector<[2]xi64>
+// CHECK-NEXT: %[[T31:[0-9]+]] = "llvm.intr.vector.interleave2"(%[[T29]], %[[T30]]) : (vector<[2]xi64>, vector<[2]xi64>) -> vector<[4]xi64>
+// CHECK-NEXT: %[[T32:[0-9]+]] = llvm.bitcast %[[T31]] : vector<[4]xi64> to vector<[8]xi32>
+// CHECK-NEXT: %[[T33:[0-9]+]] = llvm.intr.vector.extract %[[T32]][0] : vector<[4]xi32> from vector<[8]xi32>
+// CHECK-NEXT: %[[T34:[0-9]+]] = llvm.intr.vector.extract %[[T32]][4] : vector<[4]xi32> from vector<[8]xi32>
+
+// Do the sub-tile matrix multiplications
+// CHECK-NEXT: %[[T35:[0-9]+]] = "arm_sve.intr.smmla"(%[[T25]], %[[T10]], %[[T17]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32>
+// CHECK-NEXT: %[[T36:[0-9]+]] = "arm_sve.intr.smmla"(%[[T26]], %[[T10]], %[[T18]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32>
+// CHECK-NEXT: %[[T37:[0-9]+]] = "arm_sve.intr.smmla"(%[[T33]], %[[T15]], %[[T17]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32>
+// CHECK-NEXT: %[[T38:[0-9]+]] = "arm_sve.intr.smmla"(%[[T34]], %[[T15]], %[[T18]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32>
+
+// Unpack (from "registers") and insert in the output result rows  0 and 1
+// CHECK-NEXT: %[[T39:[0-9]+]] = llvm.intr.vector.insert %[[T35]], %[[T2:[0-9]+]][0] : vector<[4]xi32> into vector<[8]xi32>
+// CHECK-NEXT: %[[T40:[0-9]+]] = llvm.intr.vector.insert %[[T36]], %[[T39]][4] : vector<[4]xi32> into vector<[8]xi32>
+// CHECK-NEXT: %[[T41:[0-9]+]] = llvm.bitcast %[[T40]] : vector<[8]xi32> to vector<[4]xi64>
+// CHECK-NEXT: %[[T42:[0-9]+]] = "llvm.intr.vector.deinterleave2"(%[[T41]]) : (vector<[4]xi64>) -> !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)>
+// CHECK-NEXT: %[[T43:[0-9]+]] = llvm.extractvalue %[[T42]][0] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)>
+// CHECK-NEXT: %[[T44:[0-9]+]] = llvm.extractvalue %[[T42]][1] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)>
+// CHECK-NEXT: %[[T45:[0-9]+]] = llvm.bitcast %[[T43]] : vector<[2]xi64> to vector<[4]xi32>
+// CHECK-NEXT: %[[T46:[0-9]+]] = llvm.bitcast %[[T44]] : vector<[2]xi64> to vector<[4]xi32>
+// CHECK-NEXT: %[[T47:[0-9]+]] = llvm.insertvalue %[[T45]], %[[T5:[0-9]+]][0] : !llvm.array<4 x vector<[4]xi32>>
+// CHECK-NEXT: %[[T48:[0-9]+]] = llvm.insertvalue %[[T46]], %[[T47]][1] : !llvm.array<4 x vector<[4]xi32>>
+
+// Same for result rows 2 and 3
+// CHECK-NEXT: %[[T49:[0-9]+]] = llvm.intr.vector.insert %[[T37]], %[[T2]][0] : vector<[4]xi32> into vector<[8]xi32>
+// CHECK-NEXT: %[[T50:[0-9]+]] = llvm.intr.vector.insert %[[T38]], %[[T49]][4] : vector<[4]xi32> into vector<[8]xi32>
+// CHECK-NEXT: %[[T51:[0-9]+]] = llvm.bitcast %[[T50]] : vector<[8]xi32> to vector<[4]xi64>
+// CHECK-NEXT: %[[T52:[0-9]+]] = "llvm.intr.vector.deinterleave2"(%[[T51]]) : (vector<[4]xi64>) -> !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)>
+// CHECK-NEXT: %[[T53:[0-9]+]] = llvm.extractvalue %[[T52]][0] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)>
+// CHECK-NEXT: %[[T54:[0-9]+]] = llvm.extractvalue %[[T52]][1] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)>
+// CHECK-NEXT: %[[T55:[0-9]+]] = llvm.bitcast %[[T53]] : vector<[2]xi64> to vector<[4]xi32>
+// CHECK-NEXT: %[[T56:[0-9]+]] = llvm.bitcast %[[T54]] : vector<[2]xi64> to vector<[4]xi32>
+// CHECK-NEXT: %[[T57:[0-9]+]] = llvm.insertvalue %[[T55]], %[[T48]][2] : !llvm.array<4 x vector<[4]xi32>>
+// CHECK-NEXT: %[[T58:[0-9]+]] = llvm.insertvalue %[[T56]], %[[T57]][3] : !llvm.array<4 x vector<[4]xi32>>
+
+func.func @test_vector_contract_to_smmla(%lhs: vector<4x8xi8>,
+              %rhs: vector<[4]x8xi8>,
+              %acc: vector<4x[4]xi32>) -> vector<4x[4]xi32> {
+
+  %0 = arith.extsi %lhs : vector<4x8xi8> to vector<4x8xi32>
+  %1 = arith.extsi %rhs : vector<[4]x8xi8> to vector<[4]x8xi32>
+  %2 = vector.contract {indexing_maps = #packed_maps,
+                        iterator_types = ["parallel", "parallel", "reduction"],
+                        kind = #vector.kind<add>} %0, %1, %acc
+    : vector<4x8xi32>, vector<[4]x8xi32> into vector<4x[4]xi32>
+
+  return %2 : vector<4x[4]xi32>
+}
diff --git a/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-summla.mlir b/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-summla.mlir
new file mode 100644
index 0000000000000..b6285d068b0f8
--- /dev/null
+++ b/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-summla.mlir
@@ -0,0 +1,85 @@
+// RUN:  mlir-opt %s --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' --split-input-file | FileCheck %s
+
+#packed_maps = [
+  affine_map<(d0, d1, d2) -> (d0, d2)>,
+  affine_map<(d0, d1, d2) -> (d1, d2)>,
+  affine_map<(d0, d1, d2) -> (d0, d1)>
+]
+
+// CHECK-LABEL: @test_vector_contract_to_usmmla_rev
+
+// Extract LHS rows 0 and 1, concatenate, turn into scalable vector
+// CHECK:      %[[T6:[0-9]+]] = llvm.extractvalue %[[T1:[0-9]+]][0] : !llvm.array<4 x vector<8xi8>>
+// CHECK-NEXT: %[[T7:[0-9]+]] = llvm.extractvalue %[[T1]][1] : !llvm.array<4 x vector<8xi8>>
+// CHECK-NEXT: %[[T8:[0-9]+]] = llvm.shufflevector %[[T6]], %[[T7]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8>
+// CHECK-NEXT: %[[T9:[0-9]+]] = llvm.intr.vector.insert %[[T8]], %[[T5:[0-9]+]][0] : vector<16xi8> into vector<[16]xi8>
+
+// Replicate across the entire length of the scalabale vector
+// CHECK-NEXT: %[[T10:[0-9]+]] = "arm_sve.intr.dupq_lane"(%[[T9]]) <{lane = 0 : i64}> : (vector<[16]xi8>) -> vector<[16]xi8>
+
+// Same for LHS rows 2 and 4
+// CHECK-NEXT: %[[T11:[0-9]+]] = llvm.extractvalue %[[T1]][2] : !llvm.array<4 x vector<8xi8>>
+// CHECK-NEXT: %[[T12:[0-9]+]] = llvm.extractvalue %[[T1]][3] : !llvm.array<4 x vector<8xi8>>
+// CHECK-NEXT: %[[T13:[0-9]+]] = llvm.shufflevector %[[T11]], %[[T12]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8>
+// CHECK-NEXT: %[[T14:[0-9]+]] = llvm.intr.vector.insert %[[T13]], %[[T5]][0] : vector<16xi8> into vector<[16]xi8>
+// CHECK-NEXT: %[[T15:[0-9]+]] = "arm_sve.intr.dupq_lane"(%[[T14]]) <{lane = 0 : i64}> : (vector<[16]xi8>) -> vector<[16]xi8>
+
+
+// Extract sub-tiles from the RHS
+// CHECK-NEXT: %[[T16:[0-9]+]] = vector.shape_cast %arg1 : vector<[4]x8xi8> to vector<[32]xi8>
+// CHECK-NEXT: %[[T17:[0-9]+]] = llvm.intr.vector.extract %[[T16]][0] : vector<[16]xi8> from vector<[32]xi8>
+// CHECK-NEXT: %[[T18:[0-9]+]] = llvm.intr.vector.extract %[[T16]][16] : vector<[16]xi8> from vector<[32]xi8>
+
+// Extract accumulator rows 0 and 1 and pack (into "registers")
+// CHECK-NEXT: %[[T19:[0-9]+]] = llvm.extractvalue %[[T0:[0-9]+]][0] : !llvm.array<4 x vector<[4]xi32>>
+// CHECK-NEXT: %[[T20:[0-9]+]] = llvm.extractvalue %[[T0]][1] : !llvm.array<4 x vector<[4]xi32>>
+// CHECK-NEXT: %[[T21:[0-9]+]] = "llvm.intr.vector.interleave2"(%[[T19]], %[[T20]]) : (vector<[4]xi32>, vector<[4]xi32>) -> vector<[8]xi32>
+// CHECK-NEXT: %[[T22:[0-9]+]] = llvm.intr.vector.extract %[[T21]][0] : vector<[4]xi32> from vector<[8]xi32>
+// CHECK-NEXT: %[[T23:[0-9]+]] = llvm.intr.vector.extract %[[T21]][4] : vector<[4]xi32> from vector<[8]xi32>
+
+// Same for accumulator rows 2 and 3.
+// CHECK-NEXT: %[[T24:[0-9]+]] = llvm.extractvalue %[[T0]][2] : !llvm.array<4 x vector<[4]xi32>>
+// CHECK-NEXT: %[[T25:[0-9]+]] = llvm.extractvalue %[[T0]][3] : !llvm.array<4 x vector<[4]xi32>>
+// CHECK-NEXT: %[[T26:[0-9]+]] = "llvm.intr.vector.interleave2"(%[[T24]], %[[T25]]) : (vector<[4]xi32>, vector<[4]xi32>) -> vector<[8]xi32>
+// CHECK-NEXT: %[[T27:[0-9]+]] = llvm.intr.vector.extract %[[T26]][0] : vector<[4]xi32> from vector<[8]xi32>
+// CHECK-NEXT: %[[T28:[0-9]+]] = llvm.intr.vector.extract %[[T26]][4] : vector<[4]xi32> from vector<[8]xi32>
+
+// Do the sub-tile matrix multiplications
+// CHECK-NEXT: %[[T29:[0-9]+]] = "arm_sve.intr.usmmla"(%[[T22]], %[[T17]], %[[T10]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32>
+// CHECK-NEXT: %[[T30:[0-9]+]] = "arm_sve.intr.usmmla"(%[[T23]], %[[T18]], %[[T10]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32>
+// CHECK-NEXT: %[[T31:[0-9]+]] = "arm_sve.intr.usmmla"(%[[T27]], %[[T17]], %[[T15]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32>
+// CHECK-NEXT: %[[T32:[0-9]+]] = "arm_sve.intr.usmmla"(%[[T28]], %[[T18]], %[[T15]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32>
+
+// Unpack (from "registers") and insert in the output result rows  0 and 1
+// CHECK-NEXT: %[[T33:[0-9]+]] = llvm.intr.vector.insert %[[T29]], %[[T2:[0-9]+]][0] : vector<[4]xi32> into vector<[8]xi32>
+// CHECK-NEXT: %[[T34:[0-9]+]] = llvm.intr.vector.insert %[[T30]], %[[T33]][4] : vector<[4]xi32> into vector<[8]xi32>
+// CHECK-NEXT: %[[T35:[0-9]+]] = "llvm.intr.vector.deinterleave2"(%[[T34]]) : (vector<[8]xi32>) -> !llvm.struct<(vector<[4]xi32>, vector<[4]xi32>)>
+// CHECK-NEXT: %[[T36:[0-9]+]] = llvm.extractvalue %[[T35]][0] : !llvm.struct<(vector<[4]xi32>, vector<[4]xi32>)>
+// CHECK-NEXT: %[[T37:[0-9]+]] = llvm.extractvalue %[[T35]][1] : !llvm.struct<(vector<[4]xi32>, vector<[4]xi32>)>
+// CHECK-NEXT: %[[T38:[0-9]+]] = llvm.insertvalue %[[T36]], %[[T4:[0-9]+]][0] : !llvm.array<4 x vector<[4]xi32>>
+// CHECK-NEXT: %[[T39:[0-9]+]] = llvm.insertvalue %[[T37]], %[[T38]][1] : !llvm.array<4 x vector<[4]xi32>>
+
+// Same for result rows 2 and 3
+// CHECK-NEXT: %[[T40:[0-9]+]] = llvm.intr.vector.insert %[[T31]], %[[T2]][0] : vector<[4]xi32> into vector<[8]xi32>
+// CHECK-NEXT: %[[T41:[0-9]+]] = llvm.intr.vector.insert %[[T32]], %[[T40]][4] : vector<[4]xi32> into vector<[8]xi32>
+// CHECK-NEXT: %[[T42:[0-9]+]] = "llvm.intr.vector.deinterleave2"(%[[T41]]) : (vector<[8]xi32>) -> !llvm.struct<(vector<[4]xi32>, vector<[4]xi32>)>
+// CHECK-NEXT: %[[T43:[0-9]+]] = llvm.extractvalue %[[T42]][0] : !llvm.struct<(vector<[4]xi32>, vector<[4]xi32>)>
+// CHECK-NEXT: %[[T44:[0-9]+]] = llvm.extractvalue %[[T42]][1] : !llvm.struct<(vector<[4]xi32>, vector<[4]xi32>)>
+// CHECK-NEXT: %[[T45:[0-9]+]] = llvm.insertvalue %[[T43]], %[[T39]][2] : !llvm.array<4 x vector<[4]xi32>>
+// CHECK-NEXT: %[[T46:[0-9]+]] = llvm.insertvalue %[[T44]], %[[T45]][3] : !llvm.array<4 x vector<[4]xi32>>
+// CHECK-NEXT: %[[T47:[0-9]+]] = builtin.unrealized_conversion_cast %[[T46]] : !llvm.array<4 x vector<[4]xi32>> to vector<4x[4]xi32>
+
+func.func @test_vector_contract_to_usmmla_rev(
+  %lhs: vector<4x8xi8>,
+  %rhs: vector<[4]x8xi8>,
+  %acc: vector<4x[4]xi32>) -> vector<4x[4]xi32> {
+
+  %0 = arith.extsi %lhs : vector<4x8xi8> to vector<4x8xi32>
+  %1 = arith.extui %rhs : vector<[4]x8xi8> to vector<[4]x8xi32>
+  %2 = vector.contract {indexing_maps = #packed_maps,
+                        iterator_types = ["parallel", "parallel", "reduction"],
+                        kind = #vector.kind<add>} %0, %1, %acc
+    : vector<4x8xi32>, vector<[4]x8xi32> into vector<4x[4]xi32>
+
+  return %2 : vector<4x[4]xi32>
+}
diff --git a/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-ummla.mlir b/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-ummla.mlir
new file mode 100644
index 0000000000000..cde57842295f7
--- /dev/null
+++ b/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-ummla.mlir
@@ -0,0 +1,94 @@
+// RUN:  mlir-opt %s --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' --split-input-file | FileCheck %s
+
+#packed_maps = [
+  affine_map<(d0, d1, d2) -> (d0, d2)>,
+  affine_map<(d0, d1, d2) -> (d1, d2)>,
+  affine_map<(d0, d1, d2) -> (d0, d1)>
+]
+
+// CHECK-LABEL: @test_vector_contract_to_ummla
+
+// Extract LHS rows 0 and 1, concatenate, turn into scalable vector
+// CHECK: %[[T6:[0-9]+]] = llvm.extractvalue %[[T4:[0-9]+]][0] : !llvm.array<4 x vector<8xi8>>
+// CHECK-NEXT: %[[T7:[0-9]+]] = llvm.extractvalue %[[T4]][1] : !llvm.array<4 x vector<8xi8>>
+// CHECK-NEXT: %[[T8:[0-9]+]] = llvm.shufflevector %[[T6]], %[[T7]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8>
+// CHECK-NEXT: %[[T9:[0-9]+]] = llvm.intr.vector.insert %[[T8]], %[[T0:[0-9+]]][0] : vector<16xi8> into vector<[16]xi8>
+
+// Replicate across the entire length of the scalabale vector
+// CHECK-NEXT: %[[T10:[0-9]+]] = "arm_sve.intr.dupq_lane"(%[[T9]]) <{lane = 0 : i64}> : (vector<[16]xi8>) -> vector<[16]xi8>
+
+// Same for LHS rows 2 and 4
+// CHECK-NEXT: %[[T11:[0-9]+]] = llvm.extractvalue %[[T4]][2] : !llvm.array<4 x vector<8xi8>>
+// CHECK-NEXT: %[[T12:[0-9]+]] = llvm.extractvalue %[[T4]][3] : !llvm.array<4 x vector<8xi8>>
+// CHECK-NEXT: %[[T13:[0-9]+]] = llvm.shufflevector %[[T11]], %[[T12]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8>
+// CHECK-NEXT: %[[T14:[0-9]+]] = llvm.intr.vector.insert %[[T13]], %[[T0]][0] : vector<16xi8> into vector<[16]xi8>
+// CHECK-NEXT: %[[T15:[0-9]+]] = "arm_sve.intr.dupq_lane"(%[[T14]]) <{lane = 0 : i64}> : (vector<[16]xi8>) -> vector<[16]xi8>
+
+// Extract sub-tiles from the RHS
+// CHECK-NEXT: %[[T16:[0-9]+]] = vector.shape_cast %arg1 : vector<[4]x8xi8> to vector<[32]xi8>
+// CHECK-NEXT: %[[T17:[0-9]+]] = llvm.intr.vector.extract %[[T16]][0] : vector<[16]xi8> from vector<[32]xi8>
+// CHECK-NEXT: %[[T18:[0-9]+]] = llvm.intr.vector.extract %[[T16]][16] : vector<[16]xi8> from vector<[32]xi8>
+
+// Extract accumulator rows 0 and 1 and pack (into "registers")
+// CHECK-NEXT: %[[T19:[0-9]+]] = llvm.extractvalue %[[T3:[0-9]+]][0] : !llvm.array<4 x vector<[4]xi32>>
+// CHECK-NEXT: %[[T20:[0-9]+]] = llvm.extractvalue %[[T3]][1] : !llvm.array<4 x vector<[4]xi32>>
+// CHECK-NEXT: %[[T21:[0-9]+]] = llvm.bitcast %[[T19]] : vector<[4]xi32> to vector<[2]xi64>
+// CHECK-NEXT: %[[T22:[0-9]+]] = llvm.bitcast %[[T20]] : vector<[4]xi32> to vector<[2]xi64>
+// CHECK-NEXT: %[[T23:[0-9]+]] = "llvm.intr.vector.interleave2"(%[[T21]], %[[T22]]) : (vector<[2]xi64>, vector<[2]xi64>) -> vector<[4]xi64>
+// CHECK-NEXT: %[[T24:[0-9]+]] = llvm.bitcast %[[T23]] : vector<[4]xi64> to vector<[8]xi32>
+// CHECK-NEXT: %[[T25:[0-9]+]] = llvm.intr.vector.extract %[[T24]][0] : vector<[4]xi32> from vector<[8]xi32>
+// CHECK-NEXT: %[[T26:[0-9]+]] = llvm.intr.vector.extract %[[T24]][4] : vector<[4]xi32> from vector<[8]xi32>
+
+// Same for accumulator rows 2 and 3.
+// CHECK-NEXT: %[[T27:[0-9]+]] = llvm.extractvalue %[[T3]][2] : !llvm.array<4 x vector<[4]xi32>>
+// CHECK-NEXT: %[[T28:[0-9]+]] = llvm.extractvalue %[[T3]][3] : !llvm.array<4 x vector<[4]xi32>>
+// CHECK-NEXT: %[[T29:[0-9]+]] = llvm.bitcast %[[T27]] : vector<[4]xi32> to vector<[2]xi64>
+// CHECK-NEXT: %[[T30:[0-9]+]] = llvm.bitcast %[[T28]] : vector<[4]xi32> to vector<[2]xi64>
+// CHECK-NEXT: %[[T31:[0-9]+]] = "llvm.intr.vector.interleave2"(%[[T29]], %[[T30]]) : (vector<[2]xi64>, vector<[2]xi64>) -> vector<[4]xi64>
+// CHECK-NEXT: %[[T32:[0-9]+]] = llvm.bitcast %[[T31]] : vector<[4]xi64> to vector<[8]xi32>
+// CHECK-NEXT: %[[T33:[0-9]+]] = llvm.intr.vector.extract %[[T32]][0] : vector<[4]xi32> from vector<[8]xi32>
+// CHECK-NEXT: %[[T34:[0-9]+]] = llvm.intr.vector.extract %[[T32]][4] : vector<[4]xi32> from vector<[8]xi32>
+
+// Do the sub-tile matrix multiplications
+// CHECK-NEXT: %[[T35:[0-9]+]] = "arm_sve.intr.ummla"(%[[T25]], %[[T10]], %[[T17]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32>
+// CHECK-NEXT: %[[T36:[0-9]+]] = "arm_sve.intr.ummla"(%[[T26]], %[[T10]], %[[T18]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32>
+// CHECK-NEXT: %[[T37:[0-9]+]] = "arm_sve.intr.ummla"(%[[T33]], %[[T15]], %[[T17]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32>
+// CHECK-NEXT: %[[T38:[0-9]+]] = "arm_sve.intr.ummla"(%[[T34]], %[[T15]], %[[T18]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32>
+
+// Unpack (from "registers") and insert in the output result rows  0 and 1
+// CHECK-NEXT: %[[T39:[0-9]+]] = llvm.intr.vector.insert %[[T35]], %[[T2:[0-9]+]][0] : vector<[4]xi32> into vector<[8]xi32>
+// CHECK-NEXT: %[[T40:[0-9]+]] = llvm.intr.vector.insert %[[T36]], %[[T39]][4] : vector<[4]xi32> into vector<[8]xi32>
+// CHECK-NEXT: %[[T41:[0-9]+]] = llvm.bitcast %[[T40]] : vector<[8]xi32> to vector<[4]xi64>
+// CHECK-NEXT: %[[T42:[0-9]+]] = "llvm.intr.vector.deinterleave2"(%[[T41]]) : (vector<[4]xi64>) -> !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)>
+// CHECK-NEXT: %[[T43:[0-9]+]] = llvm.extractvalue %[[T42]][0] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)>
+// CHECK-NEXT: %[[T44:[0-9]+]] = llvm.extractvalue %[[T42]][1] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)>
+// CHECK-NEXT: %[[T45:[0-9]+]] = llvm.bitcast %[[T43]] : vector<[2]xi64> to vector<[4]xi32>
+// CHECK-NEXT: %[[T46:[0-9]+]] = llvm.bitcast %[[T44]] : vector<[2]xi64> to vector<[4]xi32>
+// CHECK-NEXT: %[[T47:[0-9]+]] = llvm.insertvalue %[[T45]], %[[T5:[0-9]+]][0] : !llvm.array<4 x vector<[4]xi32>>
+// CHECK-NEXT: %[[T48:[0-9]+]] = llvm.insertvalue %[[T46]], %[[T47]][1] : !llvm.array<4 x vector<[4]xi32>>
+
+// Same for result rows 2 and 3
+// CHECK-NEXT: %[[T49:[0-9]+]] = llvm.intr.vector.insert %[[T37]], %[[T2]][0] : vector<[4]xi32> into vector<[8]xi32>
+// CHECK-NEXT: %[[T50:[0-9]+]] = llvm.intr.vector.insert %[[T38]], %[[T49]][4] : vector<[4]xi32> into vector<[8]xi32>
+// CHECK-NEXT: %[[T51:[0-9]+]] = llvm.bitcast %[[T50]] : vector<[8]xi32> to vector<[4]xi64>
+// CHECK-NEXT: %[[T52:[0-9]+]] = "llvm.intr.vector.deinterleave2"(%[[T51]]) : (vector<[4]xi64>) -> !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)>
+// CHECK-NEXT: %[[T53:[0-9]+]] = llvm.extractvalue %[[T52]][0] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)>
+// CHECK-NEXT: %[[T54:[0-9]+]] = llvm.extractvalue %[[T52]][1] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)>
+// CHECK-NEXT: %[[T55:[0-9]+]] = llvm.bitcast %[[T53]] : vector<[2]xi64> to vector<[4]xi32>
+// CHECK-NEXT: %[[T56:[0-9]+]] = llvm.bitcast %[[T54]] : vector<[2]xi64> to vector<[4]xi32>
+// CHECK-NEXT: %[[T57:[0-9]+]] = llvm.insertvalue %[[T55]], %[[T48]][2] : !llvm.array<4 x vector<[4]xi32>>
+// CHECK-NEXT: %[[T58:[0-9]+]] = llvm.insertvalue %[[T56]], %[[T57]][3] : !llvm.array<4 x vector<[4]xi32>>
+
+func.func @test_vector_contract_to_ummla(%lhs: vector<4x8xi8>,
+              %rhs: vector<[4]x8xi8>,
+              %acc: vector<4x[4]xi32>) -> vector<4x[4]xi32> {
+
+  %0 = arith.extui %lhs : vector<4x8xi8> to vector<4x8xi32>
+  %1 = arith.extui %rhs : vector<[4]x8xi8> to vector<[4]x8xi32>
+  %2 = vector.contract {indexing_maps = #packed_maps,
+                        iterator_types = ["parallel", "parallel", "reduction"],
+                        kind = #vector.kind<add>} %0, %1, %acc
+    : vector<4x8xi32>, vector<[4]x8xi32> into vector<4x[4]xi32>
+
+  return %2 : vector<4x[4]xi32>
+}
diff --git a/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-usmmla.mlir b/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-usmmla.mlir
new file mode 100644
index 0000000000000..d0eef9fb9769c
--- /dev/null
+++ b/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-usmmla.mlir
@@ -0,0 +1,95 @@
+// RUN:  mlir-opt %s --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' --split-input-file | FileCheck %s
+
+#packed_maps = [
+  affine_map<(d0, d1, d2) -> (d0, d2)>,
+  affine_map<(d0, d1, d2) -> (d1, d2)>,
+  affine_map<(d0, d1, d2) -> (d0, d1)>
+]
+
+// CHECK-LABEL: @test_vector_contract_to_usmmla
+
+// Extract LHS rows 0 and 1, concatenate, turn into scalable vector
+// CHECK: %[[T6:[0-9]+]] = llvm.extractvalue %[[T4:[0-9]+]][0] : !llvm.array<4 x vector<8xi8>>
+// CHECK-NEXT: %[[T7:[0-9]+]] = llvm.extractvalue %[[T4]][1] : !llvm.array<4 x vector<8xi8>>
+// CHECK-NEXT: %[[T8:[0-9]+]] = llvm.shufflevector %[[T6]], %[[T7]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8>
+// CHECK-NEXT: %[[T9:[0-9]+]] = llvm.intr.vector.insert %[[T8]], %[[T0:[0-9+]]][0] : vector<16xi8> into vector<[16]xi8>
+
+// Replicate across the entire length of the scalabale vector
+// CHECK-NEXT: %[[T10:[0-9]+]] = "arm_sve.intr.dupq_lane"(%[[T9]]) <{lane = 0 : i64}> : (vector<[16]xi8>) -> vector<[16]xi8>
+
+// Same for LHS rows 2 and 4
+// CHECK-NEXT: %[[T11:[0-9]+]] = llvm.extractvalue %[[T4]][2] : !llvm.array<4 x vector<8xi8>>
+// CHECK-NEXT: %[[T12:[0-9]+]] = llvm.extractvalue %[[T4]][3] : !llvm.array<4 x vector<8xi8>>
+// CHECK-NEXT: %[[T13:[0-9]+]] = llvm.shufflevector %[[T11]], %[[T12]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8>
+// CHECK-NEXT: %[[T14:[0-9]+]] = llvm.intr.vector.insert %[[T13]], %[[T0]][0] : vector<16xi8> into vector<[16]xi8>
+// CHECK-NEXT: %[[T15:[0-9]+]] = "arm_sve.intr.dupq_lane"(%[[T14]]) <{lane = 0 : i64}> : (vector<[16]xi8>) -> vector<[16]xi8>
+
+// Extract sub-tiles from the RHS
+// CHECK-NEXT: %[[T16:[0-9]+]] = vector.shape_cast %arg1 : vector<[4]x8xi8> to vector<[32]xi8>
+// CHECK-NEXT: %[[T17:[0-9]+]] = llvm.intr.vector.extract %[[T16]][0] : vector<[16]xi8> from vector<[32]xi8>
+// CHECK-NEXT: %[[T18:[0-9]+]] = llvm.intr.vector.extract %[[T16]][16] : vector<[16]xi8> from vector<[32]xi8>
+
+// Extract accumulator rows 0 and 1 and pack (into "registers")
+// CHECK-NEXT: %[[T19:[0-9]+]] = llvm.extractvalue %[[T3:[0-9]+]][0] : !llvm.array<4 x vector<[4]xi32>>
+// CHECK-NEXT: %[[T20:[0-9]+]] = llvm.extractvalue %[[T3]][1] : !llvm.array<4 x vector<[4]xi32>>
+// CHECK-NEXT: %[[T21:[0-9]+]] = llvm.bitcast %[[T19]] : vector<[4]xi32> to vector<[2]xi64>
+// CHECK-NEXT: %[[T22:[0-9]+]] = llvm.bitcast %[[T20]] : vector<[4]xi32> to vector<[2]xi64>
+// CHECK-NEXT: %[[T23:[0-9]+]] = "llvm.intr.vector.interleave2"(%[[T21]], %[[T22]]) : (vector<[2]xi64>, vector<[2]xi64>) -> vector<[4]xi64>
+// CHECK-NEXT: %[[T24:[0-9]+]] = llvm.bitcast %[[T23]] : vector<[4]xi64> to vector<[8]xi32>
+// CHECK-NEXT: %[[T25:[0-9]+]] = llvm.intr.vector.extract %[[T24]][0] : vector<[4]xi32> from vector<[8]xi32>
+// CHECK-NEXT: %[[T26:[0-9]+]] = llvm.intr.vector.extract %[[T24]][4] : vector<[4]xi32> from vector<[8]xi32>
+
+// Same for accumulator rows 2 and 3.
+// CHECK-NEXT: %[[T27:[0-9]+]] = llvm.extractvalue %[[T3]][2] : !llvm.array<4 x vector<[4]xi32>>
+// CHECK-NEXT: %[[T28:[0-9]+]] = llvm.extractvalue %[[T3]][3] : !llvm.array<4 x vector<[4]xi32>>
+// CHECK-NEXT: %[[T29:[0-9]+]] = llvm.bitcast %[[T27]] : vector<[4]xi32> to vector<[2]xi64>
+// CHECK-NEXT: %[[T30:[0-9]+]] = llvm.bitcast %[[T28]] : vector<[4]xi32> to vector<[2]xi64>
+// CHECK-NEXT: %[[T31:[0-9]+]] = "llvm.intr.vector.interleave2"(%[[T29]], %[[T30]]) : (vector<[2]xi64>, vector<[2]xi64>) -> vector<[4]xi64>
+// CHECK-NEXT: %[[T32:[0-9]+]] = llvm.bitcast %[[T31]] : vector<[4]xi64> to vector<[8]xi32>
+// CHECK-NEXT: %[[T33:[0-9]+]] = llvm.intr.vector.extract %[[T32]][0] : vector<[4]xi32> from vector<[8]xi32>
+// CHECK-NEXT: %[[T34:[0-9]+]] = llvm.intr.vector.extract %[[T32]][4] : vector<[4]xi32> from vector<[8]xi32>
+
+// Do the sub-tile matrix multiplications
+// CHECK-NEXT: %[[T35:[0-9]+]] = "arm_sve.intr.usmmla"(%[[T25]], %[[T10]], %[[T17]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32>
+// CHECK-NEXT: %[[T36:[0-9]+]] = "arm_sve.intr.usmmla"(%[[T26]], %[[T10]], %[[T18]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32>
+// CHECK-NEXT: %[[T37:[0-9]+]] = "arm_sve.intr.usmmla"(%[[T33]], %[[T15]], %[[T17]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32>
+// CHECK-NEXT: %[[T38:[0-9]+]] = "arm_sve.intr.usmmla"(%[[T34]], %[[T15]], %[[T18]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32>
+
+// Unpack (from "registers") and insert in the output result rows  0 and 1
+// CHECK-NEXT: %[[T39:[0-9]+]] = llvm.intr.vector.insert %[[T35]], %[[T2:[0-9]+]][0] : vector<[4]xi32> into vector<[8]xi32>
+// CHECK-NEXT: %[[T40:[0-9]+]] = llvm.intr.vector.insert %[[T36]], %[[T39]][4] : vector<[4]xi32> into vector<[8]xi32>
+// CHECK-NEXT: %[[T41:[0-9]+]] = llvm.bitcast %[[T40]] : vector<[8]xi32> to vector<[4]xi64>
+// CHECK-NEXT: %[[T42:[0-9]+]] = "llvm.intr.vector.deinterleave2"(%[[T41]]) : (vector<[4]xi64>) -> !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)>
+// CHECK-NEXT: %[[T43:[0-9]+]] = llvm.extractvalue %[[T42]][0] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)>
+// CHECK-NEXT: %[[T44:[0-9]+]] = llvm.extractvalue %[[T42]][1] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)>
+// CHECK-NEXT: %[[T45:[0-9]+]] = llvm.bitcast %[[T43]] : vector<[2]xi64> to vector<[4]xi32>
+// CHECK-NEXT: %[[T46:[0-9]+]] = llvm.bitcast %[[T44]] : vector<[2]xi64> to vector<[4]xi32>
+// CHECK-NEXT: %[[T47:[0-9]+]] = llvm.insertvalue %[[T45]], %[[T5:[0-9]+]][0] : !llvm.array<4 x vector<[4]xi32>>
+// CHECK-NEXT: %[[T48:[0-9]+]] = llvm.insertvalue %[[T46]], %[[T47]][1] : !llvm.array<4 x vector<[4]xi32>>
+
+// Same for result rows 2 and 3
+// CHECK-NEXT: %[[T49:[0-9]+]] = llvm.intr.vector.insert %[[T37]], %[[T2]][0] : vector<[4]xi32> into vector<[8]xi32>
+// CHECK-NEXT: %[[T50:[0-9]+]] = llvm.intr.vector.insert %[[T38]], %[[T49]][4] : vector<[4]xi32> into vector<[8]xi32>
+// CHECK-NEXT: %[[T51:[0-9]+]] = llvm.bitcast %[[T50]] : vector<[8]xi32> to vector<[4]xi64>
+// CHECK-NEXT: %[[T52:[0-9]+]] = "llvm.intr.vector.deinterleave2"(%[[T51]]) : (vector<[4]xi64>) -> !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)>
+// CHECK-NEXT: %[[T53:[0-9]+]] = llvm.extractvalue %[[T52]][0] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)>
+// CHECK-NEXT: %[[T54:[0-9]+]] = llvm.extractvalue %[[T52]][1] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)>
+// CHECK-NEXT: %[[T55:[0-9]+]] = llvm.bitcast %[[T53]] : vector<[2]xi64> to vector<[4]xi32>
+// CHECK-NEXT: %[[T56:[0-9]+]] = llvm.bitcast %[[T54]] : vector<[2]xi64> to vector<[4]xi32>
+// CHECK-NEXT: %[[T57:[0-9]+]] = llvm.insertvalue %[[T55]], %[[T48]][2] : !llvm.array<4 x vector<[4]xi32>>
+// CHECK-NEXT: %[[T58:[0-9]+]] = llvm.insertvalue %[[T56]], %[[T57]][3] : !llvm.array<4 x vector<[4]xi32>>
+
+func.func @test_vector_contract_to_usmmla(
+  %lhs: vector<4x8xi8>,
+  %rhs: vector<[4]x8xi8>,
+  %acc: vector<4x[4]xi32>) -> vector<4x[4]xi32> {
+
+  %0 = arith.extui %lhs : vector<4x8xi8> to vector<4x8xi32>
+  %1 = arith.extsi %rhs : vector<[4]x8xi8> to vector<[4]x8xi32>
+  %2 = vector.contract {indexing_maps = #packed_maps,
+                        iterator_types = ["parallel", "parallel", "reduction"],
+                        kind = #vector.kind<add>} %0, %1, %acc
+    : vector<4x8xi32>, vector<[4]x8xi32> into vector<4x[4]xi32>
+
+  return %2 : vector<4x[4]xi32>
+}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir
new file mode 100644
index 0000000000000..88534dd2aab1e
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir
@@ -0,0 +1,117 @@
+// REQUIRES: arm-emulator
+
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE:   --convert-vector-to-scf --convert-scf-to-cf  --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' \
+// DEFINE:   --expand-strided-metadata --convert-to-llvm --finalize-memref-to-llvm  --reconcile-unrealized-casts \
+// DEFINE: -o %t
+
+// DEFINE: %{entry_point} = main
+
+// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void  --march=aarch64 --mattr="+sve,+i8mm" \
+// DEFINE:    -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%native_mlir_arm_runner_utils
+
+// RUN: rm -f %t && %{compile} && %{run} | FileCheck %s
+
+#packed_maps = [
+  affine_map<(d0, d1, d2) -> (d0, d2)>,
+  affine_map<(d0, d1, d2) -> (d1, d2)>,
+  affine_map<(d0, d1, d2) -> (d0, d1)>
+]
+
+func.func private @setArmVLBits(%bits : i32)
+
+func.func @main() {
+  %c128 = arith.constant 128 : i32
+  func.call @setArmVLBits(%c128) : (i32) -> ()
+
+  %c0 = arith.constant 0 : index
+  %c0_i32 = arith.constant 0 : i32
+  %c0_i8 = arith.constant 0 : i8
+
+// Accumulator test data
+  %acc_cst = arith.constant dense<[[-44,  20,  44, -46],
+                                   [ -8,  25, -34,  26],
+                                   [-20, -36,  -3,  39],
+                                   [-48, -31, -25, -21]]> : vector<4x4xi32>
+  %acc_m = memref.alloca() : memref<4x4xi32>
+  vector.transfer_write %acc_cst, %acc_m[%c0, %c0] : vector<4x4xi32>, memref<4x4xi32>
+
+  %acc_m1 = memref.collapse_shape %acc_m [[0, 1]] : memref<4x4xi32> into memref<16xi32>
+  %acc_flat = vector.transfer_read %acc_m1[%c0], %c0_i32 {in_bounds = [true]} : memref<16xi32>, vector<[16]xi32>
+  %acc = vector.shape_cast %acc_flat : vector<[16]xi32> to vector<4x[4]xi32>
+
+  vector.print str "ACC:\n"
+  %acc0 = vector.extract %acc[0] : vector<[4]xi32> from vector<4x[4]xi32>
+  %acc1 = vector.extract %acc[1] : vector<[4]xi32> from vector<4x[4]xi32>
+  %acc2 = vector.extract %acc[2] : vector<[4]xi32> from vector<4x[4]xi32>
+  %acc3 = vector.extract %acc[3] : vector<[4]xi32> from vector<4x[4]xi32>
+  vector.print %acc0 : vector<[4]xi32>
+  vector.print %acc1 : vector<[4]xi32>
+  vector.print %acc2 : vector<[4]xi32>
+  vector.print %acc3 : vector<[4]xi32>
+
+  // LHS test data
+  %lhs_cst = arith.constant dense<[[-35, -27, -36, -31,  23, -34,  -8, -33],
+                                   [-20,  17, -32, -47,  37,  22,  -7, -21],
+                                   [ -7, -35,  20,  -4,  39,  46, -23,  40],
+                                   [ 40,  27,  37,  43,  38,  -6,  37,  49]]> : vector<4x8xi8>
+
+  %lhs_m = memref.alloca() : memref<4x8xi8>
+  vector.transfer_write %lhs_cst, %lhs_m[%c0, %c0] : vector<4x8xi8>, memref<4x8xi8>
+  %lhs = vector.transfer_read %lhs_m[%c0, %c0], %c0_i8 : memref<4x8xi8>, vector<4x8xi8>
+
+  vector.print str "LHS:\n"
+  %lhs0 = vector.extract %lhs[0] : vector<8xi8> from vector<4x8xi8>
+  %lhs1 = vector.extract %lhs[1] : vector<8xi8> from vector<4x8xi8>
+  %lhs2 = vector.extract %lhs[2] : vector<8xi8> from vector<4x8xi8>
+  %lhs3 = vector.extract %lhs[3] : vector<8xi8> from vector<4x8xi8>
+  vector.print %lhs0 : vector<8xi8>
+  vector.print %lhs1 : vector<8xi8>
+  vector.print %lhs2 : vector<8xi8>
+  vector.print %lhs3 : vector<8xi8>
+
+  // RHS test data
+  %rhs_cst = arith.constant dense<[[-17, -50,  -1,  48, -13,  22,  39,  33],
+                                   [-35, -24,  37, -32,  33,  30, -11, -17],
+                                   [-28,  31,   3, -44, -15, -27,  22,  35],
+                                   [-23,  39,  48,  26, -23,  32, -39, -38]]> : vector<4x8xi8>
+
+  %rhs_m = memref.alloca() : memref<4x8xi8>
+  vector.transfer_write %rhs_cst, %rhs_m[%c0, %c0] : vector<4x8xi8>, memref<4x8xi8>
+
+  %rhs_m1 = memref.collapse_shape %rhs_m [[0, 1]] : memref<4x8xi8> into memref<32xi8>
+  %rhs_flat = vector.transfer_read %rhs_m1[%c0], %c0_i8 {in_bounds = [true]} : memref<32xi8>, vector<[32]xi8>
+
+  vector.print str "RHS:\n"
+  %rhs0 = vector.scalable.extract %rhs_flat[0] : vector<[16]xi8> from vector<[32]xi8>
+  %rhs1 = vector.scalable.extract %rhs_flat[16] : vector<[16]xi8> from vector<[32]xi8>
+  vector.print %rhs0 : vector<[16]xi8>
+  vector.print %rhs1 : vector<[16]xi8>
+
+  %rhs = vector.shape_cast %rhs_flat : vector<[32]xi8> to vector<[4]x8xi8>
+
+  // Matrix multiplication
+  %0 = arith.extsi %lhs : vector<4x8xi8> to vector<4x8xi32>
+  %1 = arith.extsi %rhs : vector<[4]x8xi8> to vector<[4]x8xi32>
+  %2 = vector.contract {indexing_maps = #packed_maps,
+                        iterator_types = ["parallel", "parallel", "reduction"],
+                        kind = #vector.kind<add>} %0, %1, %acc
+    : vector<4x8xi32>, vector<[4]x8xi32> into vector<4x[4]xi32>
+
+  // Display the result of the multiplication
+  vector.print str "Result:\n"
+  %u0 = vector.extract %2[0] : vector<[4]xi32> from vector<4x[4]xi32>
+  %u1 = vector.extract %2[1] : vector<[4]xi32> from vector<4x[4]xi32>
+  %u2 = vector.extract %2[2] : vector<[4]xi32> from vector<4x[4]xi32>
+  %u3 = vector.extract %2[3] : vector<[4]xi32> from vector<4x[4]xi32>
+  vector.print %u0 : vector<[4]xi32>
+  vector.print %u1 : vector<[4]xi32>
+  vector.print %u2 : vector<[4]xi32>
+  vector.print %u3 : vector<[4]xi32>
+
+// CHECK: ( -1999,  1941,   685, -2879 )
+// CHECK: ( -3705,  2952,   987,  -685 )
+// CHECK: (  2565,  4157, -1589,  -357 )
+// CHECK: (  2383, -2252,    32, -1365 )
+  return
+}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-8x8x8-vs2.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-8x8x8-vs2.mlir
new file mode 100644
index 0000000000000..ce57be91fa540
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-8x8x8-vs2.mlir
@@ -0,0 +1,159 @@
+// REQUIRES: arm-emulator
+
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE:   --convert-vector-to-scf --convert-scf-to-cf  --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' \
+// DEFINE:   --expand-strided-metadata --convert-to-llvm --finalize-memref-to-llvm  --reconcile-unrealized-casts \
+// DEFINE: -o %t
+
+// DEFINE: %{entry_point} = main
+
+// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void  --march=aarch64 --mattr="+sve,+i8mm" \
+// DEFINE:    -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%native_mlir_arm_runner_utils
+
+// RUN: rm -f %t && %{compile} && %{run} | FileCheck %s
+
+#packed_maps = [
+  affine_map<(d0, d1, d2) -> (d0, d2)>,
+  affine_map<(d0, d1, d2) -> (d1, d2)>,
+  affine_map<(d0, d1, d2) -> (d0, d1)>
+]
+
+func.func private @setArmVLBits(%bits : i32)
+
+func.func @main() {
+  %c256 = arith.constant 256 : i32
+  func.call @setArmVLBits(%c256) : (i32) -> ()
+
+  %c0 = arith.constant 0 : index
+  %c0_i32 = arith.constant 0 : i32
+  %c0_i8 = arith.constant 0 : i8
+
+
+  // Accumulator test data
+  %acc_cst = arith.constant dense<[[-44,  20,  44, -46,  -8,  25, -34,  26],
+                                   [-20, -36,  -3,  39, -48, -31, -25, -21],
+                                   [-35, -27, -36, -31,  23, -34,  -8, -33],
+                                   [-20,  17, -32, -47,  37,  22,  -7, -21],
+                                   [ -7, -35,  20,  -4,  39,  46, -23,  40],
+                                   [ 40,  27,  37,  43,  38,  -6,  37,  49],
+                                   [-17, -50,  -1,  48, -13,  22,  39,  33],
+                                   [-35, -24,  37, -32,  33,  30, -11, -17]]> : vector<8x8xi32>
+  %acc_m = memref.alloca() : memref<8x8xi32>
+  vector.transfer_write %acc_cst, %acc_m[%c0, %c0] : vector<8x8xi32>, memref<8x8xi32>
+
+  %acc_m1 = memref.collapse_shape %acc_m [[0, 1]] : memref<8x8xi32> into memref<64xi32>
+  %acc_flat = vector.transfer_read %acc_m1[%c0], %c0_i32 {in_bounds = [true]} : memref<64xi32>, vector<[32]xi32>
+  %acc = vector.shape_cast %acc_flat : vector<[32]xi32> to vector<8x[4]xi32>
+
+  vector.print str "ACC:\n"
+  %acc0 = vector.extract %acc[0] : vector<[4]xi32> from vector<8x[4]xi32>
+  %acc1 = vector.extract %acc[1] : vector<[4]xi32> from vector<8x[4]xi32>
+  %acc2 = vector.extract %acc[2] : vector<[4]xi32> from vector<8x[4]xi32>
+  %acc3 = vector.extract %acc[3] : vector<[4]xi32> from vector<8x[4]xi32>
+  %acc4 = vector.extract %acc[4] : vector<[4]xi32> from vector<8x[4]xi32>
+  %acc5 = vector.extract %acc[5] : vector<[4]xi32> from vector<8x[4]xi32>
+  %acc6 = vector.extract %acc[6] : vector<[4]xi32> from vector<8x[4]xi32>
+  %acc7 = vector.extract %acc[7] : vector<[4]xi32> from vector<8x[4]xi32>
+  vector.print %acc0 : vector<[4]xi32>
+  vector.print %acc1 : vector<[4]xi32>
+  vector.print %acc2 : vector<[4]xi32>
+  vector.print %acc3 : vector<[4]xi32>
+  vector.print %acc4 : vector<[4]xi32>
+  vector.print %acc5 : vector<[4]xi32>
+  vector.print %acc6 : vector<[4]xi32>
+  vector.print %acc7 : vector<[4]xi32>
+
+  // LHS test data
+  %lhs_cst = arith.constant dense<[[-28,  31,   3, -44, -15, -27,  22,  35],
+                                   [-23,  39,  48,  26, -23,  32, -39, -38],
+                                   [ -3,   9,  43, -30, -32,  39,  41, -39],
+                                   [-13, -21, -25,  27,  47, -36, -11, -11],
+                                   [ -4, -20,  36,  11,  13, -23,  24, -13],
+                                   [-20,  30,  -5,   1,  42, -37, -22,  35],
+                                   [-22,  38,  -4,  44,  25, -31,  23, -39],
+                                   [-45,  -4, -31, -24,  14, -41, -47,  22]]> : vector<8x8xi8>
+
+  %lhs_m = memref.alloca() : memref<8x8xi8>
+  vector.transfer_write %lhs_cst, %lhs_m[%c0, %c0] : vector<8x8xi8>, memref<8x8xi8>
+  %lhs = vector.transfer_read %lhs_m[%c0, %c0], %c0_i8 : memref<8x8xi8>, vector<8x8xi8>
+
+  vector.print str "LHS:\n"
+  %lhs0 = vector.extract %lhs[0] : vector<8xi8> from vector<8x8xi8>
+  %lhs1 = vector.extract %lhs[1] : vector<8xi8> from vector<8x8xi8>
+  %lhs2 = vector.extract %lhs[2] : vector<8xi8> from vector<8x8xi8>
+  %lhs3 = vector.extract %lhs[3] : vector<8xi8> from vector<8x8xi8>
+  %lhs4 = vector.extract %lhs[4] : vector<8xi8> from vector<8x8xi8>
+  %lhs5 = vector.extract %lhs[5] : vector<8xi8> from vector<8x8xi8>
+  %lhs6 = vector.extract %lhs[6] : vector<8xi8> from vector<8x8xi8>
+  %lhs7 = vector.extract %lhs[7] : vector<8xi8> from vector<8x8xi8>
+  vector.print %lhs0 : vector<8xi8>
+  vector.print %lhs1 : vector<8xi8>
+  vector.print %lhs2 : vector<8xi8>
+  vector.print %lhs3 : vector<8xi8>
+  vector.print %lhs4 : vector<8xi8>
+  vector.print %lhs5 : vector<8xi8>
+  vector.print %lhs6 : vector<8xi8>
+  vector.print %lhs7 : vector<8xi8>
+
+  // RHS test data
+  %rhs_cst = arith.constant dense<[[-40, -11, -36,  36,  -1,  20,  14, -32],
+                                   [ 46, -45, -48, -46, -24,  31, -36,  22],
+                                   [  2,  36,  45, -29, -37, -49, -20, -35],
+                                   [ -6,  23,  23,  15,  20,   4,  -8,  -2],
+                                   [-35,  -6,  16,  49, -50,   9, -44,  13],
+                                   [ 24,   1,  -4, -44,  41,  15, -43,  44],
+                                   [ 44,   0, -10,  41,  22,  44, -40,   0],
+                                   [-33,  19,  27,  22,  38, -17,  23,  -9]]> : vector<8x8xi8>
+
+  %rhs_m = memref.alloca() : memref<8x8xi8>
+  vector.transfer_write %rhs_cst, %rhs_m[%c0, %c0] : vector<8x8xi8>, memref<8x8xi8>
+
+  %rhs_m1 = memref.collapse_shape %rhs_m [[0, 1]] : memref<8x8xi8> into memref<64xi8>
+  %rhs_flat = vector.transfer_read %rhs_m1[%c0], %c0_i8 {in_bounds = [true]} : memref<64xi8>, vector<[32]xi8>
+
+  vector.print str "RHS:\n"
+  %rhs0 = vector.scalable.extract %rhs_flat[ 0] : vector<[16]xi8> from vector<[32]xi8>
+  %rhs1 = vector.scalable.extract %rhs_flat[16] : vector<[16]xi8> from vector<[32]xi8>
+  vector.print %rhs0 : vector<[16]xi8>
+  vector.print %rhs1 : vector<[16]xi8>
+
+  %rhs = vector.shape_cast %rhs_flat : vector<[32]xi8> to vector<[4]x8xi8>
+
+  // Matrix multiplication
+  %0 = arith.extsi %lhs : vector<8x8xi8> to vector<8x8xi32>
+  %1 = arith.extsi %rhs : vector<[4]x8xi8> to vector<[4]x8xi32>
+  %2 = vector.contract {indexing_maps = #packed_maps,
+                        iterator_types = ["parallel", "parallel", "reduction"],
+                        kind = #vector.kind<add>} %0, %1, %acc
+    : vector<8x8xi32>, vector<[4]x8xi32> into vector<8x[4]xi32>
+
+  // Display the result of the multilication
+  vector.print str "Result:\n"
+  %u0 = vector.extract %2[0] : vector<[4]xi32> from vector<8x[4]xi32>
+  %u1 = vector.extract %2[1] : vector<[4]xi32> from vector<8x[4]xi32>
+  %u2 = vector.extract %2[2] : vector<[4]xi32> from vector<8x[4]xi32>
+  %u3 = vector.extract %2[3] : vector<[4]xi32> from vector<8x[4]xi32>
+  %u4 = vector.extract %2[4] : vector<[4]xi32> from vector<8x[4]xi32>
+  %u5 = vector.extract %2[5] : vector<[4]xi32> from vector<8x[4]xi32>
+  %u6 = vector.extract %2[6] : vector<[4]xi32> from vector<8x[4]xi32>
+  %u7 = vector.extract %2[7] : vector<[4]xi32> from vector<8x[4]xi32>
+  vector.print %u0 : vector<[4]xi32>
+  vector.print %u1 : vector<[4]xi32>
+  vector.print %u2 : vector<[4]xi32>
+  vector.print %u3 : vector<[4]xi32>
+  vector.print %u4 : vector<[4]xi32>
+  vector.print %u5 : vector<[4]xi32>
+  vector.print %u6 : vector<[4]xi32>
+  vector.print %u7 : vector<[4]xi32>
+
+
+// CHECK: ( -2294, -1282,  2728,  -410, -1328,   882, -5498,   732 )
+// CHECK: (  1012, -4237,  4154,  2624,  5225, -2338,  2011,  1374 )
+// CHECK: (    -8, -1611,  2905,    -1, -1068, -3155, -2428,   153 )
+// CHECK: (  2034, -1768, -2092,   284,  -792,   -23,   668,  2172 )
+// CHECK: (  -248, -3728,  1214,   555,  -668, -2114, -1794,  2560 )
+// CHECK: ( -1484, -2642,   297,  1551,  -483,  3173,  -576,  2570 )
+// CHECK: (  3098, -7851,  1366,  1892,  -427, -4533,  -819,  4698 )
+// CHECK: (  -135,  1247,   765,  -479,  1245,  3074, -2281,   -23 )
+  return
+}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-summla-4x8x4.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-summla-4x8x4.mlir
new file mode 100644
index 0000000000000..f1f311ddb0c18
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-summla-4x8x4.mlir
@@ -0,0 +1,118 @@
+// REQUIRES: arm-emulator
+
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE:   --convert-vector-to-scf --convert-scf-to-cf  --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' \
+// DEFINE:   --expand-strided-metadata --convert-to-llvm --finalize-memref-to-llvm  --reconcile-unrealized-casts \
+// DEFINE: -o %t
+
+// DEFINE: %{entry_point} = main
+
+// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void  --march=aarch64 --mattr="+sve,+i8mm" \
+// DEFINE:    -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%native_mlir_arm_runner_utils
+
+// RUN: rm -f %t && %{compile} && %{run} | FileCheck %s
+
+#packed_maps = [
+  affine_map<(d0, d1, d2) -> (d0, d2)>,
+  affine_map<(d0, d1, d2) -> (d1, d2)>,
+  affine_map<(d0, d1, d2) -> (d0, d1)>
+]
+
+func.func private @setArmVLBits(%bits : i32)
+
+func.func @main() {
+  %c128 = arith.constant 128 : i32
+  func.call @setArmVLBits(%c128) : (i32) -> ()
+
+  %c0 = arith.constant 0 : index
+  %c0_i32 = arith.constant 0 : i32
+  %c0_i8 = arith.constant 0 : i8
+
+// Accumulator test data
+  %acc_cst = arith.constant dense<[[-44,  20,  44, -46],
+                                   [ -8,  25, -34,  26],
+                                   [-20, -36,  -3,  39],
+                                   [-48, -31, -25, -21]]> : vector<4x4xi32>
+  %acc_m = memref.alloca() : memref<4x4xi32>
+  vector.transfer_write %acc_cst, %acc_m[%c0, %c0] : vector<4x4xi32>, memref<4x4xi32>
+
+  %acc_m1 = memref.collapse_shape %acc_m [[0, 1]] : memref<4x4xi32> into memref<16xi32>
+  %acc_flat = vector.transfer_read %acc_m1[%c0], %c0_i32 {in_bounds = [true]} : memref<16xi32>, vector<[16]xi32>
+  %acc = vector.shape_cast %acc_flat : vector<[16]xi32> to vector<4x[4]xi32>
+
+  vector.print str "ACC:\n"
+  %acc0 = vector.extract %acc[0] : vector<[4]xi32> from vector<4x[4]xi32>
+  %acc1 = vector.extract %acc[1] : vector<[4]xi32> from vector<4x[4]xi32>
+  %acc2 = vector.extract %acc[2] : vector<[4]xi32> from vector<4x[4]xi32>
+  %acc3 = vector.extract %acc[3] : vector<[4]xi32> from vector<4x[4]xi32>
+  vector.print %acc0 : vector<[4]xi32>
+  vector.print %acc1 : vector<[4]xi32>
+  vector.print %acc2 : vector<[4]xi32>
+  vector.print %acc3 : vector<[4]xi32>
+
+  // LHS test data
+  %lhs_cst = arith.constant dense<[[-35, -27, -36, -31,  23, -34,  -8, -33],
+                                   [-20,  17, -32, -47,  37,  22,  -7, -21],
+                                   [ -7, -35,  20,  -4,  39,  46, -23,  40],
+                                   [ 40,  27,  37,  43,  38,  -6,  37,  49]]> : vector<4x8xi8>
+
+  %lhs_m = memref.alloca() : memref<4x8xi8>
+  vector.transfer_write %lhs_cst, %lhs_m[%c0, %c0] : vector<4x8xi8>, memref<4x8xi8>
+  %lhs = vector.transfer_read %lhs_m[%c0, %c0], %c0_i8 : memref<4x8xi8>, vector<4x8xi8>
+
+  vector.print str "LHS:\n"
+  %lhs0 = vector.extract %lhs[0] : vector<8xi8> from vector<4x8xi8>
+  %lhs1 = vector.extract %lhs[1] : vector<8xi8> from vector<4x8xi8>
+  %lhs2 = vector.extract %lhs[2] : vector<8xi8> from vector<4x8xi8>
+  %lhs3 = vector.extract %lhs[3] : vector<8xi8> from vector<4x8xi8>
+  vector.print %lhs0 : vector<8xi8>
+  vector.print %lhs1 : vector<8xi8>
+  vector.print %lhs2 : vector<8xi8>
+  vector.print %lhs3 : vector<8xi8>
+
+  // RHS test data
+  %rhs_cst = arith.constant dense<[[125, 171, 138, 187, 108, 175,  82,  99],
+                                   [221,  25, 164,  97, 156, 221, 218, 177],
+                                   [171, 160, 219, 191, 144,  45, 161, 210],
+                                   [223, 165, 123,  99, 108,  86,  37,  92]]> : vector<4x8xi8>
+
+  %rhs_m = memref.alloca() : memref<4x8xi8>
+  vector.transfer_write %rhs_cst, %rhs_m[%c0, %c0] : vector<4x8xi8>, memref<4x8xi8>
+
+  %rhs_m1 = memref.collapse_shape %rhs_m [[0, 1]] : memref<4x8xi8> into memref<32xi8>
+  %rhs_flat = vector.transfer_read %rhs_m1[%c0], %c0_i8 {in_bounds = [true]} : memref<32xi8>, vector<[32]xi8>
+
+  vector.print str "RHS:\n"
+  %rhs0 = vector.scalable.extract %rhs_flat[0] : vector<[16]xi8> from vector<[32]xi8>
+  %rhs1 = vector.scalable.extract %rhs_flat[16] : vector<[16]xi8> from vector<[32]xi8>
+  vector.print %rhs0 : vector<[16]xi8>
+  vector.print %rhs1 : vector<[16]xi8>
+
+  %rhs = vector.shape_cast %rhs_flat : vector<[32]xi8> to vector<[4]x8xi8>
+
+  // Matrix multiplication
+  %0 = arith.extsi %lhs : vector<4x8xi8> to vector<4x8xi32>
+  %1 = arith.extui %rhs : vector<[4]x8xi8> to vector<[4]x8xi32>
+  %2 = vector.contract {indexing_maps = #packed_maps,
+                        iterator_types = ["parallel", "parallel", "reduction"],
+                        kind = #vector.kind<add>} %0, %1, %acc
+    : vector<4x8xi32>, vector<[4]x8xi32> into vector<4x[4]xi32>
+
+  // Display the result of the multiplication
+  vector.print str "Result:\n"
+  %u0 = vector.extract %2[0] : vector<[4]xi32> from vector<4x[4]xi32>
+  %u1 = vector.extract %2[1] : vector<[4]xi32> from vector<4x[4]xi32>
+  %u2 = vector.extract %2[2] : vector<[4]xi32> from vector<4x[4]xi32>
+  %u3 = vector.extract %2[3] : vector<[4]xi32> from vector<4x[4]xi32>
+  vector.print %u0 : vector<[4]xi32>
+  vector.print %u1 : vector<[4]xi32>
+  vector.print %u2 : vector<[4]xi32>
+  vector.print %u3 : vector<[4]xi32>
+
+// CHECK: ( -27190, -28812, -30502, -23575 )
+// CHECK: (  -7613,  -8386, -15938,  -6521 )
+// CHECK: (   9468,  18750,   9199,   5764 )
+// CHECK: (  33655,  41064,  48900,  31627 )
+  return
+}
+
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-ummla-4x8x4.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-ummla-4x8x4.mlir
new file mode 100644
index 0000000000000..b5a7675f59881
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-ummla-4x8x4.mlir
@@ -0,0 +1,119 @@
+// REQUIRES: arm-emulator
+
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE:   --convert-vector-to-scf --convert-scf-to-cf  --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' \
+// DEFINE:   --expand-strided-metadata --convert-to-llvm --finalize-memref-to-llvm  --reconcile-unrealized-casts \
+// DEFINE: -o %t
+
+// DEFINE: %{entry_point} = main
+
+// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void  --march=aarch64 --mattr="+sve,+i8mm" \
+// DEFINE:    -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%native_mlir_arm_runner_utils
+
+// RUN: rm -f %t && %{compile} && %{run} | FileCheck %s
+
+#packed_maps = [
+  affine_map<(d0, d1, d2) -> (d0, d2)>,
+  affine_map<(d0, d1, d2) -> (d1, d2)>,
+  affine_map<(d0, d1, d2) -> (d0, d1)>
+]
+
+func.func private @setArmVLBits(%bits : i32)
+
+func.func @main() {
+
+  %c128 = arith.constant 128 : i32
+  func.call @setArmVLBits(%c128) : (i32) -> ()
+
+  %c0 = arith.constant 0 : index
+  %c0_i32 = arith.constant 0 : i32
+  %c0_i8 = arith.constant 0 : i8
+
+
+// Accumulator test data
+  %acc_cst = arith.constant dense<[[16, 16, 48, 40],
+                                   [40, 24, 35, 12],
+                                   [33, 24, 29, 19],
+                                   [28, 13, 33, 18]]> : vector<4x4xi32>
+  %acc_m = memref.alloca() : memref<4x4xi32>
+  vector.transfer_write %acc_cst, %acc_m[%c0, %c0] : vector<4x4xi32>, memref<4x4xi32>
+
+  %acc_m1 = memref.collapse_shape %acc_m [[0, 1]] : memref<4x4xi32> into memref<16xi32>
+  %acc_flat = vector.transfer_read %acc_m1[%c0], %c0_i32 {in_bounds = [true]} : memref<16xi32>, vector<[16]xi32>
+  %acc = vector.shape_cast %acc_flat : vector<[16]xi32> to vector<4x[4]xi32>
+  
+  vector.print str "ACC:\n"
+  %acc0 = vector.extract %acc[0] : vector<[4]xi32> from vector<4x[4]xi32>
+  %acc1 = vector.extract %acc[1] : vector<[4]xi32> from vector<4x[4]xi32>
+  %acc2 = vector.extract %acc[2] : vector<[4]xi32> from vector<4x[4]xi32>
+  %acc3 = vector.extract %acc[3] : vector<[4]xi32> from vector<4x[4]xi32>
+  vector.print %acc0 : vector<[4]xi32>
+  vector.print %acc1 : vector<[4]xi32>
+  vector.print %acc2 : vector<[4]xi32>
+  vector.print %acc3 : vector<[4]xi32>
+
+  // LHS test data
+  %lhs_cst = arith.constant dense<[[35, 42, 37, 49, 36, 36, 23, 33],
+                                   [39, 34, 33, 45, 43, 10, 44, 47],
+                                   [18, 35, 29, 25, 36, 33, 28, 29],
+                                   [26, 49, 43, 32, 27, 16, 45, 33]]> : vector<4x8xi8>
+
+  %lhs_m = memref.alloca() : memref<4x8xi8>
+  vector.transfer_write %lhs_cst, %lhs_m[%c0, %c0] : vector<4x8xi8>, memref<4x8xi8>
+  %lhs = vector.transfer_read %lhs_m[%c0, %c0], %c0_i8 : memref<4x8xi8>, vector<4x8xi8>
+
+  vector.print str "LHS:\n"
+  %lhs0 = vector.extract %lhs[0] : vector<8xi8> from vector<4x8xi8>
+  %lhs1 = vector.extract %lhs[1] : vector<8xi8> from vector<4x8xi8>
+  %lhs2 = vector.extract %lhs[2] : vector<8xi8> from vector<4x8xi8>
+  %lhs3 = vector.extract %lhs[3] : vector<8xi8> from vector<4x8xi8>
+  vector.print %lhs0 : vector<8xi8>
+  vector.print %lhs1 : vector<8xi8>
+  vector.print %lhs2 : vector<8xi8>
+  vector.print %lhs3 : vector<8xi8>
+
+  // RHS test data
+  %rhs_cst = arith.constant dense<[[18, 31, 37, 35, 44, 22, 37, 28],
+                                   [21, 22, 49, 39, 30, 28, 35, 37],
+                                   [21, 47, 39, 35, 23, 43, 24, 49],
+                                   [49, 49, 40, 32, 37, 20, 47, 40]]> : vector<4x8xi8>
+
+  %rhs_m = memref.alloca() : memref<4x8xi8>
+  vector.transfer_write %rhs_cst, %rhs_m[%c0, %c0] : vector<4x8xi8>, memref<4x8xi8>
+
+  %rhs_m1 = memref.collapse_shape %rhs_m [[0, 1]] : memref<4x8xi8> into memref<32xi8>
+  %rhs_flat = vector.transfer_read %rhs_m1[%c0], %c0_i8 {in_bounds = [true]} : memref<32xi8>, vector<[32]xi8>
+
+  vector.print str "RHS:\n"
+  %rhs0 = vector.scalable.extract %rhs_flat[0] : vector<[16]xi8> from vector<[32]xi8>
+  %rhs1 = vector.scalable.extract %rhs_flat[16] : vector<[16]xi8> from vector<[32]xi8>
+  vector.print %rhs0 : vector<[16]xi8>
+  vector.print %rhs1 : vector<[16]xi8>
+
+  %rhs = vector.shape_cast %rhs_flat : vector<[32]xi8> to vector<[4]x8xi8>
+  
+  // Matrix multiplication
+  %0 = arith.extui %lhs : vector<4x8xi8> to vector<4x8xi32>
+  %1 = arith.extui %rhs : vector<[4]x8xi8> to vector<[4]x8xi32>
+  %2 = vector.contract {indexing_maps = #packed_maps,
+                        iterator_types = ["parallel", "parallel", "reduction"],
+                        kind = #vector.kind<add>} %0, %1, %acc
+    : vector<4x8xi32>, vector<[4]x8xi32> into vector<4x[4]xi32>
+
+  // Display the result of the multiplication
+  vector.print str "Result:\n"
+  %u0 = vector.extract %2[0] : vector<[4]xi32> from vector<4x[4]xi32>
+  %u1 = vector.extract %2[1] : vector<[4]xi32> from vector<4x[4]xi32>
+  %u2 = vector.extract %2[2] : vector<[4]xi32> from vector<4x[4]xi32>
+  %u3 = vector.extract %2[3] : vector<[4]xi32> from vector<4x[4]xi32>
+  vector.print %u0 : vector<[4]xi32>
+  vector.print %u1 : vector<[4]xi32>
+  vector.print %u2 : vector<[4]xi32>
+  vector.print %u3 : vector<[4]xi32>
+
+// CHECK: ( 9183, 9513, 10460, 11314 )
+// CHECK: ( 9648, 9812, 10092, 12088 )
+// CHECK: ( 7548, 7625,  8398,  9044 )
+// CHECK: ( 8855, 9046,  9685, 11191 )
+  return
+}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-usmmla-4x8x4.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-usmmla-4x8x4.mlir
new file mode 100644
index 0000000000000..a25a51dd7018c
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-usmmla-4x8x4.mlir
@@ -0,0 +1,117 @@
+// REQUIRES: arm-emulator
+
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE:   --convert-vector-to-scf --convert-scf-to-cf  --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' \
+// DEFINE:   --expand-strided-metadata --convert-to-llvm --finalize-memref-to-llvm  --reconcile-unrealized-casts \
+// DEFINE: -o %t
+
+// DEFINE: %{entry_point} = main
+
+// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void  --march=aarch64 --mattr="+sve,+i8mm" \
+// DEFINE:    -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%native_mlir_arm_runner_utils
+
+// RUN: rm -f %t && %{compile} && %{run} | FileCheck %s
+
+#packed_maps = [
+  affine_map<(d0, d1, d2) -> (d0, d2)>,
+  affine_map<(d0, d1, d2) -> (d1, d2)>,
+  affine_map<(d0, d1, d2) -> (d0, d1)>
+]
+
+func.func private @setArmVLBits(%bits : i32)
+
+func.func @main() {
+  %c128 = arith.constant 128 : i32
+  func.call @setArmVLBits(%c128) : (i32) -> ()
+
+  %c0 = arith.constant 0 : index
+  %c0_i32 = arith.constant 0 : i32
+  %c0_i8 = arith.constant 0 : i8
+
+// Accumulator test data
+  %acc_cst = arith.constant dense<[[-44,  20,  44, -46],
+                                   [ -8,  25, -34,  26],
+                                   [-20, -36,  -3,  39],
+                                   [-48, -31, -25, -21]]> : vector<4x4xi32>
+  %acc_m = memref.alloca() : memref<4x4xi32>
+  vector.transfer_write %acc_cst, %acc_m[%c0, %c0] : vector<4x4xi32>, memref<4x4xi32>
+
+  %acc_m1 = memref.collapse_shape %acc_m [[0, 1]] : memref<4x4xi32> into memref<16xi32>
+  %acc_flat = vector.transfer_read %acc_m1[%c0], %c0_i32 {in_bounds = [true]} : memref<16xi32>, vector<[16]xi32>
+  %acc = vector.shape_cast %acc_flat : vector<[16]xi32> to vector<4x[4]xi32>
+
+  vector.print str "ACC:\n"
+  %acc0 = vector.extract %acc[0] : vector<[4]xi32> from vector<4x[4]xi32>
+  %acc1 = vector.extract %acc[1] : vector<[4]xi32> from vector<4x[4]xi32>
+  %acc2 = vector.extract %acc[2] : vector<[4]xi32> from vector<4x[4]xi32>
+  %acc3 = vector.extract %acc[3] : vector<[4]xi32> from vector<4x[4]xi32>
+  vector.print %acc0 : vector<[4]xi32>
+  vector.print %acc1 : vector<[4]xi32>
+  vector.print %acc2 : vector<[4]xi32>
+  vector.print %acc3 : vector<[4]xi32>
+
+  // LHS test data
+  %lhs_cst = arith.constant dense<[[153, 161,  24, 157, 211, 154,  52,  27],
+                                   [168,  77, 136, 124, 249,  28,  13, 122],
+                                   [ 97,  82, 181,  39,  53,  25,  80, 240],
+                                   [184, 227, 106, 165, 126, 113, 121, 228]]> : vector<4x8xi8>
+
+  %lhs_m = memref.alloca() : memref<4x8xi8>
+  vector.transfer_write %lhs_cst, %lhs_m[%c0, %c0] : vector<4x8xi8>, memref<4x8xi8>
+  %lhs = vector.transfer_read %lhs_m[%c0, %c0], %c0_i8 : memref<4x8xi8>, vector<4x8xi8>
+
+  vector.print str "LHS:\n"
+  %lhs0 = vector.extract %lhs[0] : vector<8xi8> from vector<4x8xi8>
+  %lhs1 = vector.extract %lhs[1] : vector<8xi8> from vector<4x8xi8>
+  %lhs2 = vector.extract %lhs[2] : vector<8xi8> from vector<4x8xi8>
+  %lhs3 = vector.extract %lhs[3] : vector<8xi8> from vector<4x8xi8>
+  vector.print %lhs0 : vector<8xi8>
+  vector.print %lhs1 : vector<8xi8>
+  vector.print %lhs2 : vector<8xi8>
+  vector.print %lhs3 : vector<8xi8>
+
+  // RHS test data
+  %rhs_cst = arith.constant dense<[[ 40,  27,  37,  43,  38,  -6,  37,  49],
+                                   [-17, -50,  -1,  48, -13,  22,  39,  33],
+                                   [-35, -24,  37, -32,  33,  30, -11, -17],
+                                   [-28,  31,   3, -44, -15, -27,  22,  35]]> : vector<4x8xi8>
+
+  %rhs_m = memref.alloca() : memref<4x8xi8>
+  vector.transfer_write %rhs_cst, %rhs_m[%c0, %c0] : vector<4x8xi8>, memref<4x8xi8>
+
+  %rhs_m1 = memref.collapse_shape %rhs_m [[0, 1]] : memref<4x8xi8> into memref<32xi8>
+  %rhs_flat = vector.transfer_read %rhs_m1[%c0], %c0_i8 {in_bounds = [true]} : memref<32xi8>, vector<[32]xi8>
+
+  vector.print str "RHS:\n"
+  %rhs0 = vector.scalable.extract %rhs_flat[0] : vector<[16]xi8> from vector<[32]xi8>
+  %rhs1 = vector.scalable.extract %rhs_flat[16] : vector<[16]xi8> from vector<[32]xi8>
+  vector.print %rhs0 : vector<[16]xi8>
+  vector.print %rhs1 : vector<[16]xi8>
+
+  %rhs = vector.shape_cast %rhs_flat : vector<[32]xi8> to vector<[4]x8xi8>
+
+  // Matrix multiplication
+  %0 = arith.extui %lhs : vector<4x8xi8> to vector<4x8xi32>
+  %1 = arith.extsi %rhs : vector<[4]x8xi8> to vector<[4]x8xi32>
+  %2 = vector.contract {indexing_maps = #packed_maps,
+                        iterator_types = ["parallel", "parallel", "reduction"],
+                        kind = #vector.kind<add>} %0, %1, %acc
+    : vector<4x8xi32>, vector<[4]x8xi32> into vector<4x[4]xi32>
+
+  // Display the result of the multiplication
+  vector.print str "Result:\n"
+  %u0 = vector.extract %2[0] : vector<[4]xi32> from vector<4x[4]xi32>
+  %u1 = vector.extract %2[1] : vector<[4]xi32> from vector<4x[4]xi32>
+  %u2 = vector.extract %2[2] : vector<[4]xi32> from vector<4x[4]xi32>
+  %u3 = vector.extract %2[3] : vector<[4]xi32> from vector<4x[4]xi32>
+  vector.print %u0 : vector<[4]xi32>
+  vector.print %u1 : vector<[4]xi32>
+  vector.print %u2 : vector<[4]xi32>
+  vector.print %u3 : vector<[4]xi32>
+
+ // CHECK: ( 28403,  445,  -2759, -11409 )
+ // CHECK: ( 34908, 1047,    142,  -7274 )
+ // CHECK: ( 31032, 6807,  -2378,   7382 )
+ // CHECK: ( 44217, 6396, -10930,    623 )
+  return
+}



More information about the llvm-branch-commits mailing list