[llvm-branch-commits] [mlir] [MLIR] Add apply_patterns.vector.arm_sve.lower_contraction TD Op (PR #140572)
Momchil Velikov via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Mon May 19 09:43:39 PDT 2025
https://github.com/momchil-velikov created https://github.com/llvm/llvm-project/pull/140572
None
>From 251f93ea5b87acefac1fbcd6951c3b7870eff83c Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Fri, 16 May 2025 15:47:36 +0000
Subject: [PATCH] [MLIR] Add apply_patterns.vector.arm_sve.lower_contraction TD
Op
---
.../mlir/Dialect/ArmSVE/CMakeLists.txt | 1 +
.../TransformOps/ArmSVEVectorTransformOps.h | 31 +++
.../TransformOps/ArmSVEVectorTransformOps.td | 26 ++
.../ArmSVE/TransformOps/CMakeLists.txt | 6 +
mlir/include/mlir/InitAllExtensions.h | 2 +
mlir/lib/Dialect/ArmSVE/CMakeLists.txt | 1 +
.../TransformOps/ArmSVEVectorTransformOps.cpp | 54 ++++
.../ArmSVE/TransformOps/CMakeLists.txt | 19 ++
.../Vector/CPU/ArmSVE/vector-smmla.mlir | 263 ++++++++++--------
.../Vector/CPU/ArmSVE/vector-summla.mlir | 123 ++++----
.../Vector/CPU/ArmSVE/vector-ummla.mlir | 138 +++++----
.../Vector/CPU/ArmSVE/vector-usmmla.mlir | 138 +++++----
12 files changed, 512 insertions(+), 290 deletions(-)
create mode 100644 mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h
create mode 100644 mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.td
create mode 100644 mlir/include/mlir/Dialect/ArmSVE/TransformOps/CMakeLists.txt
create mode 100644 mlir/lib/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.cpp
create mode 100644 mlir/lib/Dialect/ArmSVE/TransformOps/CMakeLists.txt
diff --git a/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt
index 9f57627c321fb..cb1e9d01821a2 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt
@@ -1,2 +1,3 @@
add_subdirectory(IR)
add_subdirectory(Transforms)
+add_subdirectory(TransformOps)
diff --git a/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h b/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h
new file mode 100644
index 0000000000000..7f22cd1fe6435
--- /dev/null
+++ b/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h
@@ -0,0 +1,31 @@
+//===- ArmSVEVectorTransformOps.h - Vector transform ops --------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_ARM_SVE_VECTOR_TRANSFORMOPS_H
+#define MLIR_DIALECT_ARM_SVE_VECTOR_TRANSFORMOPS_H
+
+#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
+#include "mlir/IR/OpImplementation.h"
+
+//===----------------------------------------------------------------------===//
+// ArmSVE Vector Transform Operations
+//===----------------------------------------------------------------------===//
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h.inc"
+
+namespace mlir {
+class DialectRegistry;
+
+namespace arm_sve {
+void registerTransformDialectExtension(DialectRegistry ®istry);
+
+} // namespace arm_sve
+} // namespace mlir
+
+#endif // MLIR_DIALECT_ARM_SVE_VECTOR_TRANSFORMOPS_H
diff --git a/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.td b/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.td
new file mode 100644
index 0000000000000..81b59340f3b0d
--- /dev/null
+++ b/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.td
@@ -0,0 +1,26 @@
+//===- ArmSVEVectorTransformOps.td - Arm SVE transform ops--*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef ARMSVE_VECTOR_TRANSFORM_OPS
+#define ARMSVE_VECTOR_TRANSFORM_OPS
+
+include "mlir/Dialect/Transform/IR/TransformAttrs.td"
+include "mlir/Dialect/Transform/IR/TransformDialect.td"
+include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
+
+def ApplyArmSVELowerContractionPatternsOp
+ : Op<Transform_Dialect, "apply_patterns.vector.arm_sve.lower_contraction",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ let description = [{
+ Indicates that vector contraction-like operations should be lowered to
+ finer-grained vector primitives using the ArmSVE dialect.
+ }];
+
+ let assemblyFormat = "attr-dict";
+}
+
+#endif // ARMSVE_VECTOR_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/ArmSVE/TransformOps/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmSVE/TransformOps/CMakeLists.txt
new file mode 100644
index 0000000000000..ce8d8fea7f188
--- /dev/null
+++ b/mlir/include/mlir/Dialect/ArmSVE/TransformOps/CMakeLists.txt
@@ -0,0 +1,6 @@
+set(LLVM_TARGET_DEFINITIONS ArmSVEVectorTransformOps.td)
+mlir_tablegen(ArmSVEVectorTransformOps.h.inc -gen-op-decls)
+mlir_tablegen(ArmSVEVectorTransformOps.cpp.inc -gen-op-defs)
+add_public_tablegen_target(MLIRArmSVEVectorTransformOpsIncGen)
+
+add_mlir_doc(ArmSVEVectorTransformOps ArmSVEVectorTransformOps Dialects/ -gen-op-doc)
diff --git a/mlir/include/mlir/InitAllExtensions.h b/mlir/include/mlir/InitAllExtensions.h
index 37e4904cb48ed..767c7099accbb 100644
--- a/mlir/include/mlir/InitAllExtensions.h
+++ b/mlir/include/mlir/InitAllExtensions.h
@@ -34,6 +34,7 @@
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Dialect/AMX/Transforms.h"
#include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h"
+#include "mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h"
#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h"
#include "mlir/Dialect/DLTI/TransformOps/DLTITransformOps.h"
#include "mlir/Dialect/Func/Extensions/AllExtensions.h"
@@ -106,6 +107,7 @@ inline void registerAllExtensions(DialectRegistry ®istry) {
transform::registerLoopExtension(registry);
transform::registerPDLExtension(registry);
vector::registerTransformDialectExtension(registry);
+ arm_sve::registerTransformDialectExtension(registry);
// Translation extensions need to be registered by calling
// `registerAllToLLVMIRTranslations` (see All.h).
diff --git a/mlir/lib/Dialect/ArmSVE/CMakeLists.txt b/mlir/lib/Dialect/ArmSVE/CMakeLists.txt
index 9f57627c321fb..cb1e9d01821a2 100644
--- a/mlir/lib/Dialect/ArmSVE/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSVE/CMakeLists.txt
@@ -1,2 +1,3 @@
add_subdirectory(IR)
add_subdirectory(Transforms)
+add_subdirectory(TransformOps)
diff --git a/mlir/lib/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.cpp b/mlir/lib/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.cpp
new file mode 100644
index 0000000000000..b2ca4fc1eaa8c
--- /dev/null
+++ b/mlir/lib/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.cpp
@@ -0,0 +1,54 @@
+//===- ArmSVEVectorTransformOps.cpp - Implementation transform ops -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h"
+
+#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
+#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// Apply...PatternsOp
+//===----------------------------------------------------------------------===//
+
+void transform::ApplyArmSVELowerContractionPatternsOp::populatePatterns(
+ RewritePatternSet &patterns) {
+ mlir::populateLowerContractionToSVEI8MMPatternPatterns(patterns);
+}
+
+//===----------------------------------------------------------------------===//
+// Transform op registration
+//===----------------------------------------------------------------------===//
+
+namespace {
+class ArmSVEVectorTransformDialectExtension
+ : public transform::TransformDialectExtension<
+ ArmSVEVectorTransformDialectExtension> {
+public:
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+ ArmSVEVectorTransformDialectExtension)
+
+ ArmSVEVectorTransformDialectExtension() {
+ declareGeneratedDialect<arm_sve::ArmSVEDialect>();
+ registerTransformOps<
+#define GET_OP_LIST
+#include "mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.cpp.inc"
+ >();
+ }
+};
+} // namespace
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.cpp.inc"
+
+void mlir::arm_sve::registerTransformDialectExtension(
+ DialectRegistry ®istry) {
+ registry.addExtensions<ArmSVEVectorTransformDialectExtension>();
+}
diff --git a/mlir/lib/Dialect/ArmSVE/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/ArmSVE/TransformOps/CMakeLists.txt
new file mode 100644
index 0000000000000..8771826e08913
--- /dev/null
+++ b/mlir/lib/Dialect/ArmSVE/TransformOps/CMakeLists.txt
@@ -0,0 +1,19 @@
+add_mlir_dialect_library(MLIRArmSVEVectorTransformOps
+ ArmSVEVectorTransformOps.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmSVE/TransformOps
+
+ DEPENDS
+ MLIRArmSVEVectorTransformOpsIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIRLLVMCommonConversion
+ MLIRLLVMDialect
+ MLIRVectorDialect
+ MLIRTransformDialect
+ MLIRArmSVEDialect
+ MLIRArmSVETransforms
+ )
+
\ No newline at end of file
diff --git a/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir b/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir
index af0cb37e2d249..3991038761e8d 100644
--- a/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir
+++ b/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-smmla.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' | FileCheck %s
+// RUN: mlir-opt %s --transform-interpreter | FileCheck %s
#attrs = {
indexing_maps = [
@@ -12,77 +12,82 @@
// CHECK-LABEL: @test_vector_contract_to_smmla
+// CHECK-SAME: %[[LHS:arg0]]: vector<4x8xi8>
+// CHECK-SAME: %[[RHS:arg1]]: vector<[4]x8xi8>
+// CHECK-SAME: %[[ACC:arg2]]: vector<4x[4]xi32>
+
+// CHECK: [[P0:[0-9]+]] = ub.poison : vector<[8]xi32>
+// CHECK-NEXT: [[P1:[0-9]+]] = ub.poison : vector<4x[4]xi32>
+// CHECK-NEXT: [[P2:[0-9]+]] = ub.poison : vector<[16]xi8>
+
// Extract LHS rows 0 and 1, concatenate, turn into scalable vector
-// CHECK: %[[T6:[0-9]+]] = llvm.extractvalue %[[T4:[0-9]+]][0] : !llvm.array<4 x vector<8xi8>>
-// CHECK-NEXT: %[[T7:[0-9]+]] = llvm.extractvalue %[[T4]][1] : !llvm.array<4 x vector<8xi8>>
-// CHECK-NEXT: %[[T8:[0-9]+]] = llvm.shufflevector %[[T6]], %[[T7]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8>
-// CHECK-NEXT: %[[T9:[0-9]+]] = llvm.intr.vector.insert %[[T8]], %[[T0:[0-9+]]][0] : vector<16xi8> into vector<[16]xi8>
+// CHECK: %[[T3:[0-9]+]] = vector.extract %[[LHS]][0] : vector<8xi8> from vector<4x8xi8>
+// CHECK-NEXT: %[[T4:[0-9]+]] = vector.extract %[[LHS]][1] : vector<8xi8> from vector<4x8xi8>
+// CHECK-NEXT: %[[T5:[0-9]+]] = vector.shuffle %[[T3]], %[[T4]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8>, vector<8xi8>
+// CHECK-NEXT: %[[T6:[0-9]+]] = vector.scalable.insert %[[T5]], %[[P2]][0] : vector<16xi8> into vector<[16]xi8>
-// Replicate across the entire length of the scalabale vector
-// CHECK-NEXT: %[[T10:[0-9]+]] = "arm_sve.intr.dupq_lane"(%[[T9]]) <{lane = 0 : i64}> : (vector<[16]xi8>) -> vector<[16]xi8>
+// Replicate across the entire length of the scalable vector
+// CHECK-NEXT: %[[LHS_0:[0-9]+]] = arm_sve.dupq_lane %[[T6]][0] : vector<[16]xi8>
-// Same for LHS rows 2 and 4
-// CHECK-NEXT: %[[T11:[0-9]+]] = llvm.extractvalue %[[T4]][2] : !llvm.array<4 x vector<8xi8>>
-// CHECK-NEXT: %[[T12:[0-9]+]] = llvm.extractvalue %[[T4]][3] : !llvm.array<4 x vector<8xi8>>
-// CHECK-NEXT: %[[T13:[0-9]+]] = llvm.shufflevector %[[T11]], %[[T12]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8>
-// CHECK-NEXT: %[[T14:[0-9]+]] = llvm.intr.vector.insert %[[T13]], %[[T0]][0] : vector<16xi8> into vector<[16]xi8>
-// CHECK-NEXT: %[[T15:[0-9]+]] = "arm_sve.intr.dupq_lane"(%[[T14]]) <{lane = 0 : i64}> : (vector<[16]xi8>) -> vector<[16]xi8>
+// Same for LHS rows 2 and 3
+// CHECK-NEXT: %[[T8:[0-9]+]] = vector.extract %[[LHS]][2] : vector<8xi8> from vector<4x8xi8>
+// CHECK-NEXT: %[[T9:[0-9]+]] = vector.extract %[[LHS]][3] : vector<8xi8> from vector<4x8xi8>
+// CHECK-NEXT: %[[T10:[0-9]+]] = vector.shuffle %[[T8]], %[[T9]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8>, vector<8xi8>
+// CHECK-NEXT: %[[T11:[0-9]+]] = vector.scalable.insert %[[T10]], %[[P2]][0] : vector<16xi8> into vector<[16]xi8>
+// CHECK-NEXT: %[[LHS_1:[0-9]+]] = arm_sve.dupq_lane %[[T11]][0] : vector<[16]xi8>
// Extract sub-tiles from the RHS
-// CHECK-NEXT: %[[T16:[0-9]+]] = vector.shape_cast %arg1 : vector<[4]x8xi8> to vector<[32]xi8>
-// CHECK-NEXT: %[[T17:[0-9]+]] = llvm.intr.vector.extract %[[T16]][0] : vector<[16]xi8> from vector<[32]xi8>
-// CHECK-NEXT: %[[T18:[0-9]+]] = llvm.intr.vector.extract %[[T16]][16] : vector<[16]xi8> from vector<[32]xi8>
+// CHECK-NEXT: %[[T13:[0-9]+]] = vector.shape_cast %[[RHS]] : vector<[4]x8xi8> to vector<[32]xi8>
+// CHECK-NEXT: %[[RHS_0:[0-9]+]] = vector.scalable.extract %[[T13]][0] : vector<[16]xi8> from vector<[32]xi8>
+// CHECK-NEXT: %[[RHS_1:[0-9]+]] = vector.scalable.extract %[[T13]][16] : vector<[16]xi8> from vector<[32]xi8>
// Extract accumulator rows 0 and 1 and pack (into "registers")
-// CHECK-NEXT: %[[T19:[0-9]+]] = llvm.extractvalue %[[T3:[0-9]+]][0] : !llvm.array<4 x vector<[4]xi32>>
-// CHECK-NEXT: %[[T20:[0-9]+]] = llvm.extractvalue %[[T3]][1] : !llvm.array<4 x vector<[4]xi32>>
-// CHECK-NEXT: %[[T21:[0-9]+]] = llvm.bitcast %[[T19]] : vector<[4]xi32> to vector<[2]xi64>
-// CHECK-NEXT: %[[T22:[0-9]+]] = llvm.bitcast %[[T20]] : vector<[4]xi32> to vector<[2]xi64>
-// CHECK-NEXT: %[[T23:[0-9]+]] = "llvm.intr.vector.interleave2"(%[[T21]], %[[T22]]) : (vector<[2]xi64>, vector<[2]xi64>) -> vector<[4]xi64>
-// CHECK-NEXT: %[[T24:[0-9]+]] = llvm.bitcast %[[T23]] : vector<[4]xi64> to vector<[8]xi32>
-// CHECK-NEXT: %[[T25:[0-9]+]] = llvm.intr.vector.extract %[[T24]][0] : vector<[4]xi32> from vector<[8]xi32>
-// CHECK-NEXT: %[[T26:[0-9]+]] = llvm.intr.vector.extract %[[T24]][4] : vector<[4]xi32> from vector<[8]xi32>
-
-// Same for accumulator rows 2 and 3.
-// CHECK-NEXT: %[[T27:[0-9]+]] = llvm.extractvalue %[[T3]][2] : !llvm.array<4 x vector<[4]xi32>>
-// CHECK-NEXT: %[[T28:[0-9]+]] = llvm.extractvalue %[[T3]][3] : !llvm.array<4 x vector<[4]xi32>>
-// CHECK-NEXT: %[[T29:[0-9]+]] = llvm.bitcast %[[T27]] : vector<[4]xi32> to vector<[2]xi64>
-// CHECK-NEXT: %[[T30:[0-9]+]] = llvm.bitcast %[[T28]] : vector<[4]xi32> to vector<[2]xi64>
-// CHECK-NEXT: %[[T31:[0-9]+]] = "llvm.intr.vector.interleave2"(%[[T29]], %[[T30]]) : (vector<[2]xi64>, vector<[2]xi64>) -> vector<[4]xi64>
-// CHECK-NEXT: %[[T32:[0-9]+]] = llvm.bitcast %[[T31]] : vector<[4]xi64> to vector<[8]xi32>
-// CHECK-NEXT: %[[T33:[0-9]+]] = llvm.intr.vector.extract %[[T32]][0] : vector<[4]xi32> from vector<[8]xi32>
-// CHECK-NEXT: %[[T34:[0-9]+]] = llvm.intr.vector.extract %[[T32]][4] : vector<[4]xi32> from vector<[8]xi32>
+// CHECK-NEXT: %[[T16:[0-9]+]] = vector.extract %[[ACC]][0] : vector<[4]xi32> from vector<4x[4]xi32>
+// CHECK-NEXT: %[[T17:[0-9]+]] = vector.extract %[[ACC]][1] : vector<[4]xi32> from vector<4x[4]xi32>
+// CHECK-NEXT: %[[T18:[0-9]+]] = vector.bitcast %[[T16]] : vector<[4]xi32> to vector<[2]xi64>
+// CHECK-NEXT: %[[T19:[0-9]+]] = vector.bitcast %[[T17]] : vector<[4]xi32> to vector<[2]xi64>
+// CHECK-NEXT: %[[T20:[0-9]+]] = vector.interleave %[[T18]], %[[T19]] : vector<[2]xi64> -> vector<[4]xi64>
+// CHECK-NEXT: %[[T21:[0-9]+]] = vector.bitcast %[[T20]] : vector<[4]xi64> to vector<[8]xi32>
+// CHECK-NEXT: %[[ACC_0:[0-9]+]] = vector.scalable.extract %[[T21]][0] : vector<[4]xi32> from vector<[8]xi32>
+// CHECK-NEXT: %[[ACC_1:[0-9]+]] = vector.scalable.extract %[[T21]][4] : vector<[4]xi32> from vector<[8]xi32>
+
+// Same for accumulator rows 2 and 3
+// CHECK-NEXT: %[[T24:[0-9]+]] = vector.extract %[[ACC]][2] : vector<[4]xi32> from vector<4x[4]xi32>
+// CHECK-NEXT: %[[T25:[0-9]+]] = vector.extract %[[ACC]][3] : vector<[4]xi32> from vector<4x[4]xi32>
+// CHECK-NEXT: %[[T26:[0-9]+]] = vector.bitcast %[[T24]] : vector<[4]xi32> to vector<[2]xi64>
+// CHECK-NEXT: %[[T27:[0-9]+]] = vector.bitcast %[[T25]] : vector<[4]xi32> to vector<[2]xi64>
+// CHECK-NEXT: %[[T28:[0-9]+]] = vector.interleave %[[T26]], %[[T27]] : vector<[2]xi64> -> vector<[4]xi64>
+// CHECK-NEXT: %[[T29:[0-9]+]] = vector.bitcast %[[T28]] : vector<[4]xi64> to vector<[8]xi32>
+// CHECK-NEXT: %[[ACC_2:[0-9]+]] = vector.scalable.extract %[[T29]][0] : vector<[4]xi32> from vector<[8]xi32>
+// CHECK-NEXT: %[[ACC_3:[0-9]+]] = vector.scalable.extract %[[T29]][4] : vector<[4]xi32> from vector<[8]xi32>
// Do the sub-tile matrix multiplications
-// CHECK-NEXT: %[[T35:[0-9]+]] = "arm_sve.intr.smmla"(%[[T25]], %[[T10]], %[[T17]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32>
-// CHECK-NEXT: %[[T36:[0-9]+]] = "arm_sve.intr.smmla"(%[[T26]], %[[T10]], %[[T18]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32>
-// CHECK-NEXT: %[[T37:[0-9]+]] = "arm_sve.intr.smmla"(%[[T33]], %[[T15]], %[[T17]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32>
-// CHECK-NEXT: %[[T38:[0-9]+]] = "arm_sve.intr.smmla"(%[[T34]], %[[T15]], %[[T18]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32>
-
-// Unpack (from "registers") and insert in the output result rows 0 and 1
-// CHECK-NEXT: %[[T39:[0-9]+]] = llvm.intr.vector.insert %[[T35]], %[[T2:[0-9]+]][0] : vector<[4]xi32> into vector<[8]xi32>
-// CHECK-NEXT: %[[T40:[0-9]+]] = llvm.intr.vector.insert %[[T36]], %[[T39]][4] : vector<[4]xi32> into vector<[8]xi32>
-// CHECK-NEXT: %[[T41:[0-9]+]] = llvm.bitcast %[[T40]] : vector<[8]xi32> to vector<[4]xi64>
-// CHECK-NEXT: %[[T42:[0-9]+]] = "llvm.intr.vector.deinterleave2"(%[[T41]]) : (vector<[4]xi64>) -> !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)>
-// CHECK-NEXT: %[[T43:[0-9]+]] = llvm.extractvalue %[[T42]][0] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)>
-// CHECK-NEXT: %[[T44:[0-9]+]] = llvm.extractvalue %[[T42]][1] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)>
-// CHECK-NEXT: %[[T45:[0-9]+]] = llvm.bitcast %[[T43]] : vector<[2]xi64> to vector<[4]xi32>
-// CHECK-NEXT: %[[T46:[0-9]+]] = llvm.bitcast %[[T44]] : vector<[2]xi64> to vector<[4]xi32>
-// CHECK-NEXT: %[[T47:[0-9]+]] = llvm.insertvalue %[[T45]], %[[T5:[0-9]+]][0] : !llvm.array<4 x vector<[4]xi32>>
-// CHECK-NEXT: %[[T48:[0-9]+]] = llvm.insertvalue %[[T46]], %[[T47]][1] : !llvm.array<4 x vector<[4]xi32>>
+// CHECK-NEXT: %[[PACK_RES_00:[0-9]+]] = arm_sve.smmla %[[ACC_0]], %[[LHS_0]], %[[RHS_0]] : vector<[16]xi8> to vector<[4]xi32>
+// CHECK-NEXT: %[[PACK_RES_01:[0-9]+]] = arm_sve.smmla %[[ACC_1]], %[[LHS_0]], %[[RHS_1]] : vector<[16]xi8> to vector<[4]xi32>
+// CHECK-NEXT: %[[PACK_RES_10:[0-9]+]] = arm_sve.smmla %[[ACC_2]], %[[LHS_1]], %[[RHS_0]] : vector<[16]xi8> to vector<[4]xi32>
+// CHECK-NEXT: %[[PACK_RES_11:[0-9]+]] = arm_sve.smmla %[[ACC_3]], %[[LHS_1]], %[[RHS_1]] : vector<[16]xi8> to vector<[4]xi32>
+
+// Unpack (from "registers") and insert in the output result rows 0 and 1
+// CHECK-NEXT: %[[T36:[0-9]+]] = vector.scalable.insert %[[PACK_RES_00]], %[[P0]][0] : vector<[4]xi32> into vector<[8]xi32>
+// CHECK-NEXT: %[[T37:[0-9]+]] = vector.scalable.insert %[[PACK_RES_01]], %[[T36]][4] : vector<[4]xi32> into vector<[8]xi32>
+// CHECK-NEXT: %[[T38:[0-9]+]] = vector.bitcast %[[T37]] : vector<[8]xi32> to vector<[4]xi64>
+// CHECK-NEXT: %res1, %res2 = vector.deinterleave %[[T38]] : vector<[4]xi64> -> vector<[2]xi64>
+// CHECK-NEXT: %[[UNPACK_RES_0:[0-9]+]] = vector.bitcast %res1 : vector<[2]xi64> to vector<[4]xi32>
+// CHECK-NEXT: %[[UNPACK_RES_1:[0-9]+]] = vector.bitcast %res2 : vector<[2]xi64> to vector<[4]xi32>
+// CHECK-NEXT: %[[TMP_OUT_0:[0-9]+]] = vector.insert %[[UNPACK_RES_0]], %[[P1]] [0] : vector<[4]xi32> into vector<4x[4]xi32>
+// CHECK-NEXT: %[[TMP_OUT_1:[0-9]+]] = vector.insert %[[UNPACK_RES_1]], %[[TMP_OUT_0]] [1] : vector<[4]xi32> into vector<4x[4]xi32>
// Same for result rows 2 and 3
-// CHECK-NEXT: %[[T49:[0-9]+]] = llvm.intr.vector.insert %[[T37]], %[[T2]][0] : vector<[4]xi32> into vector<[8]xi32>
-// CHECK-NEXT: %[[T50:[0-9]+]] = llvm.intr.vector.insert %[[T38]], %[[T49]][4] : vector<[4]xi32> into vector<[8]xi32>
-// CHECK-NEXT: %[[T51:[0-9]+]] = llvm.bitcast %[[T50]] : vector<[8]xi32> to vector<[4]xi64>
-// CHECK-NEXT: %[[T52:[0-9]+]] = "llvm.intr.vector.deinterleave2"(%[[T51]]) : (vector<[4]xi64>) -> !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)>
-// CHECK-NEXT: %[[T53:[0-9]+]] = llvm.extractvalue %[[T52]][0] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)>
-// CHECK-NEXT: %[[T54:[0-9]+]] = llvm.extractvalue %[[T52]][1] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)>
-// CHECK-NEXT: %[[T55:[0-9]+]] = llvm.bitcast %[[T53]] : vector<[2]xi64> to vector<[4]xi32>
-// CHECK-NEXT: %[[T56:[0-9]+]] = llvm.bitcast %[[T54]] : vector<[2]xi64> to vector<[4]xi32>
-// CHECK-NEXT: %[[T57:[0-9]+]] = llvm.insertvalue %[[T55]], %[[T48]][2] : !llvm.array<4 x vector<[4]xi32>>
-// CHECK-NEXT: %[[T58:[0-9]+]] = llvm.insertvalue %[[T56]], %[[T57]][3] : !llvm.array<4 x vector<[4]xi32>>
-
+// CHECK-NEXT: %[[T43:[0-9]+]] = vector.scalable.insert %[[PACK_RES_10]], %[[P0]][0] : vector<[4]xi32> into vector<[8]xi32>
+// CHECK-NEXT: %[[T44:[0-9]+]] = vector.scalable.insert %[[PACK_RES_11]], %[[T43]][4] : vector<[4]xi32> into vector<[8]xi32>
+// CHECK-NEXT: %[[T45:[0-9]+]] = vector.bitcast %[[T44]] : vector<[8]xi32> to vector<[4]xi64>
+// CHECK-NEXT: %res1_0, %res2_1 = vector.deinterleave %[[T45]] : vector<[4]xi64> -> vector<[2]xi64>
+// CHECK-NEXT: %[[UNPACK_RES_2:[0-9]+]] = vector.bitcast %res1_0 : vector<[2]xi64> to vector<[4]xi32>
+// CHECK-NEXT: %[[UNPACK_RES_3:[0-9]+]] = vector.bitcast %res2_1 : vector<[2]xi64> to vector<[4]xi32>
+// CHECK-NEXT: %[[TMP_OUT_2:[0-9]+]] = vector.insert %[[UNPACK_RES_2]], %[[TMP_OUT_1]] [2] : vector<[4]xi32> into vector<4x[4]xi32>
+// CHECK-NEXT: %[[OUT:[0-9]+]] = vector.insert %[[UNPACK_RES_3]], %[[TMP_OUT_2]] [3] : vector<[4]xi32> into vector<4x[4]xi32>
+
+// CHECK-NEXT: return %[[OUT]] : vector<4x[4]xi32>
func.func @test_vector_contract_to_smmla(%lhs: vector<4x8xi8>,
%rhs: vector<[4]x8xi8>,
%acc: vector<4x[4]xi32>) -> vector<4x[4]xi32> {
@@ -97,76 +102,82 @@ func.func @test_vector_contract_to_smmla(%lhs: vector<4x8xi8>,
// CHECK-LABEL: @test_vector_contract_to_smmla_implicit_sext
+// CHECK-SAME: %[[LHS:arg0]]: vector<4x8xi8>
+// CHECK-SAME: %[[RHS:arg1]]: vector<[4]x8xi8>
+// CHECK-SAME: %[[ACC:arg2]]: vector<4x[4]xi32>
+
+// CHECK: [[P0:[0-9]+]] = ub.poison : vector<[8]xi32>
+// CHECK-NEXT: [[P1:[0-9]+]] = ub.poison : vector<4x[4]xi32>
+// CHECK-NEXT: [[P2:[0-9]+]] = ub.poison : vector<[16]xi8>
+
// Extract LHS rows 0 and 1, concatenate, turn into scalable vector
-// CHECK: %[[T6:[0-9]+]] = llvm.extractvalue %[[T4:[0-9]+]][0] : !llvm.array<4 x vector<8xi8>>
-// CHECK-NEXT: %[[T7:[0-9]+]] = llvm.extractvalue %[[T4]][1] : !llvm.array<4 x vector<8xi8>>
-// CHECK-NEXT: %[[T8:[0-9]+]] = llvm.shufflevector %[[T6]], %[[T7]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8>
-// CHECK-NEXT: %[[T9:[0-9]+]] = llvm.intr.vector.insert %[[T8]], %[[T0:[0-9+]]][0] : vector<16xi8> into vector<[16]xi8>
+// CHECK: %[[T3:[0-9]+]] = vector.extract %[[LHS]][0] : vector<8xi8> from vector<4x8xi8>
+// CHECK-NEXT: %[[T4:[0-9]+]] = vector.extract %[[LHS]][1] : vector<8xi8> from vector<4x8xi8>
+// CHECK-NEXT: %[[T5:[0-9]+]] = vector.shuffle %[[T3]], %[[T4]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8>, vector<8xi8>
+// CHECK-NEXT: %[[T6:[0-9]+]] = vector.scalable.insert %[[T5]], %[[P2]][0] : vector<16xi8> into vector<[16]xi8>
-// Replicate across the entire length of the scalabale vector
-// CHECK-NEXT: %[[T10:[0-9]+]] = "arm_sve.intr.dupq_lane"(%[[T9]]) <{lane = 0 : i64}> : (vector<[16]xi8>) -> vector<[16]xi8>
+// Replicate across the entire length of the scalable vector
+// CHECK-NEXT: %[[LHS_0:[0-9]+]] = arm_sve.dupq_lane %[[T6]][0] : vector<[16]xi8>
-// Same for LHS rows 2 and 4
-// CHECK-NEXT: %[[T11:[0-9]+]] = llvm.extractvalue %[[T4]][2] : !llvm.array<4 x vector<8xi8>>
-// CHECK-NEXT: %[[T12:[0-9]+]] = llvm.extractvalue %[[T4]][3] : !llvm.array<4 x vector<8xi8>>
-// CHECK-NEXT: %[[T13:[0-9]+]] = llvm.shufflevector %[[T11]], %[[T12]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8>
-// CHECK-NEXT: %[[T14:[0-9]+]] = llvm.intr.vector.insert %[[T13]], %[[T0]][0] : vector<16xi8> into vector<[16]xi8>
-// CHECK-NEXT: %[[T15:[0-9]+]] = "arm_sve.intr.dupq_lane"(%[[T14]]) <{lane = 0 : i64}> : (vector<[16]xi8>) -> vector<[16]xi8>
+// Same for LHS rows 2 and 3
+// CHECK-NEXT: %[[T8:[0-9]+]] = vector.extract %[[LHS]][2] : vector<8xi8> from vector<4x8xi8>
+// CHECK-NEXT: %[[T9:[0-9]+]] = vector.extract %[[LHS]][3] : vector<8xi8> from vector<4x8xi8>
+// CHECK-NEXT: %[[T10:[0-9]+]] = vector.shuffle %[[T8]], %[[T9]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8>, vector<8xi8>
+// CHECK-NEXT: %[[T11:[0-9]+]] = vector.scalable.insert %[[T10]], %[[P2]][0] : vector<16xi8> into vector<[16]xi8>
+// CHECK-NEXT: %[[LHS_1:[0-9]+]] = arm_sve.dupq_lane %[[T11]][0] : vector<[16]xi8>
// Extract sub-tiles from the RHS
-// CHECK-NEXT: %[[T16:[0-9]+]] = vector.shape_cast %arg1 : vector<[4]x8xi8> to vector<[32]xi8>
-// CHECK-NEXT: %[[T17:[0-9]+]] = llvm.intr.vector.extract %[[T16]][0] : vector<[16]xi8> from vector<[32]xi8>
-// CHECK-NEXT: %[[T18:[0-9]+]] = llvm.intr.vector.extract %[[T16]][16] : vector<[16]xi8> from vector<[32]xi8>
+// CHECK-NEXT: %[[T13:[0-9]+]] = vector.shape_cast %[[RHS]] : vector<[4]x8xi8> to vector<[32]xi8>
+// CHECK-NEXT: %[[RHS_0:[0-9]+]] = vector.scalable.extract %[[T13]][0] : vector<[16]xi8> from vector<[32]xi8>
+// CHECK-NEXT: %[[RHS_1:[0-9]+]] = vector.scalable.extract %[[T13]][16] : vector<[16]xi8> from vector<[32]xi8>
// Extract accumulator rows 0 and 1 and pack (into "registers")
-// CHECK-NEXT: %[[T19:[0-9]+]] = llvm.extractvalue %[[T3:[0-9]+]][0] : !llvm.array<4 x vector<[4]xi32>>
-// CHECK-NEXT: %[[T20:[0-9]+]] = llvm.extractvalue %[[T3]][1] : !llvm.array<4 x vector<[4]xi32>>
-// CHECK-NEXT: %[[T21:[0-9]+]] = llvm.bitcast %[[T19]] : vector<[4]xi32> to vector<[2]xi64>
-// CHECK-NEXT: %[[T22:[0-9]+]] = llvm.bitcast %[[T20]] : vector<[4]xi32> to vector<[2]xi64>
-// CHECK-NEXT: %[[T23:[0-9]+]] = "llvm.intr.vector.interleave2"(%[[T21]], %[[T22]]) : (vector<[2]xi64>, vector<[2]xi64>) -> vector<[4]xi64>
-// CHECK-NEXT: %[[T24:[0-9]+]] = llvm.bitcast %[[T23]] : vector<[4]xi64> to vector<[8]xi32>
-// CHECK-NEXT: %[[T25:[0-9]+]] = llvm.intr.vector.extract %[[T24]][0] : vector<[4]xi32> from vector<[8]xi32>
-// CHECK-NEXT: %[[T26:[0-9]+]] = llvm.intr.vector.extract %[[T24]][4] : vector<[4]xi32> from vector<[8]xi32>
-
-// Same for accumulator rows 2 and 3.
-// CHECK-NEXT: %[[T27:[0-9]+]] = llvm.extractvalue %[[T3]][2] : !llvm.array<4 x vector<[4]xi32>>
-// CHECK-NEXT: %[[T28:[0-9]+]] = llvm.extractvalue %[[T3]][3] : !llvm.array<4 x vector<[4]xi32>>
-// CHECK-NEXT: %[[T29:[0-9]+]] = llvm.bitcast %[[T27]] : vector<[4]xi32> to vector<[2]xi64>
-// CHECK-NEXT: %[[T30:[0-9]+]] = llvm.bitcast %[[T28]] : vector<[4]xi32> to vector<[2]xi64>
-// CHECK-NEXT: %[[T31:[0-9]+]] = "llvm.intr.vector.interleave2"(%[[T29]], %[[T30]]) : (vector<[2]xi64>, vector<[2]xi64>) -> vector<[4]xi64>
-// CHECK-NEXT: %[[T32:[0-9]+]] = llvm.bitcast %[[T31]] : vector<[4]xi64> to vector<[8]xi32>
-// CHECK-NEXT: %[[T33:[0-9]+]] = llvm.intr.vector.extract %[[T32]][0] : vector<[4]xi32> from vector<[8]xi32>
-// CHECK-NEXT: %[[T34:[0-9]+]] = llvm.intr.vector.extract %[[T32]][4] : vector<[4]xi32> from vector<[8]xi32>
+// CHECK-NEXT: %[[T16:[0-9]+]] = vector.extract %[[ACC]][0] : vector<[4]xi32> from vector<4x[4]xi32>
+// CHECK-NEXT: %[[T17:[0-9]+]] = vector.extract %[[ACC]][1] : vector<[4]xi32> from vector<4x[4]xi32>
+// CHECK-NEXT: %[[T18:[0-9]+]] = vector.bitcast %[[T16]] : vector<[4]xi32> to vector<[2]xi64>
+// CHECK-NEXT: %[[T19:[0-9]+]] = vector.bitcast %[[T17]] : vector<[4]xi32> to vector<[2]xi64>
+// CHECK-NEXT: %[[T20:[0-9]+]] = vector.interleave %[[T18]], %[[T19]] : vector<[2]xi64> -> vector<[4]xi64>
+// CHECK-NEXT: %[[T21:[0-9]+]] = vector.bitcast %[[T20]] : vector<[4]xi64> to vector<[8]xi32>
+// CHECK-NEXT: %[[ACC_0:[0-9]+]] = vector.scalable.extract %[[T21]][0] : vector<[4]xi32> from vector<[8]xi32>
+// CHECK-NEXT: %[[ACC_1:[0-9]+]] = vector.scalable.extract %[[T21]][4] : vector<[4]xi32> from vector<[8]xi32>
+
+// Same for accumulator rows 2 and 3
+// CHECK-NEXT: %[[T24:[0-9]+]] = vector.extract %[[ACC]][2] : vector<[4]xi32> from vector<4x[4]xi32>
+// CHECK-NEXT: %[[T25:[0-9]+]] = vector.extract %[[ACC]][3] : vector<[4]xi32> from vector<4x[4]xi32>
+// CHECK-NEXT: %[[T26:[0-9]+]] = vector.bitcast %[[T24]] : vector<[4]xi32> to vector<[2]xi64>
+// CHECK-NEXT: %[[T27:[0-9]+]] = vector.bitcast %[[T25]] : vector<[4]xi32> to vector<[2]xi64>
+// CHECK-NEXT: %[[T28:[0-9]+]] = vector.interleave %[[T26]], %[[T27]] : vector<[2]xi64> -> vector<[4]xi64>
+// CHECK-NEXT: %[[T29:[0-9]+]] = vector.bitcast %[[T28]] : vector<[4]xi64> to vector<[8]xi32>
+// CHECK-NEXT: %[[ACC_2:[0-9]+]] = vector.scalable.extract %[[T29]][0] : vector<[4]xi32> from vector<[8]xi32>
+// CHECK-NEXT: %[[ACC_3:[0-9]+]] = vector.scalable.extract %[[T29]][4] : vector<[4]xi32> from vector<[8]xi32>
// Do the sub-tile matrix multiplications
-// CHECK-NEXT: %[[T35:[0-9]+]] = "arm_sve.intr.smmla"(%[[T25]], %[[T10]], %[[T17]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32>
-// CHECK-NEXT: %[[T36:[0-9]+]] = "arm_sve.intr.smmla"(%[[T26]], %[[T10]], %[[T18]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32>
-// CHECK-NEXT: %[[T37:[0-9]+]] = "arm_sve.intr.smmla"(%[[T33]], %[[T15]], %[[T17]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32>
-// CHECK-NEXT: %[[T38:[0-9]+]] = "arm_sve.intr.smmla"(%[[T34]], %[[T15]], %[[T18]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32>
-
-// Unpack (from "registers") and insert in the output result rows 0 and 1
-// CHECK-NEXT: %[[T39:[0-9]+]] = llvm.intr.vector.insert %[[T35]], %[[T2:[0-9]+]][0] : vector<[4]xi32> into vector<[8]xi32>
-// CHECK-NEXT: %[[T40:[0-9]+]] = llvm.intr.vector.insert %[[T36]], %[[T39]][4] : vector<[4]xi32> into vector<[8]xi32>
-// CHECK-NEXT: %[[T41:[0-9]+]] = llvm.bitcast %[[T40]] : vector<[8]xi32> to vector<[4]xi64>
-// CHECK-NEXT: %[[T42:[0-9]+]] = "llvm.intr.vector.deinterleave2"(%[[T41]]) : (vector<[4]xi64>) -> !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)>
-// CHECK-NEXT: %[[T43:[0-9]+]] = llvm.extractvalue %[[T42]][0] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)>
-// CHECK-NEXT: %[[T44:[0-9]+]] = llvm.extractvalue %[[T42]][1] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)>
-// CHECK-NEXT: %[[T45:[0-9]+]] = llvm.bitcast %[[T43]] : vector<[2]xi64> to vector<[4]xi32>
-// CHECK-NEXT: %[[T46:[0-9]+]] = llvm.bitcast %[[T44]] : vector<[2]xi64> to vector<[4]xi32>
-// CHECK-NEXT: %[[T47:[0-9]+]] = llvm.insertvalue %[[T45]], %[[T5:[0-9]+]][0] : !llvm.array<4 x vector<[4]xi32>>
-// CHECK-NEXT: %[[T48:[0-9]+]] = llvm.insertvalue %[[T46]], %[[T47]][1] : !llvm.array<4 x vector<[4]xi32>>
+// CHECK-NEXT: %[[PACK_RES_00:[0-9]+]] = arm_sve.smmla %[[ACC_0]], %[[LHS_0]], %[[RHS_0]] : vector<[16]xi8> to vector<[4]xi32>
+// CHECK-NEXT: %[[PACK_RES_01:[0-9]+]] = arm_sve.smmla %[[ACC_1]], %[[LHS_0]], %[[RHS_1]] : vector<[16]xi8> to vector<[4]xi32>
+// CHECK-NEXT: %[[PACK_RES_10:[0-9]+]] = arm_sve.smmla %[[ACC_2]], %[[LHS_1]], %[[RHS_0]] : vector<[16]xi8> to vector<[4]xi32>
+// CHECK-NEXT: %[[PACK_RES_11:[0-9]+]] = arm_sve.smmla %[[ACC_3]], %[[LHS_1]], %[[RHS_1]] : vector<[16]xi8> to vector<[4]xi32>
+
+// Unpack (from "registers") and insert in the output result rows 0 and 1
+// CHECK-NEXT: %[[T36:[0-9]+]] = vector.scalable.insert %[[PACK_RES_00]], %[[P0]][0] : vector<[4]xi32> into vector<[8]xi32>
+// CHECK-NEXT: %[[T37:[0-9]+]] = vector.scalable.insert %[[PACK_RES_01]], %[[T36]][4] : vector<[4]xi32> into vector<[8]xi32>
+// CHECK-NEXT: %[[T38:[0-9]+]] = vector.bitcast %[[T37]] : vector<[8]xi32> to vector<[4]xi64>
+// CHECK-NEXT: %res1, %res2 = vector.deinterleave %[[T38]] : vector<[4]xi64> -> vector<[2]xi64>
+// CHECK-NEXT: %[[UNPACK_RES_0:[0-9]+]] = vector.bitcast %res1 : vector<[2]xi64> to vector<[4]xi32>
+// CHECK-NEXT: %[[UNPACK_RES_1:[0-9]+]] = vector.bitcast %res2 : vector<[2]xi64> to vector<[4]xi32>
+// CHECK-NEXT: %[[TMP_OUT_0:[0-9]+]] = vector.insert %[[UNPACK_RES_0]], %[[P1]] [0] : vector<[4]xi32> into vector<4x[4]xi32>
+// CHECK-NEXT: %[[TMP_OUT_1:[0-9]+]] = vector.insert %[[UNPACK_RES_1]], %[[TMP_OUT_0]] [1] : vector<[4]xi32> into vector<4x[4]xi32>
// Same for result rows 2 and 3
-// CHECK-NEXT: %[[T49:[0-9]+]] = llvm.intr.vector.insert %[[T37]], %[[T2]][0] : vector<[4]xi32> into vector<[8]xi32>
-// CHECK-NEXT: %[[T50:[0-9]+]] = llvm.intr.vector.insert %[[T38]], %[[T49]][4] : vector<[4]xi32> into vector<[8]xi32>
-// CHECK-NEXT: %[[T51:[0-9]+]] = llvm.bitcast %[[T50]] : vector<[8]xi32> to vector<[4]xi64>
-// CHECK-NEXT: %[[T52:[0-9]+]] = "llvm.intr.vector.deinterleave2"(%[[T51]]) : (vector<[4]xi64>) -> !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)>
-// CHECK-NEXT: %[[T53:[0-9]+]] = llvm.extractvalue %[[T52]][0] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)>
-// CHECK-NEXT: %[[T54:[0-9]+]] = llvm.extractvalue %[[T52]][1] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)>
-// CHECK-NEXT: %[[T55:[0-9]+]] = llvm.bitcast %[[T53]] : vector<[2]xi64> to vector<[4]xi32>
-// CHECK-NEXT: %[[T56:[0-9]+]] = llvm.bitcast %[[T54]] : vector<[2]xi64> to vector<[4]xi32>
-// CHECK-NEXT: %[[T57:[0-9]+]] = llvm.insertvalue %[[T55]], %[[T48]][2] : !llvm.array<4 x vector<[4]xi32>>
-// CHECK-NEXT: %[[T58:[0-9]+]] = llvm.insertvalue %[[T56]], %[[T57]][3] : !llvm.array<4 x vector<[4]xi32>>
+// CHECK-NEXT: %[[T43:[0-9]+]] = vector.scalable.insert %[[PACK_RES_10]], %[[P0]][0] : vector<[4]xi32> into vector<[8]xi32>
+// CHECK-NEXT: %[[T44:[0-9]+]] = vector.scalable.insert %[[PACK_RES_11]], %[[T43]][4] : vector<[4]xi32> into vector<[8]xi32>
+// CHECK-NEXT: %[[T45:[0-9]+]] = vector.bitcast %[[T44]] : vector<[8]xi32> to vector<[4]xi64>
+// CHECK-NEXT: %res1_0, %res2_1 = vector.deinterleave %[[T45]] : vector<[4]xi64> -> vector<[2]xi64>
+// CHECK-NEXT: %[[UNPACK_RES_2:[0-9]+]] = vector.bitcast %res1_0 : vector<[2]xi64> to vector<[4]xi32>
+// CHECK-NEXT: %[[UNPACK_RES_3:[0-9]+]] = vector.bitcast %res2_1 : vector<[2]xi64> to vector<[4]xi32>
+// CHECK-NEXT: %[[TMP_OUT_2:[0-9]+]] = vector.insert %[[UNPACK_RES_2]], %[[TMP_OUT_1]] [2] : vector<[4]xi32> into vector<4x[4]xi32>
+// CHECK-NEXT: %[[OUT:[0-9]+]] = vector.insert %[[UNPACK_RES_3]], %[[TMP_OUT_2]] [3] : vector<[4]xi32> into vector<4x[4]xi32>
+
+// CHECK-NEXT: return %[[OUT]] : vector<4x[4]xi32>
// Test a variant where the sign-extension of the operands is
// implicit. The output is identical to the one of the previous test.
@@ -179,3 +190,15 @@ func.func @test_vector_contract_to_smmla_implicit_sext(%lhs: vector<4x8xi8>,
return %0 : vector<4x[4]xi32>
}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %module : (!transform.any_op) -> !transform.op<"func.func">
+
+ transform.apply_patterns to %func {
+ transform.apply_patterns.vector.arm_sve.lower_contraction
+ } : !transform.op<"func.func">
+
+ transform.yield
+ }
+}
diff --git a/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-summla.mlir b/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-summla.mlir
index b6285d068b0f8..ffcb437f189f9 100644
--- a/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-summla.mlir
+++ b/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-summla.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' --split-input-file | FileCheck %s
+// RUN: mlir-opt %s -transform-interpreter | FileCheck %s
#packed_maps = [
affine_map<(d0, d1, d2) -> (d0, d2)>,
@@ -8,66 +8,77 @@
// CHECK-LABEL: @test_vector_contract_to_usmmla_rev
-// Extract LHS rows 0 and 1, concatenate, turn into scalable vector
-// CHECK: %[[T6:[0-9]+]] = llvm.extractvalue %[[T1:[0-9]+]][0] : !llvm.array<4 x vector<8xi8>>
-// CHECK-NEXT: %[[T7:[0-9]+]] = llvm.extractvalue %[[T1]][1] : !llvm.array<4 x vector<8xi8>>
-// CHECK-NEXT: %[[T8:[0-9]+]] = llvm.shufflevector %[[T6]], %[[T7]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8>
-// CHECK-NEXT: %[[T9:[0-9]+]] = llvm.intr.vector.insert %[[T8]], %[[T5:[0-9]+]][0] : vector<16xi8> into vector<[16]xi8>
+// CHECK-SAME: %[[LHS:arg0]]: vector<4x8xi8>
+// CHECK-SAME: %[[RHS:arg1]]: vector<[4]x8xi8>
+// CHECK-SAME: %[[ACC:arg2]]: vector<4x[4]xi32>
+
+// CHECK: [[P0:[0-9]+]] = ub.poison : vector<[8]xi32>
+// CHECK-NEXT: [[P1:[0-9]+]] = ub.poison : vector<4x[4]xi32>
+// CHECK-NEXT: [[P2:[0-9]+]] = ub.poison : vector<[16]xi8>
-// Replicate across the entire length of the scalabale vector
-// CHECK-NEXT: %[[T10:[0-9]+]] = "arm_sve.intr.dupq_lane"(%[[T9]]) <{lane = 0 : i64}> : (vector<[16]xi8>) -> vector<[16]xi8>
+// Extract LHS rows 0 and 1, concatenate, turn into scalable vector
+// CHECK: %[[T3:[0-9]+]] = vector.extract %[[LHS]][0] : vector<8xi8> from vector<4x8xi8>
+// CHECK-NEXT: %[[T4:[0-9]+]] = vector.extract %[[LHS]][1] : vector<8xi8> from vector<4x8xi8>
+// CHECK-NEXT: %[[T5:[0-9]+]] = vector.shuffle %[[T3]], %[[T4]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8>, vector<8xi8>
+// CHECK-NEXT: %[[T6:[0-9]+]] = vector.scalable.insert %[[T5]], %[[P2]][0] : vector<16xi8> into vector<[16]xi8>
-// Same for LHS rows 2 and 4
-// CHECK-NEXT: %[[T11:[0-9]+]] = llvm.extractvalue %[[T1]][2] : !llvm.array<4 x vector<8xi8>>
-// CHECK-NEXT: %[[T12:[0-9]+]] = llvm.extractvalue %[[T1]][3] : !llvm.array<4 x vector<8xi8>>
-// CHECK-NEXT: %[[T13:[0-9]+]] = llvm.shufflevector %[[T11]], %[[T12]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8>
-// CHECK-NEXT: %[[T14:[0-9]+]] = llvm.intr.vector.insert %[[T13]], %[[T5]][0] : vector<16xi8> into vector<[16]xi8>
-// CHECK-NEXT: %[[T15:[0-9]+]] = "arm_sve.intr.dupq_lane"(%[[T14]]) <{lane = 0 : i64}> : (vector<[16]xi8>) -> vector<[16]xi8>
+// Replicate across the entire length of the scalable vector
+// CHECK-NEXT: %[[LHS_0:[0-9]+]] = arm_sve.dupq_lane %[[T6]][0] : vector<[16]xi8>
+// Same for LHS rows 2 and 3
+// CHECK-NEXT: %[[T8:[0-9]+]] = vector.extract %[[LHS]][2] : vector<8xi8> from vector<4x8xi8>
+// CHECK-NEXT: %[[T9:[0-9]+]] = vector.extract %[[LHS]][3] : vector<8xi8> from vector<4x8xi8>
+// CHECK-NEXT: %[[T10:[0-9]+]] = vector.shuffle %[[T8]], %[[T9]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8>, vector<8xi8>
+// CHECK-NEXT: %[[T11:[0-9]+]] = vector.scalable.insert %[[T10]], %[[P2]][0] : vector<16xi8> into vector<[16]xi8>
+// CHECK-NEXT: %[[LHS_1:[0-9]+]] = arm_sve.dupq_lane %[[T11]][0] : vector<[16]xi8>
// Extract sub-tiles from the RHS
-// CHECK-NEXT: %[[T16:[0-9]+]] = vector.shape_cast %arg1 : vector<[4]x8xi8> to vector<[32]xi8>
-// CHECK-NEXT: %[[T17:[0-9]+]] = llvm.intr.vector.extract %[[T16]][0] : vector<[16]xi8> from vector<[32]xi8>
-// CHECK-NEXT: %[[T18:[0-9]+]] = llvm.intr.vector.extract %[[T16]][16] : vector<[16]xi8> from vector<[32]xi8>
+// CHECK-NEXT: %[[T13:[0-9]+]] = vector.shape_cast %[[RHS]] : vector<[4]x8xi8> to vector<[32]xi8>
+// CHECK-NEXT: %[[RHS_0:[0-9]+]] = vector.scalable.extract %[[T13]][0] : vector<[16]xi8> from vector<[32]xi8>
+// CHECK-NEXT: %[[RHS_1:[0-9]+]] = vector.scalable.extract %[[T13]][16] : vector<[16]xi8> from vector<[32]xi8>
// Extract accumulator rows 0 and 1 and pack (into "registers")
-// CHECK-NEXT: %[[T19:[0-9]+]] = llvm.extractvalue %[[T0:[0-9]+]][0] : !llvm.array<4 x vector<[4]xi32>>
-// CHECK-NEXT: %[[T20:[0-9]+]] = llvm.extractvalue %[[T0]][1] : !llvm.array<4 x vector<[4]xi32>>
-// CHECK-NEXT: %[[T21:[0-9]+]] = "llvm.intr.vector.interleave2"(%[[T19]], %[[T20]]) : (vector<[4]xi32>, vector<[4]xi32>) -> vector<[8]xi32>
-// CHECK-NEXT: %[[T22:[0-9]+]] = llvm.intr.vector.extract %[[T21]][0] : vector<[4]xi32> from vector<[8]xi32>
-// CHECK-NEXT: %[[T23:[0-9]+]] = llvm.intr.vector.extract %[[T21]][4] : vector<[4]xi32> from vector<[8]xi32>
-
-// Same for accumulator rows 2 and 3.
-// CHECK-NEXT: %[[T24:[0-9]+]] = llvm.extractvalue %[[T0]][2] : !llvm.array<4 x vector<[4]xi32>>
-// CHECK-NEXT: %[[T25:[0-9]+]] = llvm.extractvalue %[[T0]][3] : !llvm.array<4 x vector<[4]xi32>>
-// CHECK-NEXT: %[[T26:[0-9]+]] = "llvm.intr.vector.interleave2"(%[[T24]], %[[T25]]) : (vector<[4]xi32>, vector<[4]xi32>) -> vector<[8]xi32>
-// CHECK-NEXT: %[[T27:[0-9]+]] = llvm.intr.vector.extract %[[T26]][0] : vector<[4]xi32> from vector<[8]xi32>
-// CHECK-NEXT: %[[T28:[0-9]+]] = llvm.intr.vector.extract %[[T26]][4] : vector<[4]xi32> from vector<[8]xi32>
+// Note the lack of bitcasts (i.e. interleave by pairs). This has the effect of transposing the tile.
+// CHECK-NEXT: %[[T16:[0-9]+]] = vector.extract %[[ACC]][0] : vector<[4]xi32> from vector<4x[4]xi32>
+// CHECK-NEXT: %[[T17:[0-9]+]] = vector.extract %[[ACC]][1] : vector<[4]xi32> from vector<4x[4]xi32>
+// CHECK-NEXT: %[[T20:[0-9]+]] = vector.interleave %[[T16]], %[[T17]] : vector<[4]xi32> -> vector<[8]xi32>
+// CHECK-NEXT: %[[ACC_0:[0-9]+]] = vector.scalable.extract %[[T20]][0] : vector<[4]xi32> from vector<[8]xi32>
+// CHECK-NEXT: %[[ACC_1:[0-9]+]] = vector.scalable.extract %[[T20]][4] : vector<[4]xi32> from vector<[8]xi32>
+
+// Same for accumulator rows 2 and 3
+// CHECK-NEXT: %[[T24:[0-9]+]] = vector.extract %[[ACC]][2] : vector<[4]xi32> from vector<4x[4]xi32>
+// CHECK-NEXT: %[[T25:[0-9]+]] = vector.extract %[[ACC]][3] : vector<[4]xi32> from vector<4x[4]xi32>
+// CHECK-NEXT: %[[T28:[0-9]+]] = vector.interleave %[[T24]], %[[T25]] : vector<[4]xi32> -> vector<[8]xi32>
+// CHECK-NEXT: %[[ACC_2:[0-9]+]] = vector.scalable.extract %[[T28]][0] : vector<[4]xi32> from vector<[8]xi32>
+// CHECK-NEXT: %[[ACC_3:[0-9]+]] = vector.scalable.extract %[[T28]][4] : vector<[4]xi32> from vector<[8]xi32>
// Do the sub-tile matrix multiplications
-// CHECK-NEXT: %[[T29:[0-9]+]] = "arm_sve.intr.usmmla"(%[[T22]], %[[T17]], %[[T10]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32>
-// CHECK-NEXT: %[[T30:[0-9]+]] = "arm_sve.intr.usmmla"(%[[T23]], %[[T18]], %[[T10]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32>
-// CHECK-NEXT: %[[T31:[0-9]+]] = "arm_sve.intr.usmmla"(%[[T27]], %[[T17]], %[[T15]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32>
-// CHECK-NEXT: %[[T32:[0-9]+]] = "arm_sve.intr.usmmla"(%[[T28]], %[[T18]], %[[T15]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32>
-
-// Unpack (from "registers") and insert in the output result rows 0 and 1
-// CHECK-NEXT: %[[T33:[0-9]+]] = llvm.intr.vector.insert %[[T29]], %[[T2:[0-9]+]][0] : vector<[4]xi32> into vector<[8]xi32>
-// CHECK-NEXT: %[[T34:[0-9]+]] = llvm.intr.vector.insert %[[T30]], %[[T33]][4] : vector<[4]xi32> into vector<[8]xi32>
-// CHECK-NEXT: %[[T35:[0-9]+]] = "llvm.intr.vector.deinterleave2"(%[[T34]]) : (vector<[8]xi32>) -> !llvm.struct<(vector<[4]xi32>, vector<[4]xi32>)>
-// CHECK-NEXT: %[[T36:[0-9]+]] = llvm.extractvalue %[[T35]][0] : !llvm.struct<(vector<[4]xi32>, vector<[4]xi32>)>
-// CHECK-NEXT: %[[T37:[0-9]+]] = llvm.extractvalue %[[T35]][1] : !llvm.struct<(vector<[4]xi32>, vector<[4]xi32>)>
-// CHECK-NEXT: %[[T38:[0-9]+]] = llvm.insertvalue %[[T36]], %[[T4:[0-9]+]][0] : !llvm.array<4 x vector<[4]xi32>>
-// CHECK-NEXT: %[[T39:[0-9]+]] = llvm.insertvalue %[[T37]], %[[T38]][1] : !llvm.array<4 x vector<[4]xi32>>
+// The accumulator is transposed, the operands are in the opposite order,
+// thus the result is obtained transposed too.
+//
+// CHECK-NEXT: %[[PACK_RES_00:[0-9]+]] = arm_sve.usmmla %[[ACC_0]], %[[RHS_0]], %[[LHS_0]] : vector<[16]xi8> to vector<[4]xi32>
+// CHECK-NEXT: %[[PACK_RES_01:[0-9]+]] = arm_sve.usmmla %[[ACC_1]], %[[RHS_1]], %[[LHS_0]] : vector<[16]xi8> to vector<[4]xi32>
+// CHECK-NEXT: %[[PACK_RES_10:[0-9]+]] = arm_sve.usmmla %[[ACC_2]], %[[RHS_0]], %[[LHS_1]] : vector<[16]xi8> to vector<[4]xi32>
+// CHECK-NEXT: %[[PACK_RES_11:[0-9]+]] = arm_sve.usmmla %[[ACC_3]], %[[RHS_1]], %[[LHS_1]] : vector<[16]xi8> to vector<[4]xi32>
+
+// Unpack (from "registers") and insert in the output result rows 0 and 1
+// Again, no bitcast/interleave-by-pairs, i.e. the transposed result of sub-tile
+// multiplications is unpacked into the correct output layout.
+//
+// CHECK-NEXT: %[[T36:[0-9]+]] = vector.scalable.insert %[[PACK_RES_00]], %[[P0]][0] : vector<[4]xi32> into vector<[8]xi32>
+// CHECK-NEXT: %[[T37:[0-9]+]] = vector.scalable.insert %[[PACK_RES_01]], %[[T36]][4] : vector<[4]xi32> into vector<[8]xi32>
+// CHECK-NEXT: %[[UNPACK_RES_0:[_a-z0-9]+]], %[[UNPACK_RES_1:[_a-z0-9]+]] = vector.deinterleave %[[T37]] : vector<[8]xi32> -> vector<[4]xi32>
+// CHECK-NEXT: %[[TMP_OUT_0:[0-9]+]] = vector.insert %[[UNPACK_RES_0]], %[[P1]] [0] : vector<[4]xi32> into vector<4x[4]xi32>
+// CHECK-NEXT: %[[TMP_OUT_1:[0-9]+]] = vector.insert %[[UNPACK_RES_1]], %[[TMP_OUT_0]] [1] : vector<[4]xi32> into vector<4x[4]xi32>
// Same for result rows 2 and 3
-// CHECK-NEXT: %[[T40:[0-9]+]] = llvm.intr.vector.insert %[[T31]], %[[T2]][0] : vector<[4]xi32> into vector<[8]xi32>
-// CHECK-NEXT: %[[T41:[0-9]+]] = llvm.intr.vector.insert %[[T32]], %[[T40]][4] : vector<[4]xi32> into vector<[8]xi32>
-// CHECK-NEXT: %[[T42:[0-9]+]] = "llvm.intr.vector.deinterleave2"(%[[T41]]) : (vector<[8]xi32>) -> !llvm.struct<(vector<[4]xi32>, vector<[4]xi32>)>
-// CHECK-NEXT: %[[T43:[0-9]+]] = llvm.extractvalue %[[T42]][0] : !llvm.struct<(vector<[4]xi32>, vector<[4]xi32>)>
-// CHECK-NEXT: %[[T44:[0-9]+]] = llvm.extractvalue %[[T42]][1] : !llvm.struct<(vector<[4]xi32>, vector<[4]xi32>)>
-// CHECK-NEXT: %[[T45:[0-9]+]] = llvm.insertvalue %[[T43]], %[[T39]][2] : !llvm.array<4 x vector<[4]xi32>>
-// CHECK-NEXT: %[[T46:[0-9]+]] = llvm.insertvalue %[[T44]], %[[T45]][3] : !llvm.array<4 x vector<[4]xi32>>
-// CHECK-NEXT: %[[T47:[0-9]+]] = builtin.unrealized_conversion_cast %[[T46]] : !llvm.array<4 x vector<[4]xi32>> to vector<4x[4]xi32>
+// CHECK-NEXT: %[[T43:[0-9]+]] = vector.scalable.insert %[[PACK_RES_10]], %[[P0]][0] : vector<[4]xi32> into vector<[8]xi32>
+// CHECK-NEXT: %[[T44:[0-9]+]] = vector.scalable.insert %[[PACK_RES_11]], %[[T43]][4] : vector<[4]xi32> into vector<[8]xi32>
+// CHECK-NEXT: %[[UNPACK_RES_2:[_a-z0-9]+]], %[[UNPACK_RES_3:[_a-z0-9]+]] = vector.deinterleave %[[T44]] : vector<[8]xi32> -> vector<[4]xi32>
+// CHECK-NEXT: %[[TMP_OUT_2:[0-9]+]] = vector.insert %[[UNPACK_RES_2]], %[[TMP_OUT_1]] [2] : vector<[4]xi32> into vector<4x[4]xi32>
+// CHECK-NEXT: %[[OUT:[0-9]+]] = vector.insert %[[UNPACK_RES_3]], %[[TMP_OUT_2]] [3] : vector<[4]xi32> into vector<4x[4]xi32>
+
+// CHECK-NEXT: return %[[OUT]] : vector<4x[4]xi32>
func.func @test_vector_contract_to_usmmla_rev(
%lhs: vector<4x8xi8>,
@@ -83,3 +94,15 @@ func.func @test_vector_contract_to_usmmla_rev(
return %2 : vector<4x[4]xi32>
}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %module : (!transform.any_op) -> !transform.op<"func.func">
+
+ transform.apply_patterns to %func {
+ transform.apply_patterns.vector.arm_sve.lower_contraction
+ } : !transform.op<"func.func">
+
+ transform.yield
+ }
+}
diff --git a/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-ummla.mlir b/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-ummla.mlir
index cde57842295f7..9f63a0c1688f5 100644
--- a/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-ummla.mlir
+++ b/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-ummla.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' --split-input-file | FileCheck %s
+// RUN: mlir-opt %s -transform-interpreter --split-input-file | FileCheck %s
#packed_maps = [
affine_map<(d0, d1, d2) -> (d0, d2)>,
@@ -8,76 +8,82 @@
// CHECK-LABEL: @test_vector_contract_to_ummla
+// CHECK-SAME: %[[LHS:arg0]]: vector<4x8xi8>
+// CHECK-SAME: %[[RHS:arg1]]: vector<[4]x8xi8>
+// CHECK-SAME: %[[ACC:arg2]]: vector<4x[4]xi32>
+
+// CHECK: [[P0:[0-9]+]] = ub.poison : vector<[8]xi32>
+// CHECK-NEXT: [[P1:[0-9]+]] = ub.poison : vector<4x[4]xi32>
+// CHECK-NEXT: [[P2:[0-9]+]] = ub.poison : vector<[16]xi8>
+
// Extract LHS rows 0 and 1, concatenate, turn into scalable vector
-// CHECK: %[[T6:[0-9]+]] = llvm.extractvalue %[[T4:[0-9]+]][0] : !llvm.array<4 x vector<8xi8>>
-// CHECK-NEXT: %[[T7:[0-9]+]] = llvm.extractvalue %[[T4]][1] : !llvm.array<4 x vector<8xi8>>
-// CHECK-NEXT: %[[T8:[0-9]+]] = llvm.shufflevector %[[T6]], %[[T7]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8>
-// CHECK-NEXT: %[[T9:[0-9]+]] = llvm.intr.vector.insert %[[T8]], %[[T0:[0-9+]]][0] : vector<16xi8> into vector<[16]xi8>
+// CHECK: %[[T3:[0-9]+]] = vector.extract %[[LHS]][0] : vector<8xi8> from vector<4x8xi8>
+// CHECK-NEXT: %[[T4:[0-9]+]] = vector.extract %[[LHS]][1] : vector<8xi8> from vector<4x8xi8>
+// CHECK-NEXT: %[[T5:[0-9]+]] = vector.shuffle %[[T3]], %[[T4]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8>, vector<8xi8>
+// CHECK-NEXT: %[[T6:[0-9]+]] = vector.scalable.insert %[[T5]], %[[P2]][0] : vector<16xi8> into vector<[16]xi8>
-// Replicate across the entire length of the scalabale vector
-// CHECK-NEXT: %[[T10:[0-9]+]] = "arm_sve.intr.dupq_lane"(%[[T9]]) <{lane = 0 : i64}> : (vector<[16]xi8>) -> vector<[16]xi8>
+// Replicate across the entire length of the scalable vector
+// CHECK-NEXT: %[[LHS_0:[0-9]+]] = arm_sve.dupq_lane %[[T6]][0] : vector<[16]xi8>
-// Same for LHS rows 2 and 4
-// CHECK-NEXT: %[[T11:[0-9]+]] = llvm.extractvalue %[[T4]][2] : !llvm.array<4 x vector<8xi8>>
-// CHECK-NEXT: %[[T12:[0-9]+]] = llvm.extractvalue %[[T4]][3] : !llvm.array<4 x vector<8xi8>>
-// CHECK-NEXT: %[[T13:[0-9]+]] = llvm.shufflevector %[[T11]], %[[T12]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8>
-// CHECK-NEXT: %[[T14:[0-9]+]] = llvm.intr.vector.insert %[[T13]], %[[T0]][0] : vector<16xi8> into vector<[16]xi8>
-// CHECK-NEXT: %[[T15:[0-9]+]] = "arm_sve.intr.dupq_lane"(%[[T14]]) <{lane = 0 : i64}> : (vector<[16]xi8>) -> vector<[16]xi8>
+// Same for LHS rows 2 and 3
+// CHECK-NEXT: %[[T8:[0-9]+]] = vector.extract %[[LHS]][2] : vector<8xi8> from vector<4x8xi8>
+// CHECK-NEXT: %[[T9:[0-9]+]] = vector.extract %[[LHS]][3] : vector<8xi8> from vector<4x8xi8>
+// CHECK-NEXT: %[[T10:[0-9]+]] = vector.shuffle %[[T8]], %[[T9]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8>, vector<8xi8>
+// CHECK-NEXT: %[[T11:[0-9]+]] = vector.scalable.insert %[[T10]], %[[P2]][0] : vector<16xi8> into vector<[16]xi8>
+// CHECK-NEXT: %[[LHS_1:[0-9]+]] = arm_sve.dupq_lane %[[T11]][0] : vector<[16]xi8>
// Extract sub-tiles from the RHS
-// CHECK-NEXT: %[[T16:[0-9]+]] = vector.shape_cast %arg1 : vector<[4]x8xi8> to vector<[32]xi8>
-// CHECK-NEXT: %[[T17:[0-9]+]] = llvm.intr.vector.extract %[[T16]][0] : vector<[16]xi8> from vector<[32]xi8>
-// CHECK-NEXT: %[[T18:[0-9]+]] = llvm.intr.vector.extract %[[T16]][16] : vector<[16]xi8> from vector<[32]xi8>
+// CHECK-NEXT: %[[T13:[0-9]+]] = vector.shape_cast %[[RHS]] : vector<[4]x8xi8> to vector<[32]xi8>
+// CHECK-NEXT: %[[RHS_0:[0-9]+]] = vector.scalable.extract %[[T13]][0] : vector<[16]xi8> from vector<[32]xi8>
+// CHECK-NEXT: %[[RHS_1:[0-9]+]] = vector.scalable.extract %[[T13]][16] : vector<[16]xi8> from vector<[32]xi8>
// Extract accumulator rows 0 and 1 and pack (into "registers")
-// CHECK-NEXT: %[[T19:[0-9]+]] = llvm.extractvalue %[[T3:[0-9]+]][0] : !llvm.array<4 x vector<[4]xi32>>
-// CHECK-NEXT: %[[T20:[0-9]+]] = llvm.extractvalue %[[T3]][1] : !llvm.array<4 x vector<[4]xi32>>
-// CHECK-NEXT: %[[T21:[0-9]+]] = llvm.bitcast %[[T19]] : vector<[4]xi32> to vector<[2]xi64>
-// CHECK-NEXT: %[[T22:[0-9]+]] = llvm.bitcast %[[T20]] : vector<[4]xi32> to vector<[2]xi64>
-// CHECK-NEXT: %[[T23:[0-9]+]] = "llvm.intr.vector.interleave2"(%[[T21]], %[[T22]]) : (vector<[2]xi64>, vector<[2]xi64>) -> vector<[4]xi64>
-// CHECK-NEXT: %[[T24:[0-9]+]] = llvm.bitcast %[[T23]] : vector<[4]xi64> to vector<[8]xi32>
-// CHECK-NEXT: %[[T25:[0-9]+]] = llvm.intr.vector.extract %[[T24]][0] : vector<[4]xi32> from vector<[8]xi32>
-// CHECK-NEXT: %[[T26:[0-9]+]] = llvm.intr.vector.extract %[[T24]][4] : vector<[4]xi32> from vector<[8]xi32>
-
-// Same for accumulator rows 2 and 3.
-// CHECK-NEXT: %[[T27:[0-9]+]] = llvm.extractvalue %[[T3]][2] : !llvm.array<4 x vector<[4]xi32>>
-// CHECK-NEXT: %[[T28:[0-9]+]] = llvm.extractvalue %[[T3]][3] : !llvm.array<4 x vector<[4]xi32>>
-// CHECK-NEXT: %[[T29:[0-9]+]] = llvm.bitcast %[[T27]] : vector<[4]xi32> to vector<[2]xi64>
-// CHECK-NEXT: %[[T30:[0-9]+]] = llvm.bitcast %[[T28]] : vector<[4]xi32> to vector<[2]xi64>
-// CHECK-NEXT: %[[T31:[0-9]+]] = "llvm.intr.vector.interleave2"(%[[T29]], %[[T30]]) : (vector<[2]xi64>, vector<[2]xi64>) -> vector<[4]xi64>
-// CHECK-NEXT: %[[T32:[0-9]+]] = llvm.bitcast %[[T31]] : vector<[4]xi64> to vector<[8]xi32>
-// CHECK-NEXT: %[[T33:[0-9]+]] = llvm.intr.vector.extract %[[T32]][0] : vector<[4]xi32> from vector<[8]xi32>
-// CHECK-NEXT: %[[T34:[0-9]+]] = llvm.intr.vector.extract %[[T32]][4] : vector<[4]xi32> from vector<[8]xi32>
+// CHECK-NEXT: %[[T16:[0-9]+]] = vector.extract %[[ACC]][0] : vector<[4]xi32> from vector<4x[4]xi32>
+// CHECK-NEXT: %[[T17:[0-9]+]] = vector.extract %[[ACC]][1] : vector<[4]xi32> from vector<4x[4]xi32>
+// CHECK-NEXT: %[[T18:[0-9]+]] = vector.bitcast %[[T16]] : vector<[4]xi32> to vector<[2]xi64>
+// CHECK-NEXT: %[[T19:[0-9]+]] = vector.bitcast %[[T17]] : vector<[4]xi32> to vector<[2]xi64>
+// CHECK-NEXT: %[[T20:[0-9]+]] = vector.interleave %[[T18]], %[[T19]] : vector<[2]xi64> -> vector<[4]xi64>
+// CHECK-NEXT: %[[T21:[0-9]+]] = vector.bitcast %[[T20]] : vector<[4]xi64> to vector<[8]xi32>
+// CHECK-NEXT: %[[ACC_0:[0-9]+]] = vector.scalable.extract %[[T21]][0] : vector<[4]xi32> from vector<[8]xi32>
+// CHECK-NEXT: %[[ACC_1:[0-9]+]] = vector.scalable.extract %[[T21]][4] : vector<[4]xi32> from vector<[8]xi32>
+
+// Same for accumulator rows 2 and 3
+// CHECK-NEXT: %[[T24:[0-9]+]] = vector.extract %[[ACC]][2] : vector<[4]xi32> from vector<4x[4]xi32>
+// CHECK-NEXT: %[[T25:[0-9]+]] = vector.extract %[[ACC]][3] : vector<[4]xi32> from vector<4x[4]xi32>
+// CHECK-NEXT: %[[T26:[0-9]+]] = vector.bitcast %[[T24]] : vector<[4]xi32> to vector<[2]xi64>
+// CHECK-NEXT: %[[T27:[0-9]+]] = vector.bitcast %[[T25]] : vector<[4]xi32> to vector<[2]xi64>
+// CHECK-NEXT: %[[T28:[0-9]+]] = vector.interleave %[[T26]], %[[T27]] : vector<[2]xi64> -> vector<[4]xi64>
+// CHECK-NEXT: %[[T29:[0-9]+]] = vector.bitcast %[[T28]] : vector<[4]xi64> to vector<[8]xi32>
+// CHECK-NEXT: %[[ACC_2:[0-9]+]] = vector.scalable.extract %[[T29]][0] : vector<[4]xi32> from vector<[8]xi32>
+// CHECK-NEXT: %[[ACC_3:[0-9]+]] = vector.scalable.extract %[[T29]][4] : vector<[4]xi32> from vector<[8]xi32>
// Do the sub-tile matrix multiplications
-// CHECK-NEXT: %[[T35:[0-9]+]] = "arm_sve.intr.ummla"(%[[T25]], %[[T10]], %[[T17]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32>
-// CHECK-NEXT: %[[T36:[0-9]+]] = "arm_sve.intr.ummla"(%[[T26]], %[[T10]], %[[T18]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32>
-// CHECK-NEXT: %[[T37:[0-9]+]] = "arm_sve.intr.ummla"(%[[T33]], %[[T15]], %[[T17]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32>
-// CHECK-NEXT: %[[T38:[0-9]+]] = "arm_sve.intr.ummla"(%[[T34]], %[[T15]], %[[T18]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32>
-
-// Unpack (from "registers") and insert in the output result rows 0 and 1
-// CHECK-NEXT: %[[T39:[0-9]+]] = llvm.intr.vector.insert %[[T35]], %[[T2:[0-9]+]][0] : vector<[4]xi32> into vector<[8]xi32>
-// CHECK-NEXT: %[[T40:[0-9]+]] = llvm.intr.vector.insert %[[T36]], %[[T39]][4] : vector<[4]xi32> into vector<[8]xi32>
-// CHECK-NEXT: %[[T41:[0-9]+]] = llvm.bitcast %[[T40]] : vector<[8]xi32> to vector<[4]xi64>
-// CHECK-NEXT: %[[T42:[0-9]+]] = "llvm.intr.vector.deinterleave2"(%[[T41]]) : (vector<[4]xi64>) -> !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)>
-// CHECK-NEXT: %[[T43:[0-9]+]] = llvm.extractvalue %[[T42]][0] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)>
-// CHECK-NEXT: %[[T44:[0-9]+]] = llvm.extractvalue %[[T42]][1] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)>
-// CHECK-NEXT: %[[T45:[0-9]+]] = llvm.bitcast %[[T43]] : vector<[2]xi64> to vector<[4]xi32>
-// CHECK-NEXT: %[[T46:[0-9]+]] = llvm.bitcast %[[T44]] : vector<[2]xi64> to vector<[4]xi32>
-// CHECK-NEXT: %[[T47:[0-9]+]] = llvm.insertvalue %[[T45]], %[[T5:[0-9]+]][0] : !llvm.array<4 x vector<[4]xi32>>
-// CHECK-NEXT: %[[T48:[0-9]+]] = llvm.insertvalue %[[T46]], %[[T47]][1] : !llvm.array<4 x vector<[4]xi32>>
+// CHECK-NEXT: %[[PACK_RES_00:[0-9]+]] = arm_sve.ummla %[[ACC_0]], %[[LHS_0]], %[[RHS_0]] : vector<[16]xi8> to vector<[4]xi32>
+// CHECK-NEXT: %[[PACK_RES_01:[0-9]+]] = arm_sve.ummla %[[ACC_1]], %[[LHS_0]], %[[RHS_1]] : vector<[16]xi8> to vector<[4]xi32>
+// CHECK-NEXT: %[[PACK_RES_10:[0-9]+]] = arm_sve.ummla %[[ACC_2]], %[[LHS_1]], %[[RHS_0]] : vector<[16]xi8> to vector<[4]xi32>
+// CHECK-NEXT: %[[PACK_RES_11:[0-9]+]] = arm_sve.ummla %[[ACC_3]], %[[LHS_1]], %[[RHS_1]] : vector<[16]xi8> to vector<[4]xi32>
+
+// Unpack (from "registers") and insert in the output result rows 0 and 1
+// CHECK-NEXT: %[[T36:[0-9]+]] = vector.scalable.insert %[[PACK_RES_00]], %[[P0]][0] : vector<[4]xi32> into vector<[8]xi32>
+// CHECK-NEXT: %[[T37:[0-9]+]] = vector.scalable.insert %[[PACK_RES_01]], %[[T36]][4] : vector<[4]xi32> into vector<[8]xi32>
+// CHECK-NEXT: %[[T38:[0-9]+]] = vector.bitcast %[[T37]] : vector<[8]xi32> to vector<[4]xi64>
+// CHECK-NEXT: %res1, %res2 = vector.deinterleave %[[T38]] : vector<[4]xi64> -> vector<[2]xi64>
+// CHECK-NEXT: %[[UNPACK_RES_0:[0-9]+]] = vector.bitcast %res1 : vector<[2]xi64> to vector<[4]xi32>
+// CHECK-NEXT: %[[UNPACK_RES_1:[0-9]+]] = vector.bitcast %res2 : vector<[2]xi64> to vector<[4]xi32>
+// CHECK-NEXT: %[[TMP_OUT_0:[0-9]+]] = vector.insert %[[UNPACK_RES_0]], %[[P1]] [0] : vector<[4]xi32> into vector<4x[4]xi32>
+// CHECK-NEXT: %[[TMP_OUT_1:[0-9]+]] = vector.insert %[[UNPACK_RES_1]], %[[TMP_OUT_0]] [1] : vector<[4]xi32> into vector<4x[4]xi32>
// Same for result rows 2 and 3
-// CHECK-NEXT: %[[T49:[0-9]+]] = llvm.intr.vector.insert %[[T37]], %[[T2]][0] : vector<[4]xi32> into vector<[8]xi32>
-// CHECK-NEXT: %[[T50:[0-9]+]] = llvm.intr.vector.insert %[[T38]], %[[T49]][4] : vector<[4]xi32> into vector<[8]xi32>
-// CHECK-NEXT: %[[T51:[0-9]+]] = llvm.bitcast %[[T50]] : vector<[8]xi32> to vector<[4]xi64>
-// CHECK-NEXT: %[[T52:[0-9]+]] = "llvm.intr.vector.deinterleave2"(%[[T51]]) : (vector<[4]xi64>) -> !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)>
-// CHECK-NEXT: %[[T53:[0-9]+]] = llvm.extractvalue %[[T52]][0] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)>
-// CHECK-NEXT: %[[T54:[0-9]+]] = llvm.extractvalue %[[T52]][1] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)>
-// CHECK-NEXT: %[[T55:[0-9]+]] = llvm.bitcast %[[T53]] : vector<[2]xi64> to vector<[4]xi32>
-// CHECK-NEXT: %[[T56:[0-9]+]] = llvm.bitcast %[[T54]] : vector<[2]xi64> to vector<[4]xi32>
-// CHECK-NEXT: %[[T57:[0-9]+]] = llvm.insertvalue %[[T55]], %[[T48]][2] : !llvm.array<4 x vector<[4]xi32>>
-// CHECK-NEXT: %[[T58:[0-9]+]] = llvm.insertvalue %[[T56]], %[[T57]][3] : !llvm.array<4 x vector<[4]xi32>>
+// CHECK-NEXT: %[[T43:[0-9]+]] = vector.scalable.insert %[[PACK_RES_10]], %[[P0]][0] : vector<[4]xi32> into vector<[8]xi32>
+// CHECK-NEXT: %[[T44:[0-9]+]] = vector.scalable.insert %[[PACK_RES_11]], %[[T43]][4] : vector<[4]xi32> into vector<[8]xi32>
+// CHECK-NEXT: %[[T45:[0-9]+]] = vector.bitcast %[[T44]] : vector<[8]xi32> to vector<[4]xi64>
+// CHECK-NEXT: %res1_0, %res2_1 = vector.deinterleave %[[T45]] : vector<[4]xi64> -> vector<[2]xi64>
+// CHECK-NEXT: %[[UNPACK_RES_2:[0-9]+]] = vector.bitcast %res1_0 : vector<[2]xi64> to vector<[4]xi32>
+// CHECK-NEXT: %[[UNPACK_RES_3:[0-9]+]] = vector.bitcast %res2_1 : vector<[2]xi64> to vector<[4]xi32>
+// CHECK-NEXT: %[[TMP_OUT_2:[0-9]+]] = vector.insert %[[UNPACK_RES_2]], %[[TMP_OUT_1]] [2] : vector<[4]xi32> into vector<4x[4]xi32>
+// CHECK-NEXT: %[[OUT:[0-9]+]] = vector.insert %[[UNPACK_RES_3]], %[[TMP_OUT_2]] [3] : vector<[4]xi32> into vector<4x[4]xi32>
+
+// CHECK-NEXT: return %[[OUT]] : vector<4x[4]xi32>
func.func @test_vector_contract_to_ummla(%lhs: vector<4x8xi8>,
%rhs: vector<[4]x8xi8>,
@@ -92,3 +98,15 @@ func.func @test_vector_contract_to_ummla(%lhs: vector<4x8xi8>,
return %2 : vector<4x[4]xi32>
}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %module : (!transform.any_op) -> !transform.op<"func.func">
+
+ transform.apply_patterns to %func {
+ transform.apply_patterns.vector.arm_sve.lower_contraction
+ } : !transform.op<"func.func">
+
+ transform.yield
+ }
+}
diff --git a/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-usmmla.mlir b/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-usmmla.mlir
index d0eef9fb9769c..f9d8b604150e6 100644
--- a/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-usmmla.mlir
+++ b/mlir/test/Dialect/Vector/CPU/ArmSVE/vector-usmmla.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' --split-input-file | FileCheck %s
+// RUN: mlir-opt %s -transform-interpreter | FileCheck %s
#packed_maps = [
affine_map<(d0, d1, d2) -> (d0, d2)>,
@@ -8,76 +8,82 @@
// CHECK-LABEL: @test_vector_contract_to_usmmla
+// CHECK-SAME: %[[LHS:arg0]]: vector<4x8xi8>
+// CHECK-SAME: %[[RHS:arg1]]: vector<[4]x8xi8>
+// CHECK-SAME: %[[ACC:arg2]]: vector<4x[4]xi32>
+
+// CHECK: [[P0:[0-9]+]] = ub.poison : vector<[8]xi32>
+// CHECK-NEXT: [[P1:[0-9]+]] = ub.poison : vector<4x[4]xi32>
+// CHECK-NEXT: [[P2:[0-9]+]] = ub.poison : vector<[16]xi8>
+
// Extract LHS rows 0 and 1, concatenate, turn into scalable vector
-// CHECK: %[[T6:[0-9]+]] = llvm.extractvalue %[[T4:[0-9]+]][0] : !llvm.array<4 x vector<8xi8>>
-// CHECK-NEXT: %[[T7:[0-9]+]] = llvm.extractvalue %[[T4]][1] : !llvm.array<4 x vector<8xi8>>
-// CHECK-NEXT: %[[T8:[0-9]+]] = llvm.shufflevector %[[T6]], %[[T7]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8>
-// CHECK-NEXT: %[[T9:[0-9]+]] = llvm.intr.vector.insert %[[T8]], %[[T0:[0-9+]]][0] : vector<16xi8> into vector<[16]xi8>
+// CHECK: %[[T3:[0-9]+]] = vector.extract %[[LHS]][0] : vector<8xi8> from vector<4x8xi8>
+// CHECK-NEXT: %[[T4:[0-9]+]] = vector.extract %[[LHS]][1] : vector<8xi8> from vector<4x8xi8>
+// CHECK-NEXT: %[[T5:[0-9]+]] = vector.shuffle %[[T3]], %[[T4]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8>, vector<8xi8>
+// CHECK-NEXT: %[[T6:[0-9]+]] = vector.scalable.insert %[[T5]], %[[P2]][0] : vector<16xi8> into vector<[16]xi8>
-// Replicate across the entire length of the scalabale vector
-// CHECK-NEXT: %[[T10:[0-9]+]] = "arm_sve.intr.dupq_lane"(%[[T9]]) <{lane = 0 : i64}> : (vector<[16]xi8>) -> vector<[16]xi8>
+// Replicate across the entire length of the scalable vector
+// CHECK-NEXT: %[[LHS_0:[0-9]+]] = arm_sve.dupq_lane %[[T6]][0] : vector<[16]xi8>
-// Same for LHS rows 2 and 4
-// CHECK-NEXT: %[[T11:[0-9]+]] = llvm.extractvalue %[[T4]][2] : !llvm.array<4 x vector<8xi8>>
-// CHECK-NEXT: %[[T12:[0-9]+]] = llvm.extractvalue %[[T4]][3] : !llvm.array<4 x vector<8xi8>>
-// CHECK-NEXT: %[[T13:[0-9]+]] = llvm.shufflevector %[[T11]], %[[T12]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8>
-// CHECK-NEXT: %[[T14:[0-9]+]] = llvm.intr.vector.insert %[[T13]], %[[T0]][0] : vector<16xi8> into vector<[16]xi8>
-// CHECK-NEXT: %[[T15:[0-9]+]] = "arm_sve.intr.dupq_lane"(%[[T14]]) <{lane = 0 : i64}> : (vector<[16]xi8>) -> vector<[16]xi8>
+// Same for LHS rows 2 and 3
+// CHECK-NEXT: %[[T8:[0-9]+]] = vector.extract %[[LHS]][2] : vector<8xi8> from vector<4x8xi8>
+// CHECK-NEXT: %[[T9:[0-9]+]] = vector.extract %[[LHS]][3] : vector<8xi8> from vector<4x8xi8>
+// CHECK-NEXT: %[[T10:[0-9]+]] = vector.shuffle %[[T8]], %[[T9]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8>, vector<8xi8>
+// CHECK-NEXT: %[[T11:[0-9]+]] = vector.scalable.insert %[[T10]], %[[P2]][0] : vector<16xi8> into vector<[16]xi8>
+// CHECK-NEXT: %[[LHS_1:[0-9]+]] = arm_sve.dupq_lane %[[T11]][0] : vector<[16]xi8>
// Extract sub-tiles from the RHS
-// CHECK-NEXT: %[[T16:[0-9]+]] = vector.shape_cast %arg1 : vector<[4]x8xi8> to vector<[32]xi8>
-// CHECK-NEXT: %[[T17:[0-9]+]] = llvm.intr.vector.extract %[[T16]][0] : vector<[16]xi8> from vector<[32]xi8>
-// CHECK-NEXT: %[[T18:[0-9]+]] = llvm.intr.vector.extract %[[T16]][16] : vector<[16]xi8> from vector<[32]xi8>
+// CHECK-NEXT: %[[T13:[0-9]+]] = vector.shape_cast %[[RHS]] : vector<[4]x8xi8> to vector<[32]xi8>
+// CHECK-NEXT: %[[RHS_0:[0-9]+]] = vector.scalable.extract %[[T13]][0] : vector<[16]xi8> from vector<[32]xi8>
+// CHECK-NEXT: %[[RHS_1:[0-9]+]] = vector.scalable.extract %[[T13]][16] : vector<[16]xi8> from vector<[32]xi8>
// Extract accumulator rows 0 and 1 and pack (into "registers")
-// CHECK-NEXT: %[[T19:[0-9]+]] = llvm.extractvalue %[[T3:[0-9]+]][0] : !llvm.array<4 x vector<[4]xi32>>
-// CHECK-NEXT: %[[T20:[0-9]+]] = llvm.extractvalue %[[T3]][1] : !llvm.array<4 x vector<[4]xi32>>
-// CHECK-NEXT: %[[T21:[0-9]+]] = llvm.bitcast %[[T19]] : vector<[4]xi32> to vector<[2]xi64>
-// CHECK-NEXT: %[[T22:[0-9]+]] = llvm.bitcast %[[T20]] : vector<[4]xi32> to vector<[2]xi64>
-// CHECK-NEXT: %[[T23:[0-9]+]] = "llvm.intr.vector.interleave2"(%[[T21]], %[[T22]]) : (vector<[2]xi64>, vector<[2]xi64>) -> vector<[4]xi64>
-// CHECK-NEXT: %[[T24:[0-9]+]] = llvm.bitcast %[[T23]] : vector<[4]xi64> to vector<[8]xi32>
-// CHECK-NEXT: %[[T25:[0-9]+]] = llvm.intr.vector.extract %[[T24]][0] : vector<[4]xi32> from vector<[8]xi32>
-// CHECK-NEXT: %[[T26:[0-9]+]] = llvm.intr.vector.extract %[[T24]][4] : vector<[4]xi32> from vector<[8]xi32>
-
-// Same for accumulator rows 2 and 3.
-// CHECK-NEXT: %[[T27:[0-9]+]] = llvm.extractvalue %[[T3]][2] : !llvm.array<4 x vector<[4]xi32>>
-// CHECK-NEXT: %[[T28:[0-9]+]] = llvm.extractvalue %[[T3]][3] : !llvm.array<4 x vector<[4]xi32>>
-// CHECK-NEXT: %[[T29:[0-9]+]] = llvm.bitcast %[[T27]] : vector<[4]xi32> to vector<[2]xi64>
-// CHECK-NEXT: %[[T30:[0-9]+]] = llvm.bitcast %[[T28]] : vector<[4]xi32> to vector<[2]xi64>
-// CHECK-NEXT: %[[T31:[0-9]+]] = "llvm.intr.vector.interleave2"(%[[T29]], %[[T30]]) : (vector<[2]xi64>, vector<[2]xi64>) -> vector<[4]xi64>
-// CHECK-NEXT: %[[T32:[0-9]+]] = llvm.bitcast %[[T31]] : vector<[4]xi64> to vector<[8]xi32>
-// CHECK-NEXT: %[[T33:[0-9]+]] = llvm.intr.vector.extract %[[T32]][0] : vector<[4]xi32> from vector<[8]xi32>
-// CHECK-NEXT: %[[T34:[0-9]+]] = llvm.intr.vector.extract %[[T32]][4] : vector<[4]xi32> from vector<[8]xi32>
+// CHECK-NEXT: %[[T16:[0-9]+]] = vector.extract %[[ACC]][0] : vector<[4]xi32> from vector<4x[4]xi32>
+// CHECK-NEXT: %[[T17:[0-9]+]] = vector.extract %[[ACC]][1] : vector<[4]xi32> from vector<4x[4]xi32>
+// CHECK-NEXT: %[[T18:[0-9]+]] = vector.bitcast %[[T16]] : vector<[4]xi32> to vector<[2]xi64>
+// CHECK-NEXT: %[[T19:[0-9]+]] = vector.bitcast %[[T17]] : vector<[4]xi32> to vector<[2]xi64>
+// CHECK-NEXT: %[[T20:[0-9]+]] = vector.interleave %[[T18]], %[[T19]] : vector<[2]xi64> -> vector<[4]xi64>
+// CHECK-NEXT: %[[T21:[0-9]+]] = vector.bitcast %[[T20]] : vector<[4]xi64> to vector<[8]xi32>
+// CHECK-NEXT: %[[ACC_0:[0-9]+]] = vector.scalable.extract %[[T21]][0] : vector<[4]xi32> from vector<[8]xi32>
+// CHECK-NEXT: %[[ACC_1:[0-9]+]] = vector.scalable.extract %[[T21]][4] : vector<[4]xi32> from vector<[8]xi32>
+
+// Same for accumulator rows 2 and 3
+// CHECK-NEXT: %[[T24:[0-9]+]] = vector.extract %[[ACC]][2] : vector<[4]xi32> from vector<4x[4]xi32>
+// CHECK-NEXT: %[[T25:[0-9]+]] = vector.extract %[[ACC]][3] : vector<[4]xi32> from vector<4x[4]xi32>
+// CHECK-NEXT: %[[T26:[0-9]+]] = vector.bitcast %[[T24]] : vector<[4]xi32> to vector<[2]xi64>
+// CHECK-NEXT: %[[T27:[0-9]+]] = vector.bitcast %[[T25]] : vector<[4]xi32> to vector<[2]xi64>
+// CHECK-NEXT: %[[T28:[0-9]+]] = vector.interleave %[[T26]], %[[T27]] : vector<[2]xi64> -> vector<[4]xi64>
+// CHECK-NEXT: %[[T29:[0-9]+]] = vector.bitcast %[[T28]] : vector<[4]xi64> to vector<[8]xi32>
+// CHECK-NEXT: %[[ACC_2:[0-9]+]] = vector.scalable.extract %[[T29]][0] : vector<[4]xi32> from vector<[8]xi32>
+// CHECK-NEXT: %[[ACC_3:[0-9]+]] = vector.scalable.extract %[[T29]][4] : vector<[4]xi32> from vector<[8]xi32>
// Do the sub-tile matrix multiplications
-// CHECK-NEXT: %[[T35:[0-9]+]] = "arm_sve.intr.usmmla"(%[[T25]], %[[T10]], %[[T17]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32>
-// CHECK-NEXT: %[[T36:[0-9]+]] = "arm_sve.intr.usmmla"(%[[T26]], %[[T10]], %[[T18]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32>
-// CHECK-NEXT: %[[T37:[0-9]+]] = "arm_sve.intr.usmmla"(%[[T33]], %[[T15]], %[[T17]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32>
-// CHECK-NEXT: %[[T38:[0-9]+]] = "arm_sve.intr.usmmla"(%[[T34]], %[[T15]], %[[T18]]) : (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) -> vector<[4]xi32>
-
-// Unpack (from "registers") and insert in the output result rows 0 and 1
-// CHECK-NEXT: %[[T39:[0-9]+]] = llvm.intr.vector.insert %[[T35]], %[[T2:[0-9]+]][0] : vector<[4]xi32> into vector<[8]xi32>
-// CHECK-NEXT: %[[T40:[0-9]+]] = llvm.intr.vector.insert %[[T36]], %[[T39]][4] : vector<[4]xi32> into vector<[8]xi32>
-// CHECK-NEXT: %[[T41:[0-9]+]] = llvm.bitcast %[[T40]] : vector<[8]xi32> to vector<[4]xi64>
-// CHECK-NEXT: %[[T42:[0-9]+]] = "llvm.intr.vector.deinterleave2"(%[[T41]]) : (vector<[4]xi64>) -> !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)>
-// CHECK-NEXT: %[[T43:[0-9]+]] = llvm.extractvalue %[[T42]][0] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)>
-// CHECK-NEXT: %[[T44:[0-9]+]] = llvm.extractvalue %[[T42]][1] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)>
-// CHECK-NEXT: %[[T45:[0-9]+]] = llvm.bitcast %[[T43]] : vector<[2]xi64> to vector<[4]xi32>
-// CHECK-NEXT: %[[T46:[0-9]+]] = llvm.bitcast %[[T44]] : vector<[2]xi64> to vector<[4]xi32>
-// CHECK-NEXT: %[[T47:[0-9]+]] = llvm.insertvalue %[[T45]], %[[T5:[0-9]+]][0] : !llvm.array<4 x vector<[4]xi32>>
-// CHECK-NEXT: %[[T48:[0-9]+]] = llvm.insertvalue %[[T46]], %[[T47]][1] : !llvm.array<4 x vector<[4]xi32>>
+// CHECK-NEXT: %[[PACK_RES_00:[0-9]+]] = arm_sve.usmmla %[[ACC_0]], %[[LHS_0]], %[[RHS_0]] : vector<[16]xi8> to vector<[4]xi32>
+// CHECK-NEXT: %[[PACK_RES_01:[0-9]+]] = arm_sve.usmmla %[[ACC_1]], %[[LHS_0]], %[[RHS_1]] : vector<[16]xi8> to vector<[4]xi32>
+// CHECK-NEXT: %[[PACK_RES_10:[0-9]+]] = arm_sve.usmmla %[[ACC_2]], %[[LHS_1]], %[[RHS_0]] : vector<[16]xi8> to vector<[4]xi32>
+// CHECK-NEXT: %[[PACK_RES_11:[0-9]+]] = arm_sve.usmmla %[[ACC_3]], %[[LHS_1]], %[[RHS_1]] : vector<[16]xi8> to vector<[4]xi32>
+
+// Unpack (from "registers") and insert in the output result rows 0 and 1
+// CHECK-NEXT: %[[T36:[0-9]+]] = vector.scalable.insert %[[PACK_RES_00]], %[[P0]][0] : vector<[4]xi32> into vector<[8]xi32>
+// CHECK-NEXT: %[[T37:[0-9]+]] = vector.scalable.insert %[[PACK_RES_01]], %[[T36]][4] : vector<[4]xi32> into vector<[8]xi32>
+// CHECK-NEXT: %[[T38:[0-9]+]] = vector.bitcast %[[T37]] : vector<[8]xi32> to vector<[4]xi64>
+// CHECK-NEXT: %res1, %res2 = vector.deinterleave %[[T38]] : vector<[4]xi64> -> vector<[2]xi64>
+// CHECK-NEXT: %[[UNPACK_RES_0:[0-9]+]] = vector.bitcast %res1 : vector<[2]xi64> to vector<[4]xi32>
+// CHECK-NEXT: %[[UNPACK_RES_1:[0-9]+]] = vector.bitcast %res2 : vector<[2]xi64> to vector<[4]xi32>
+// CHECK-NEXT: %[[TMP_OUT_0:[0-9]+]] = vector.insert %[[UNPACK_RES_0]], %[[P1]] [0] : vector<[4]xi32> into vector<4x[4]xi32>
+// CHECK-NEXT: %[[TMP_OUT_1:[0-9]+]] = vector.insert %[[UNPACK_RES_1]], %[[TMP_OUT_0]] [1] : vector<[4]xi32> into vector<4x[4]xi32>
// Same for result rows 2 and 3
-// CHECK-NEXT: %[[T49:[0-9]+]] = llvm.intr.vector.insert %[[T37]], %[[T2]][0] : vector<[4]xi32> into vector<[8]xi32>
-// CHECK-NEXT: %[[T50:[0-9]+]] = llvm.intr.vector.insert %[[T38]], %[[T49]][4] : vector<[4]xi32> into vector<[8]xi32>
-// CHECK-NEXT: %[[T51:[0-9]+]] = llvm.bitcast %[[T50]] : vector<[8]xi32> to vector<[4]xi64>
-// CHECK-NEXT: %[[T52:[0-9]+]] = "llvm.intr.vector.deinterleave2"(%[[T51]]) : (vector<[4]xi64>) -> !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)>
-// CHECK-NEXT: %[[T53:[0-9]+]] = llvm.extractvalue %[[T52]][0] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)>
-// CHECK-NEXT: %[[T54:[0-9]+]] = llvm.extractvalue %[[T52]][1] : !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)>
-// CHECK-NEXT: %[[T55:[0-9]+]] = llvm.bitcast %[[T53]] : vector<[2]xi64> to vector<[4]xi32>
-// CHECK-NEXT: %[[T56:[0-9]+]] = llvm.bitcast %[[T54]] : vector<[2]xi64> to vector<[4]xi32>
-// CHECK-NEXT: %[[T57:[0-9]+]] = llvm.insertvalue %[[T55]], %[[T48]][2] : !llvm.array<4 x vector<[4]xi32>>
-// CHECK-NEXT: %[[T58:[0-9]+]] = llvm.insertvalue %[[T56]], %[[T57]][3] : !llvm.array<4 x vector<[4]xi32>>
+// CHECK-NEXT: %[[T43:[0-9]+]] = vector.scalable.insert %[[PACK_RES_10]], %[[P0]][0] : vector<[4]xi32> into vector<[8]xi32>
+// CHECK-NEXT: %[[T44:[0-9]+]] = vector.scalable.insert %[[PACK_RES_11]], %[[T43]][4] : vector<[4]xi32> into vector<[8]xi32>
+// CHECK-NEXT: %[[T45:[0-9]+]] = vector.bitcast %[[T44]] : vector<[8]xi32> to vector<[4]xi64>
+// CHECK-NEXT: %res1_0, %res2_1 = vector.deinterleave %[[T45]] : vector<[4]xi64> -> vector<[2]xi64>
+// CHECK-NEXT: %[[UNPACK_RES_2:[0-9]+]] = vector.bitcast %res1_0 : vector<[2]xi64> to vector<[4]xi32>
+// CHECK-NEXT: %[[UNPACK_RES_3:[0-9]+]] = vector.bitcast %res2_1 : vector<[2]xi64> to vector<[4]xi32>
+// CHECK-NEXT: %[[TMP_OUT_2:[0-9]+]] = vector.insert %[[UNPACK_RES_2]], %[[TMP_OUT_1]] [2] : vector<[4]xi32> into vector<4x[4]xi32>
+// CHECK-NEXT: %[[OUT:[0-9]+]] = vector.insert %[[UNPACK_RES_3]], %[[TMP_OUT_2]] [3] : vector<[4]xi32> into vector<4x[4]xi32>
+
+// CHECK-NEXT: return %[[OUT]] : vector<4x[4]xi32>
func.func @test_vector_contract_to_usmmla(
%lhs: vector<4x8xi8>,
@@ -93,3 +99,15 @@ func.func @test_vector_contract_to_usmmla(
return %2 : vector<4x[4]xi32>
}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %module : (!transform.any_op) -> !transform.op<"func.func">
+
+ transform.apply_patterns to %func {
+ transform.apply_patterns.vector.arm_sve.lower_contraction
+ } : !transform.op<"func.func">
+
+ transform.yield
+ }
+}
\ No newline at end of file
More information about the llvm-branch-commits
mailing list