[Mlir-commits] [mlir] andrzej/vector/add mmt4d with sve e2e (PR #157815)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Wed Sep 10 02:09:16 PDT 2025
https://github.com/banach-space created https://github.com/llvm/llvm-project/pull/157815
- [mlir][vector] Add a new TD op to wrap unit-dim collapsing patterns
- Add missing LIT excludes
- Remove TestVectorTransferCollapseInnerMostContiguousDims
- [mlir][test] Add e2e test for linalg.mmt4d + SVE
>From e7e9d24c1befa42d3cd08b1ee8472f3c0cfe4a2c Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Mon, 8 Sep 2025 16:03:16 +0000
Subject: [PATCH 1/4] [mlir][vector] Add a new TD op to wrap unit-dim
collapsing patterns
Adds `apply_patterns.vector.drop_inner_most_unit_dims_from_xfer_ops` TD
Op that wraps the following Vector patterns:
* `DropInnerMostUnitDimsTransferRead`
* `DropInnerMostUnitDimsTransferWrite`
This complements other similar patterns.
---
.../Vector/TransformOps/VectorTransformOps.td | 14 ++++++++++++++
.../Vector/TransformOps/VectorTransformOps.cpp | 5 +++++
.../Dialect/Vector/td/xfer-drop-unit-dims.mlir | 11 +++++++++++
.../vector-transfer-collapse-inner-most-dims.mlir | 3 +++
4 files changed, 33 insertions(+)
create mode 100644 mlir/test/Dialect/Vector/td/xfer-drop-unit-dims.mlir
diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index 07a4117a37b2c..85d0b2a28c65b 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -85,6 +85,20 @@ def ApplyDropUnitDimWithShapeCastPatternsOp : Op<Transform_Dialect,
let assemblyFormat = "attr-dict";
}
+def ApplyDropInnerMostUnitDimsFromXferOpsPatternsOp : Op<Transform_Dialect,
+ "apply_patterns.vector.drop_inner_most_unit_dims_from_xfer_ops",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ let description = [{
+ Apply vector patterns to drop the inner most unit dims from
+ vector.transfer_read and vector.transfer_write Ops by taking a subview (via
+ memref.subview) of the original source/destination MemRef. Since it
+ requires the input/ouptu to be MemRefs, this Op is only helpful
+ past-bufferization.
+ }];
+
+ let assemblyFormat = "attr-dict";
+}
+
def ApplyTransferPermutationPatternsOp : Op<Transform_Dialect,
"apply_patterns.vector.transfer_permutation_patterns",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index fe066dc04ad55..1bad9221df915 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -88,6 +88,11 @@ void transform::ApplyDropUnitDimWithShapeCastPatternsOp::populatePatterns(
vector::populateDropUnitDimWithShapeCastPatterns(patterns);
}
+void transform::ApplyDropInnerMostUnitDimsFromXferOpsPatternsOp::
+ populatePatterns(RewritePatternSet &patterns) {
+ vector::populateDropInnerMostUnitDimsXferOpPatterns(patterns);
+}
+
void transform::ApplyLowerBitCastPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
vector::populateVectorBitCastLoweringPatterns(patterns);
diff --git a/mlir/test/Dialect/Vector/td/xfer-drop-unit-dims.mlir b/mlir/test/Dialect/Vector/td/xfer-drop-unit-dims.mlir
new file mode 100644
index 0000000000000..5bffa20842b0c
--- /dev/null
+++ b/mlir/test/Dialect/Vector/td/xfer-drop-unit-dims.mlir
@@ -0,0 +1,11 @@
+module @transforms attributes { transform.with_named_sequence } {
+ transform.named_sequence @drop_unit_dims(%module: !transform.any_op {transform.readonly}) {
+
+ %func_op = transform.structured.match ops{["func.func"]} in %module : (!transform.any_op) -> !transform.op<"func.func">
+ transform.apply_patterns to %func_op {
+ transform.apply_patterns.vector.drop_inner_most_unit_dims_from_xfer_ops
+ } : !transform.op<"func.func">
+
+ transform.yield
+ }
+}
diff --git a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
index cd56c1bf9695b..52f65215d9dd4 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
@@ -1,4 +1,7 @@
// RUN: mlir-opt %s -test-vector-transfer-collapse-inner-most-dims -split-input-file | FileCheck %s
+// RUN: mlir-opt -split-input-file \
+// RUN: -transform-preload-library='transform-library-paths=%p/td/xfer-drop-unit-dims.mlir' \
+// RUN: -transform-interpreter=entry-point=drop_unit_dims %s | FileCheck %s
//-----------------------------------------------------------------------------
// 1. vector.transfer_read
>From de726b0a14f1c8e03b2b352bf2c216cb08cb62cd Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Tue, 9 Sep 2025 08:27:30 +0000
Subject: [PATCH 2/4] Add missing LIT excludes
---
mlir/test/Dialect/Vector/lit.local.cfg | 2 ++
1 file changed, 2 insertions(+)
create mode 100644 mlir/test/Dialect/Vector/lit.local.cfg
diff --git a/mlir/test/Dialect/Vector/lit.local.cfg b/mlir/test/Dialect/Vector/lit.local.cfg
new file mode 100644
index 0000000000000..62743008a3e3a
--- /dev/null
+++ b/mlir/test/Dialect/Vector/lit.local.cfg
@@ -0,0 +1,2 @@
+# Skip the directory with input TD sequences
+config.excludes = ["td"]
>From 01f3b48cc337513d24f0504b87fea20f4adf64f5 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Wed, 10 Sep 2025 07:20:45 +0000
Subject: [PATCH 3/4] Remove TestVectorTransferCollapseInnerMostContiguousDims
---
...tor-transfer-collapse-inner-most-dims.mlir | 1 -
.../Dialect/Vector/TestVectorTransforms.cpp | 32 -------------------
2 files changed, 33 deletions(-)
diff --git a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
index 52f65215d9dd4..18c28799a62e5 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
@@ -1,4 +1,3 @@
-// RUN: mlir-opt %s -test-vector-transfer-collapse-inner-most-dims -split-input-file | FileCheck %s
// RUN: mlir-opt -split-input-file \
// RUN: -transform-preload-library='transform-library-paths=%p/td/xfer-drop-unit-dims.mlir' \
// RUN: -transform-interpreter=entry-point=drop_unit_dims %s | FileCheck %s
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index bb1598ee3efe5..20ca6b37ece4d 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -344,36 +344,6 @@ struct TestVectorTransferOpt
}
};
-struct TestVectorTransferCollapseInnerMostContiguousDims
- : public PassWrapper<TestVectorTransferCollapseInnerMostContiguousDims,
- OperationPass<func::FuncOp>> {
- MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
- TestVectorTransferCollapseInnerMostContiguousDims)
-
- TestVectorTransferCollapseInnerMostContiguousDims() = default;
- TestVectorTransferCollapseInnerMostContiguousDims(
- const TestVectorTransferCollapseInnerMostContiguousDims &pass) = default;
-
- void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<memref::MemRefDialect, affine::AffineDialect>();
- }
-
- StringRef getArgument() const final {
- return "test-vector-transfer-collapse-inner-most-dims";
- }
-
- StringRef getDescription() const final {
- return "Test lowering patterns that reduces the rank of the vector "
- "transfer memory and vector operands.";
- }
-
- void runOnOperation() override {
- RewritePatternSet patterns(&getContext());
- populateDropInnerMostUnitDimsXferOpPatterns(patterns);
- (void)applyPatternsGreedily(getOperation(), std::move(patterns));
- }
-};
-
struct TestVectorSinkPatterns
: public PassWrapper<TestVectorSinkPatterns, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorSinkPatterns)
@@ -1057,8 +1027,6 @@ void registerTestVectorLowerings() {
PassRegistration<TestVectorTransferOpt>();
- PassRegistration<TestVectorTransferCollapseInnerMostContiguousDims>();
-
PassRegistration<TestVectorSinkPatterns>();
PassRegistration<TestVectorReduceToContractPatternsPatterns>();
>From eb702dba597227b843ace304530063a56936caee Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Wed, 10 Sep 2025 09:03:56 +0000
Subject: [PATCH 4/4] [mlir][test] Add e2e test for linalg.mmt4d + SVE
Adds an end-to-end test for computing matrix-multiplication using
linalg.mmt4d, combined with "scalable" tiling and "scalable"
vectorisation. This is similar to an existing example that does not use
"scalable" sizes:
* test/Integration/Dialect/Linalg/CPU/pack-unpack-mmt4d.mlir
---
.../Linalg/CPU/ArmSVE/pack-unpack-mmt4d.mlir | 398 ++++++++++++++++++
1 file changed, 398 insertions(+)
create mode 100644 mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/pack-unpack-mmt4d.mlir
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/pack-unpack-mmt4d.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/pack-unpack-mmt4d.mlir
new file mode 100644
index 0000000000000..d001353ef1d7e
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/pack-unpack-mmt4d.mlir
@@ -0,0 +1,398 @@
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE: -transform-interpreter -test-transform-dialect-erase-schedule \
+// DEFINE: -cse -canonicalize -test-lower-to-llvm
+// DEFINE: %{entry_point} = main
+// DEFINE: %{run} = mlir-runner -e %{entry_point} -entry-point-result=void \
+// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils
+
+// RUN: %{compile} | %{run} | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+/// HIGH-LEVEL OVERVIEW
+///
+/// End-to-end test for computing matrix-multiplication using linalg.mmt4d. In
+/// particular, demonstrates how the following MLIR sequence (implemented in
+/// @matmul_via_mmt4d):
+///
+/// A_pack = linalg.pack A
+/// B_pack = linalg.pack B
+/// C_pack = linalg.pack C
+/// out_pack = linalg.mmt4d(A_pack, B_pack, C_pack)
+///
+/// is equivalent to:
+///
+/// linalg.matmul(A, B, C)
+///
+/// (implemented in @matmul_via_matmul).
+///
+/// NOTES ON IMPLEMENTATION
+/// 1. The MMT4D example uses _scalable_ tile sizes for data tiling.
+/// * The matrix-multiplication dimension that's scalable: N.
+///
+/// 2. The lowering of linalg.mmt4d leverages scalable vectorisation.
+/// * The matrix-multiplication dimension that's scalable: N (to match data
+/// tiling configuration).
+///
+/// 3. Neither `linalg.pack` nor `linalg.unpack` are vectorised ATM.
+///
+/// 4. The MMT4D and Pack/Unpack Ops are kept in seperate functions to isolate
+/// the corresponding lowering and lowering configs.
+/// * TODO: Ideally, we should consider fusion opportunities by moving these
+/// Ops into one function.
+//===----------------------------------------------------------------------===//
+
+//===----------------------------------------------------------------------===//
+// @main
+//
+// The main entry point that computes matrix multiplication via linalg.mmt4d
+// and linalg.matmul. Note, the output should be independent of the underlying
+// Linalg Op used, as well as SVE vector length.
+//===----------------------------------------------------------------------===//
+func.func @main() {
+ // Allocate and initialise the inputs
+ %A_empty = tensor.empty() : tensor<7x16xi32>
+ %B_empty = tensor.empty() : tensor<16x13xi32>
+
+ %c3 = arith.constant 3 : i32
+ %c4 = arith.constant 4 : i32
+ %A = linalg.fill ins(%c3 : i32) outs(%A_empty : tensor<7x16xi32>) -> tensor<7x16xi32>
+ %B = linalg.fill ins(%c4 : i32) outs(%B_empty : tensor<16x13xi32>) -> tensor<16x13xi32>
+ %C = arith.constant dense<[
+ [ 1, 8, 15, 22, 29, 36, 43, 50, 57, 64, 71, 78, 85],
+ [ 2, 9, 16, 23, 30, 37, 44, 51, 58, 65, 72, 79, 86],
+ [ 3, 10, 17, 24, 31, 38, 45, 52, 59, 66, 73, 80, 87],
+ [ 4, 11, 18, 25, 32, 39, 46, 53, 60, 67, 74, 81, 88],
+ [ 5, 12, 19, 26, 33, 40, 47, 54, 61, 68, 75, 82, 89],
+ [ 6, 13, 20, 27, 34, 41, 48, 55, 62, 69, 76, 83, 90],
+ [ 7, 14, 21, 28, 35, 42, 49, 56, 63, 70, 77, 84, 91]
+ ]> : tensor<7x13xi32>
+
+ // VARIANT: Matrix multiplication via linalg.mmt4d
+ // CHECK: Unranked Memref
+ // CHECK: [193, 200, 207, 214, 221, 228, 235, 242, 249, 256, 263, 270, 277]
+ // CHECK: [194, 201, 208, 215, 222, 229, 236, 243, 250, 257, 264, 271, 278]
+ // CHECK: [195, 202, 209, 216, 223, 230, 237, 244, 251, 258, 265, 272, 279]
+ // CHECK: [196, 203, 210, 217, 224, 231, 238, 245, 252, 259, 266, 273, 280]
+ // CHECK: [197, 204, 211, 218, 225, 232, 239, 246, 253, 260, 267, 274, 281]
+ // CHECK: [198, 205, 212, 219, 226, 233, 240, 247, 254, 261, 268, 275, 282]
+ // CHECK: [199, 206, 213, 220, 227, 234, 241, 248, 255, 262, 269, 276, 283]
+ %C_mmt4d = func.call @matmul_via_mmt4d(%A, %B, %C) : (tensor<7x16xi32>, tensor<16x13xi32>, tensor<7x13xi32>) -> tensor<7x13xi32>
+ %C_mmt4d_cast = tensor.cast %C_mmt4d : tensor<7x13xi32> to tensor<*xi32>
+ vector.print str "--------------------------\n"
+ vector.print str "RESULT FROM linalg.mmt4d:\n"
+ vector.print str "--------------------------\n"
+ call @printMemrefI32(%C_mmt4d_cast) : (tensor<*xi32>) -> ()
+
+ // VARIANT: Matrix multiplication via linalg.matmul
+ // CHECK: Unranked Memref
+ // CHECK: [193, 200, 207, 214, 221, 228, 235, 242, 249, 256, 263, 270, 277]
+ // CHECK: [194, 201, 208, 215, 222, 229, 236, 243, 250, 257, 264, 271, 278]
+ // CHECK: [195, 202, 209, 216, 223, 230, 237, 244, 251, 258, 265, 272, 279]
+ // CHECK: [196, 203, 210, 217, 224, 231, 238, 245, 252, 259, 266, 273, 280]
+ // CHECK: [197, 204, 211, 218, 225, 232, 239, 246, 253, 260, 267, 274, 281]
+ // CHECK: [198, 205, 212, 219, 226, 233, 240, 247, 254, 261, 268, 275, 282]
+ // CHECK: [199, 206, 213, 220, 227, 234, 241, 248, 255, 262, 269, 276, 283]
+ %C_matmul = func.call @matmul(%A, %B, %C) : (tensor<7x16xi32>, tensor<16x13xi32>, tensor<7x13xi32>) -> tensor<7x13xi32>
+ %C_matmul_cast = tensor.cast %C_matmul : tensor<7x13xi32> to tensor<*xi32>
+ vector.print str "\n--------------------------\n"
+ vector.print str "RESULT FROM linalg.matmul:\n"
+ vector.print str "--------------------------\n"
+ call @printMemrefI32(%C_matmul_cast) : (tensor<*xi32>) -> ()
+
+ return
+}
+
+//===----------------------------------------------------------------------===//
+// @matmul_via_matmul
+//
+// Implements matrix-multiplication via linalg.matmul
+//===----------------------------------------------------------------------===//
+func.func private @matmul(%A: tensor<7x16xi32>, %B: tensor<16x13xi32>, %C: tensor<7x13xi32>) -> tensor<7x13xi32> {
+ %C_matmul = linalg.matmul ins(%A, %B: tensor<7x16xi32>, tensor<16x13xi32>)
+ outs(%C: tensor<7x13xi32>) -> tensor<7x13xi32>
+
+ return %C_matmul : tensor<7x13xi32>
+}
+
+//===----------------------------------------------------------------------===//
+// @matmul_via_mmt4d
+//
+// Implements matrix-multiplication via linalg.mmt4d
+//===----------------------------------------------------------------------===//
+func.func private @pack_lhs(%A: tensor<7x16xi32>) -> tensor<1x16x8x1xi32> {
+ %pad = arith.constant 0 : i32
+
+ %A_pack_empty = tensor.empty() : tensor<1x16x8x1xi32>
+ %A_pack = linalg.pack %A
+ padding_value(%pad : i32)
+ inner_dims_pos = [0, 1]
+ inner_tiles = [8, 1]
+ into %A_pack_empty : tensor<7x16xi32> -> tensor<1x16x8x1xi32>
+
+ return %A_pack : tensor<1x16x8x1xi32>
+}
+
+//===----------------------------------------------------------------------===//
+// @pack_rhs
+//
+// Implements packing for the B matrix (RHS) in matrix multiplication. The
+// inner tile size is "scalable": 8 * vscale.
+//===----------------------------------------------------------------------===//
+func.func private @pack_rhs(%B: tensor<16x13xi32>) -> tensor<?x16x?x1xi32> {
+ %pad = arith.constant 0 : i32
+
+ // Compute the outer tile size.
+ %vs = vector.vscale
+ %c8 = arith.constant 8 : index
+ %vs_c8 = arith.muli %vs, %c8 : index
+ %c13 = arith.constant 13 : index
+ %outer_tile_size = arith.ceildivui %c13, %vs_c8 : index
+
+ %B_pack_empty = tensor.empty(%outer_tile_size, %vs_c8) : tensor<?x16x?x1xi32>
+ %B_pack = linalg.pack %B
+ padding_value(%pad : i32)
+ outer_dims_perm = [1, 0]
+ inner_dims_pos = [1, 0]
+ inner_tiles = [%vs_c8, 1]
+ into %B_pack_empty : tensor<16x13xi32> -> tensor<?x16x?x1xi32>
+
+ return %B_pack : tensor<?x16x?x1xi32>
+}
+
+//===----------------------------------------------------------------------===//
+// @pack_acc
+//
+// Implements packing for the C matrix (accumulator) in matrix multiplication.
+// The inner tile size is "scalable": 8 * vscale
+//===----------------------------------------------------------------------===//
+func.func private @pack_acc(%C: tensor<7x13xi32>) -> tensor<1x?x8x?xi32> {
+ %pad = arith.constant 0 : i32
+
+ // Compute the outer tile size.
+ %c13 = arith.constant 13 : index
+ %vs = vector.vscale
+ %c8 = arith.constant 8 : index
+ %vs_c8 = arith.muli %vs, %c8 : index
+ %outer_tile_size = arith.ceildivui %c13, %vs_c8 : index
+
+ %C_pack_empty = tensor.empty(%outer_tile_size, %vs_c8) : tensor<1x?x8x?xi32>
+ %C_pack = linalg.pack %C
+ padding_value(%pad : i32)
+ outer_dims_perm = [0, 1]
+ inner_dims_pos = [0, 1]
+ inner_tiles = [8, %vs_c8] into %C_pack_empty : tensor<7x13xi32> -> tensor<1x?x8x?xi32>
+
+ return %C_pack : tensor<1x?x8x?xi32>
+}
+
+//===----------------------------------------------------------------------===//
+// @unpack_acc
+//
+// Implements unpacking for the C matrix (accumulator) in matrix
+// multiplication. The inner tile size is "scalable": 8 * vscale
+//===----------------------------------------------------------------------===//
+func.func private @unpack_acc(%C_packed: tensor<1x?x8x?xi32>) -> tensor<7x13xi32> {
+ %vs = vector.vscale
+ %c8 = arith.constant 8 : index
+ %vs_c8 = arith.muli %vs, %c8 : index
+
+ %C_out_empty = tensor.empty() : tensor<7x13xi32>
+ %C_out_unpack = linalg.unpack %C_packed
+ outer_dims_perm = [0, 1]
+ inner_dims_pos = [0, 1]
+ inner_tiles = [8, %vs_c8]
+ into %C_out_empty : tensor<1x?x8x?xi32> -> tensor<7x13xi32>
+
+ return %C_out_unpack: tensor<7x13xi32>
+}
+
+//===----------------------------------------------------------------------===//
+// Helper methods for printing
+//===----------------------------------------------------------------------===//
+func.func private @print_pack_A(%A_pack : tensor<1x16x8x1xi32>) -> () {
+ %A_pack_cast = tensor.cast %A_pack : tensor<1x16x8x1xi32> to tensor<*xi32>
+ call @printMemrefI32(%A_pack_cast) : (tensor<*xi32>) -> ()
+
+ return
+}
+
+func.func private @print_pack_B(%B_pack : tensor<?x16x?x1xi32>) -> () {
+ %B_pack_cast = tensor.cast %B_pack : tensor<?x16x?x1xi32> to tensor<*xi32>
+ call @printMemrefI32(%B_pack_cast) : (tensor<*xi32>) -> ()
+
+ return
+}
+
+func.func private @print_pack_C(%C_pack : tensor<1x?x8x?xi32>) -> () {
+ %C_pack_cast = tensor.cast %C_pack : tensor<1x?x8x?xi32> to tensor<*xi32>
+ call @printMemrefI32(%C_pack_cast) : (tensor<*xi32>) -> ()
+
+ return
+}
+
+//===----------------------------------------------------------------------===//
+// @matmul_via_mmt4d
+//
+// Implements matrix-multiplication via linalg.mmt4d
+//===----------------------------------------------------------------------===//
+func.func private @matmul_via_mmt4d(%A: tensor<7x16xi32>, %B: tensor<16x13xi32>, %C: tensor<7x13xi32>) -> tensor<7x13xi32> {
+ // Pack input matrices
+ %A_pack = func.call @pack_lhs(%A): (tensor<7x16xi32>) -> tensor<1x16x8x1xi32>
+ %B_pack = func.call @pack_rhs(%B): (tensor<16x13xi32>) -> tensor<?x16x?x1xi32>
+ %C_pack = func.call @pack_acc(%C): (tensor<7x13xi32>) -> tensor<1x?x8x?xi32>
+
+ // Print the packed matrices (this is the only _visible_ part that changes
+ // when adjusting the SVE vector size).
+ func.call @print_pack_A(%A_pack) : (tensor<1x16x8x1xi32>) -> ()
+ func.call @print_pack_B(%B_pack) : (tensor<?x16x?x1xi32>) -> ()
+ func.call @print_pack_C(%C_pack) : (tensor<1x?x8x?xi32>) -> ()
+
+ // MMT4D
+ %mmt4d = linalg.mmt4d ins(%A_pack, %B_pack : tensor<1x16x8x1xi32>, tensor<?x16x?x1xi32>) outs(%C_pack : tensor<1x?x8x?xi32>) -> tensor<1x?x8x?xi32>
+
+ // Unpack the output
+ %C_out_unpack = func.call @unpack_acc(%mmt4d) : (tensor<1x?x8x?xi32>) -> tensor<7x13xi32>
+
+ return %C_out_unpack : tensor<7x13xi32>
+}
+
+//===----------------------------------------------------------------------===//
+// TD Sequence
+//===----------------------------------------------------------------------===//
+module @transforms attributes { transform.with_named_sequence } {
+ transform.named_sequence @__transform_main(%module: !transform.any_op {transform.consumed}) {
+ //==========================================================================
+ // HANDLE MMT4D
+ //==========================================================================
+ %mmt4d = transform.collect_matching @match_mmt4d in %module : (!transform.any_op) -> (!transform.any_op)
+ %mmt4d_func = transform.get_parent_op %mmt4d {isolated_from_above} : (!transform.any_op) -> !transform.op<"func.func">
+
+ // Step 1: Tile
+ // Tile parallel dims (note, the N dim is scalable!)
+ %tiled_mmt4d_parallel, %_:4 = transform.structured.tile_using_for %mmt4d tile_sizes [1, 1, 0, 8, [8], 0]
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+ // Tile reduction dims
+ %tiled_mmt4d, %_1:2 = transform.structured.tile_using_for %tiled_mmt4d_parallel tile_sizes [0, 0, 1, 0, 0, 1]
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+
+ // Step 2: Vectorize linalg.mmt4d (note, the N dim is scalable!)
+ transform.structured.vectorize %tiled_mmt4d
+ vector_sizes [1, 1, 1, 8, [8], 1] {assume_dynamic_dims_match_vec_sizes} : !transform.any_op
+
+ // Step 3: Simplify
+ // vector.multi_reduction --> vector.contract
+ // Generates a 6-dim vector.contract with the dim matching the original MMT4D Op
+ // and with the following split into parallel and reduction dims:
+ // * parallel, parallel, reduction, parallel, parallel, reduction
+ transform.apply_patterns to %mmt4d_func {
+ transform.apply_patterns.vector.reduction_to_contract
+ // Reduce the rank of xfer ops. This transforms vector.contract to be
+ // more matmul-like and to enable the lowering to outer product Ops.
+ transform.apply_patterns.vector.transfer_permutation_patterns
+ } : !transform.op<"func.func">
+
+ // Hoisting and LICM - not strictly required
+ %mmt4d_func_h = transform.structured.hoist_redundant_vector_transfers %mmt4d_func
+ : (!transform.op<"func.func">) -> !transform.op<"func.func">
+ %all_loops = transform.structured.match interface{LoopLikeInterface} in %mmt4d_func_h
+ : (!transform.op<"func.func">) -> !transform.any_op
+ transform.apply_licm to %all_loops : !transform.any_op
+ transform.loop.hoist_loop_invariant_subsets %all_loops : !transform.any_op
+
+ // Simplification
+ transform.apply_patterns to %mmt4d_func_h {
+ transform.apply_patterns.vector.reduction_to_contract
+ transform.apply_patterns.vector.cast_away_vector_leading_one_dim
+ transform.apply_patterns.canonicalization
+ } : !transform.op<"func.func">
+
+ //==========================================================================
+ // HANDLE PACK + UNPACK
+ //==========================================================================
+ %pack = transform.structured.match ops{["linalg.pack"]} in %module : (!transform.any_op) -> !transform.any_op
+ %unpack = transform.structured.match ops{["linalg.unpack"]} in %module : (!transform.any_op) -> !transform.any_op
+
+ // 1.1 Tile the linalg.pack Op so that we can decompose it into e.g. tensor.pad
+ // and other lower-level Ops (see step 2.1)
+ %tiled_pack_op_p, %loops_pack:2 = transform.structured.tile_using_for %pack tile_sizes [1, 1]
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+
+ // 1.2 Tile the linalg.unpack Op so that we can decompose it into e.g. tensor.pad
+ // and other lower-level Ops (see step 2)
+ %tiled_unpack_op_p, %loops_unpack:2 = transform.structured.tile_using_for %unpack tile_sizes [8, 1]
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+
+ // 2.1. Decompose tiled PackOp into lower-level Ops + simplify
+ %func_op_pack = transform.get_parent_op %tiled_pack_op_p {isolated_from_above} : (!transform.any_op) -> !transform.op<"func.func">
+ transform.apply_patterns to %func_op_pack {
+ transform.apply_patterns.linalg.decompose_pack_unpack
+ transform.apply_patterns.linalg.decompose_pad
+ } : !transform.op<"func.func">
+
+ transform.apply_patterns to %func_op_pack {
+ transform.apply_patterns.tensor.fold_tensor_subset_ops
+ transform.apply_patterns.canonicalization
+ } : !transform.op<"func.func">
+
+ // 2.2. Decompose tiled UnpackOp into lower-level Ops + simplify
+ %func_op_unpack = transform.get_parent_op %tiled_unpack_op_p {isolated_from_above} : (!transform.any_op) -> !transform.op<"func.func">
+ transform.apply_patterns to %func_op_unpack {
+ transform.apply_patterns.linalg.decompose_pack_unpack
+ } : !transform.op<"func.func">
+
+ transform.apply_patterns to %func_op_unpack {
+ transform.apply_patterns.tensor.fold_tensor_subset_ops
+ transform.apply_patterns.canonicalization
+ } : !transform.op<"func.func">
+
+ //==========================================================================
+ // BUFFERIZATION
+ //==========================================================================
+ %bufferize = transform.bufferization.one_shot_bufferize %module
+ {bufferize_function_boundaries=true} : (!transform.any_op) -> !transform.any_op
+
+ //==========================================================================
+ // SIMPLIFY THE CONTRACT Op
+ //==========================================================================
+ %contract = transform.collect_matching @match_contract in %bufferize : (!transform.any_op) -> (!transform.any_op)
+ %contract_func = transform.get_parent_op %contract {isolated_from_above} : (!transform.any_op) -> !transform.op<"func.func">
+
+ // Drop trailing unit dims (the correspondong pattern works only
+ // post-bufferization)
+ transform.apply_patterns to %contract_func {
+ transform.apply_patterns.tensor.fold_tensor_subset_ops
+ transform.apply_patterns.vector.drop_inner_most_unit_dims_from_xfer_ops
+ transform.apply_patterns.canonicalization
+ } : !transform.op<"func.func">
+
+ //==========================================================================
+ // LOWER CONTRACT TO FMA
+ //==========================================================================
+ transform.apply_patterns to %contract_func {
+ transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
+ transform.apply_patterns.vector.lower_outerproduct
+ } : !transform.op<"func.func">
+
+ transform.yield
+ }
+
+ //==========================================================================
+ // TD MATCHERS (helper hooks)
+ //==========================================================================
+ transform.named_sequence @match_mmt4d(
+ %entry: !transform.any_op {transform.readonly}) -> !transform.any_op {
+ transform.match.operation_name %entry ["linalg.mmt4d"] : !transform.any_op
+ transform.yield %entry : !transform.any_op
+ }
+
+ transform.named_sequence @match_contract(
+ %entry: !transform.any_op {transform.readonly}) -> !transform.any_op {
+ transform.match.operation_name %entry ["vector.contract"] : !transform.any_op
+ transform.yield %entry : !transform.any_op
+ }
+}
+
+//===----------------------------------------------------------------------===//
+// Function signatures
+//===----------------------------------------------------------------------===//
+func.func private @printMemrefI32(%ptr : tensor<*xi32>)
More information about the Mlir-commits
mailing list