[Mlir-commits] [mlir] andrzej/vector/add mmt4d with sve e2e (PR #157815)

Andrzej WarzyƄski llvmlistbot at llvm.org
Wed Sep 10 02:09:16 PDT 2025


https://github.com/banach-space created https://github.com/llvm/llvm-project/pull/157815

- [mlir][vector] Add a new TD op to wrap unit-dim collapsing patterns
- Add missing LIT excludes
- Remove TestVectorTransferCollapseInnerMostContiguousDims
- [mlir][test] Add e2e test for linalg.mmt4d + SVE


>From e7e9d24c1befa42d3cd08b1ee8472f3c0cfe4a2c Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Mon, 8 Sep 2025 16:03:16 +0000
Subject: [PATCH 1/4] [mlir][vector] Add a new TD op to wrap unit-dim
 collapsing patterns

Adds `apply_patterns.vector.drop_inner_most_unit_dims_from_xfer_ops` TD
Op that wraps the following Vector patterns:
  * `DropInnerMostUnitDimsTransferRead`
  * `DropInnerMostUnitDimsTransferWrite`

This complements other similar patterns.
---
 .../Vector/TransformOps/VectorTransformOps.td      | 14 ++++++++++++++
 .../Vector/TransformOps/VectorTransformOps.cpp     |  5 +++++
 .../Dialect/Vector/td/xfer-drop-unit-dims.mlir     | 11 +++++++++++
 .../vector-transfer-collapse-inner-most-dims.mlir  |  3 +++
 4 files changed, 33 insertions(+)
 create mode 100644 mlir/test/Dialect/Vector/td/xfer-drop-unit-dims.mlir

diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index 07a4117a37b2c..85d0b2a28c65b 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -85,6 +85,20 @@ def ApplyDropUnitDimWithShapeCastPatternsOp : Op<Transform_Dialect,
   let assemblyFormat = "attr-dict";
 }
 
+def ApplyDropInnerMostUnitDimsFromXferOpsPatternsOp : Op<Transform_Dialect,
+    "apply_patterns.vector.drop_inner_most_unit_dims_from_xfer_ops",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Apply vector patterns to drop the inner most unit dims from
+    vector.transfer_read and vector.transfer_write Ops by taking a subview (via
+    memref.subview) of the original source/destination MemRef. Since it
+    requires the input/ouptu to be MemRefs, this Op is only helpful
+    past-bufferization.
+  }];
+
+  let assemblyFormat = "attr-dict";
+}
+
 def ApplyTransferPermutationPatternsOp : Op<Transform_Dialect,
     "apply_patterns.vector.transfer_permutation_patterns",
     [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index fe066dc04ad55..1bad9221df915 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -88,6 +88,11 @@ void transform::ApplyDropUnitDimWithShapeCastPatternsOp::populatePatterns(
   vector::populateDropUnitDimWithShapeCastPatterns(patterns);
 }
 
+void transform::ApplyDropInnerMostUnitDimsFromXferOpsPatternsOp::
+    populatePatterns(RewritePatternSet &patterns) {
+  vector::populateDropInnerMostUnitDimsXferOpPatterns(patterns);
+}
+
 void transform::ApplyLowerBitCastPatternsOp::populatePatterns(
     RewritePatternSet &patterns) {
   vector::populateVectorBitCastLoweringPatterns(patterns);
diff --git a/mlir/test/Dialect/Vector/td/xfer-drop-unit-dims.mlir b/mlir/test/Dialect/Vector/td/xfer-drop-unit-dims.mlir
new file mode 100644
index 0000000000000..5bffa20842b0c
--- /dev/null
+++ b/mlir/test/Dialect/Vector/td/xfer-drop-unit-dims.mlir
@@ -0,0 +1,11 @@
+module @transforms attributes { transform.with_named_sequence } {
+  transform.named_sequence @drop_unit_dims(%module: !transform.any_op {transform.readonly}) {
+
+    %func_op = transform.structured.match ops{["func.func"]} in %module : (!transform.any_op) -> !transform.op<"func.func">
+    transform.apply_patterns to %func_op {
+      transform.apply_patterns.vector.drop_inner_most_unit_dims_from_xfer_ops
+    } : !transform.op<"func.func">
+
+    transform.yield
+   }
+}
diff --git a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
index cd56c1bf9695b..52f65215d9dd4 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
@@ -1,4 +1,7 @@
 // RUN: mlir-opt %s -test-vector-transfer-collapse-inner-most-dims -split-input-file | FileCheck %s
+// RUN: mlir-opt -split-input-file \
+// RUN: -transform-preload-library='transform-library-paths=%p/td/xfer-drop-unit-dims.mlir' \
+// RUN: -transform-interpreter=entry-point=drop_unit_dims %s | FileCheck %s
 
 //-----------------------------------------------------------------------------
 // 1. vector.transfer_read

>From de726b0a14f1c8e03b2b352bf2c216cb08cb62cd Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Tue, 9 Sep 2025 08:27:30 +0000
Subject: [PATCH 2/4] Add missing LIT excludes

---
 mlir/test/Dialect/Vector/lit.local.cfg | 2 ++
 1 file changed, 2 insertions(+)
 create mode 100644 mlir/test/Dialect/Vector/lit.local.cfg

diff --git a/mlir/test/Dialect/Vector/lit.local.cfg b/mlir/test/Dialect/Vector/lit.local.cfg
new file mode 100644
index 0000000000000..62743008a3e3a
--- /dev/null
+++ b/mlir/test/Dialect/Vector/lit.local.cfg
@@ -0,0 +1,2 @@
+# Skip the directory with input TD sequences
+config.excludes = ["td"]

>From 01f3b48cc337513d24f0504b87fea20f4adf64f5 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Wed, 10 Sep 2025 07:20:45 +0000
Subject: [PATCH 3/4] Remove TestVectorTransferCollapseInnerMostContiguousDims

---
 ...tor-transfer-collapse-inner-most-dims.mlir |  1 -
 .../Dialect/Vector/TestVectorTransforms.cpp   | 32 -------------------
 2 files changed, 33 deletions(-)

diff --git a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
index 52f65215d9dd4..18c28799a62e5 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
@@ -1,4 +1,3 @@
-// RUN: mlir-opt %s -test-vector-transfer-collapse-inner-most-dims -split-input-file | FileCheck %s
 // RUN: mlir-opt -split-input-file \
 // RUN: -transform-preload-library='transform-library-paths=%p/td/xfer-drop-unit-dims.mlir' \
 // RUN: -transform-interpreter=entry-point=drop_unit_dims %s | FileCheck %s
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index bb1598ee3efe5..20ca6b37ece4d 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -344,36 +344,6 @@ struct TestVectorTransferOpt
   }
 };
 
-struct TestVectorTransferCollapseInnerMostContiguousDims
-    : public PassWrapper<TestVectorTransferCollapseInnerMostContiguousDims,
-                         OperationPass<func::FuncOp>> {
-  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
-      TestVectorTransferCollapseInnerMostContiguousDims)
-
-  TestVectorTransferCollapseInnerMostContiguousDims() = default;
-  TestVectorTransferCollapseInnerMostContiguousDims(
-      const TestVectorTransferCollapseInnerMostContiguousDims &pass) = default;
-
-  void getDependentDialects(DialectRegistry &registry) const override {
-    registry.insert<memref::MemRefDialect, affine::AffineDialect>();
-  }
-
-  StringRef getArgument() const final {
-    return "test-vector-transfer-collapse-inner-most-dims";
-  }
-
-  StringRef getDescription() const final {
-    return "Test lowering patterns that reduces the rank of the vector "
-           "transfer memory and vector operands.";
-  }
-
-  void runOnOperation() override {
-    RewritePatternSet patterns(&getContext());
-    populateDropInnerMostUnitDimsXferOpPatterns(patterns);
-    (void)applyPatternsGreedily(getOperation(), std::move(patterns));
-  }
-};
-
 struct TestVectorSinkPatterns
     : public PassWrapper<TestVectorSinkPatterns, OperationPass<func::FuncOp>> {
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorSinkPatterns)
@@ -1057,8 +1027,6 @@ void registerTestVectorLowerings() {
 
   PassRegistration<TestVectorTransferOpt>();
 
-  PassRegistration<TestVectorTransferCollapseInnerMostContiguousDims>();
-
   PassRegistration<TestVectorSinkPatterns>();
 
   PassRegistration<TestVectorReduceToContractPatternsPatterns>();

>From eb702dba597227b843ace304530063a56936caee Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Wed, 10 Sep 2025 09:03:56 +0000
Subject: [PATCH 4/4] [mlir][test] Add e2e test for linalg.mmt4d + SVE

Adds an end-to-end test for computing matrix-multiplication using
linalg.mmt4d, combined with "scalable" tiling and "scalable"
vectorisation. This is similar to an existing example that does not use
"scalable" sizes:
* test/Integration/Dialect/Linalg/CPU/pack-unpack-mmt4d.mlir
---
 .../Linalg/CPU/ArmSVE/pack-unpack-mmt4d.mlir  | 398 ++++++++++++++++++
 1 file changed, 398 insertions(+)
 create mode 100644 mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/pack-unpack-mmt4d.mlir

diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/pack-unpack-mmt4d.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/pack-unpack-mmt4d.mlir
new file mode 100644
index 0000000000000..d001353ef1d7e
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/pack-unpack-mmt4d.mlir
@@ -0,0 +1,398 @@
+// DEFINE: %{compile} =  mlir-opt %s \
+// DEFINE:    -transform-interpreter -test-transform-dialect-erase-schedule \
+// DEFINE:    -cse -canonicalize -test-lower-to-llvm
+// DEFINE: %{entry_point} = main
+// DEFINE: %{run} = mlir-runner -e %{entry_point} -entry-point-result=void \
+// DEFINE:    -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils
+
+// RUN: %{compile} | %{run} | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+/// HIGH-LEVEL OVERVIEW
+///
+/// End-to-end test for computing matrix-multiplication using linalg.mmt4d. In
+/// particular, demonstrates how the following MLIR sequence (implemented in
+/// @matmul_via_mmt4d):
+///
+///   A_pack = linalg.pack A
+///   B_pack = linalg.pack B
+///   C_pack = linalg.pack C
+///   out_pack = linalg.mmt4d(A_pack, B_pack, C_pack)
+///
+/// is equivalent to:
+///
+///  linalg.matmul(A, B, C)
+///
+/// (implemented in @matmul_via_matmul).
+///
+/// NOTES ON IMPLEMENTATION
+/// 1. The MMT4D example uses _scalable_ tile sizes for data tiling.
+///   * The matrix-multiplication dimension that's scalable: N.
+///
+/// 2. The lowering of linalg.mmt4d leverages scalable vectorisation.
+///   * The matrix-multiplication dimension that's scalable: N (to match data
+///     tiling configuration).
+///
+/// 3. Neither `linalg.pack` nor `linalg.unpack` are vectorised ATM.
+///
+/// 4. The MMT4D and Pack/Unpack Ops are kept in seperate functions to isolate
+///    the corresponding lowering and lowering configs.
+///   * TODO: Ideally, we should consider fusion opportunities by moving these
+///     Ops into one function.
+//===----------------------------------------------------------------------===//
+
+//===----------------------------------------------------------------------===//
+// @main
+//
+// The main entry point that computes matrix multiplication via linalg.mmt4d
+// and linalg.matmul. Note, the output should be independent of the underlying
+// Linalg Op used, as well as SVE vector length.
+//===----------------------------------------------------------------------===//
+func.func @main() {
+  // Allocate and initialise the inputs
+  %A_empty = tensor.empty() : tensor<7x16xi32>
+  %B_empty = tensor.empty() : tensor<16x13xi32>
+
+  %c3 = arith.constant 3 : i32
+  %c4 = arith.constant 4 : i32
+  %A = linalg.fill ins(%c3 : i32) outs(%A_empty : tensor<7x16xi32>) -> tensor<7x16xi32>
+  %B = linalg.fill ins(%c4 : i32) outs(%B_empty : tensor<16x13xi32>) -> tensor<16x13xi32>
+  %C = arith.constant dense<[
+    [ 1,  8, 15, 22, 29, 36, 43, 50, 57, 64, 71, 78, 85],
+    [ 2,  9, 16, 23, 30, 37, 44, 51, 58, 65, 72, 79, 86],
+    [ 3, 10, 17, 24, 31, 38, 45, 52, 59, 66, 73, 80, 87],
+    [ 4, 11, 18, 25, 32, 39, 46, 53, 60, 67, 74, 81, 88],
+    [ 5, 12, 19, 26, 33, 40, 47, 54, 61, 68, 75, 82, 89],
+    [ 6, 13, 20, 27, 34, 41, 48, 55, 62, 69, 76, 83, 90],
+    [ 7, 14, 21, 28, 35, 42, 49, 56, 63, 70, 77, 84, 91]
+  ]> : tensor<7x13xi32>
+
+  // VARIANT: Matrix multiplication via linalg.mmt4d
+  // CHECK: Unranked Memref
+  // CHECK:  [193,   200,   207,   214,   221,   228,   235,   242,   249,   256,   263,   270,   277]
+  // CHECK:  [194,   201,   208,   215,   222,   229,   236,   243,   250,   257,   264,   271,   278]
+  // CHECK:  [195,   202,   209,   216,   223,   230,   237,   244,   251,   258,   265,   272,   279]
+  // CHECK:  [196,   203,   210,   217,   224,   231,   238,   245,   252,   259,   266,   273,   280]
+  // CHECK:  [197,   204,   211,   218,   225,   232,   239,   246,   253,   260,   267,   274,   281]
+  // CHECK:  [198,   205,   212,   219,   226,   233,   240,   247,   254,   261,   268,   275,   282]
+  // CHECK:  [199,   206,   213,   220,   227,   234,   241,   248,   255,   262,   269,   276,   283]
+  %C_mmt4d = func.call @matmul_via_mmt4d(%A, %B, %C) : (tensor<7x16xi32>, tensor<16x13xi32>, tensor<7x13xi32>) -> tensor<7x13xi32>
+  %C_mmt4d_cast = tensor.cast %C_mmt4d : tensor<7x13xi32> to tensor<*xi32>
+  vector.print str "--------------------------\n"
+  vector.print str "RESULT FROM linalg.mmt4d:\n"
+  vector.print str "--------------------------\n"
+  call @printMemrefI32(%C_mmt4d_cast) : (tensor<*xi32>) -> ()
+
+  // VARIANT: Matrix multiplication via linalg.matmul
+  // CHECK: Unranked Memref
+  // CHECK:  [193,   200,   207,   214,   221,   228,   235,   242,   249,   256,   263,   270,   277]
+  // CHECK:  [194,   201,   208,   215,   222,   229,   236,   243,   250,   257,   264,   271,   278]
+  // CHECK:  [195,   202,   209,   216,   223,   230,   237,   244,   251,   258,   265,   272,   279]
+  // CHECK:  [196,   203,   210,   217,   224,   231,   238,   245,   252,   259,   266,   273,   280]
+  // CHECK:  [197,   204,   211,   218,   225,   232,   239,   246,   253,   260,   267,   274,   281]
+  // CHECK:  [198,   205,   212,   219,   226,   233,   240,   247,   254,   261,   268,   275,   282]
+  // CHECK:  [199,   206,   213,   220,   227,   234,   241,   248,   255,   262,   269,   276,   283]
+  %C_matmul = func.call @matmul(%A, %B, %C) : (tensor<7x16xi32>, tensor<16x13xi32>, tensor<7x13xi32>) -> tensor<7x13xi32>
+  %C_matmul_cast = tensor.cast %C_matmul : tensor<7x13xi32> to tensor<*xi32>
+  vector.print str "\n--------------------------\n"
+  vector.print str "RESULT FROM linalg.matmul:\n"
+  vector.print str "--------------------------\n"
+  call @printMemrefI32(%C_matmul_cast) : (tensor<*xi32>) -> ()
+
+  return
+}
+
+//===----------------------------------------------------------------------===//
+// @matmul_via_matmul
+//
+// Implements matrix-multiplication via linalg.matmul
+//===----------------------------------------------------------------------===//
+func.func private @matmul(%A: tensor<7x16xi32>, %B: tensor<16x13xi32>, %C: tensor<7x13xi32>) -> tensor<7x13xi32> {
+  %C_matmul = linalg.matmul ins(%A, %B: tensor<7x16xi32>, tensor<16x13xi32>)
+                            outs(%C: tensor<7x13xi32>) -> tensor<7x13xi32>
+
+  return %C_matmul : tensor<7x13xi32>
+}
+
+//===----------------------------------------------------------------------===//
+// @matmul_via_mmt4d
+//
+// Implements matrix-multiplication via linalg.mmt4d
+//===----------------------------------------------------------------------===//
+func.func private @pack_lhs(%A: tensor<7x16xi32>) -> tensor<1x16x8x1xi32> {
+  %pad = arith.constant 0 : i32
+
+  %A_pack_empty = tensor.empty() : tensor<1x16x8x1xi32>
+  %A_pack = linalg.pack %A
+    padding_value(%pad : i32)
+    inner_dims_pos = [0, 1]
+    inner_tiles = [8, 1]
+    into %A_pack_empty : tensor<7x16xi32> -> tensor<1x16x8x1xi32>
+
+  return %A_pack : tensor<1x16x8x1xi32>
+}
+
+//===----------------------------------------------------------------------===//
+// @pack_rhs
+//
+// Implements packing for the B matrix (RHS) in matrix multiplication. The
+// inner tile size is "scalable": 8 * vscale.
+//===----------------------------------------------------------------------===//
+func.func private @pack_rhs(%B: tensor<16x13xi32>) ->  tensor<?x16x?x1xi32> {
+  %pad = arith.constant 0 : i32
+
+  // Compute the outer tile size.
+  %vs = vector.vscale
+  %c8 = arith.constant 8 : index
+  %vs_c8 = arith.muli %vs, %c8 : index
+  %c13 = arith.constant 13 : index
+  %outer_tile_size = arith.ceildivui %c13, %vs_c8 : index
+
+  %B_pack_empty = tensor.empty(%outer_tile_size, %vs_c8) : tensor<?x16x?x1xi32>
+  %B_pack = linalg.pack %B
+     padding_value(%pad : i32)
+     outer_dims_perm = [1, 0]
+     inner_dims_pos = [1, 0]
+     inner_tiles = [%vs_c8, 1]
+     into %B_pack_empty : tensor<16x13xi32> -> tensor<?x16x?x1xi32>
+
+  return %B_pack : tensor<?x16x?x1xi32>
+}
+
+//===----------------------------------------------------------------------===//
+// @pack_acc
+//
+// Implements packing for the C matrix (accumulator) in matrix multiplication.
+// The inner tile size is "scalable": 8 * vscale
+//===----------------------------------------------------------------------===//
+func.func private @pack_acc(%C: tensor<7x13xi32>) -> tensor<1x?x8x?xi32> {
+  %pad = arith.constant 0 : i32
+
+  // Compute the outer tile size.
+  %c13 = arith.constant 13 : index
+  %vs = vector.vscale
+  %c8 = arith.constant 8 : index
+  %vs_c8 = arith.muli %vs, %c8 : index
+  %outer_tile_size = arith.ceildivui %c13, %vs_c8 : index
+
+  %C_pack_empty = tensor.empty(%outer_tile_size, %vs_c8) : tensor<1x?x8x?xi32>
+  %C_pack = linalg.pack %C
+    padding_value(%pad : i32)
+    outer_dims_perm = [0, 1]
+    inner_dims_pos = [0, 1]
+    inner_tiles = [8, %vs_c8] into %C_pack_empty : tensor<7x13xi32> -> tensor<1x?x8x?xi32>
+
+  return %C_pack : tensor<1x?x8x?xi32>
+}
+
+//===----------------------------------------------------------------------===//
+// @unpack_acc
+//
+// Implements unpacking for the C matrix (accumulator) in matrix
+// multiplication. The inner tile size is "scalable": 8 * vscale
+//===----------------------------------------------------------------------===//
+func.func private @unpack_acc(%C_packed: tensor<1x?x8x?xi32>) -> tensor<7x13xi32> {
+  %vs = vector.vscale
+  %c8 = arith.constant 8 : index
+  %vs_c8 = arith.muli %vs, %c8 : index
+
+  %C_out_empty = tensor.empty() : tensor<7x13xi32>
+  %C_out_unpack = linalg.unpack %C_packed
+    outer_dims_perm = [0, 1]
+    inner_dims_pos = [0, 1]
+    inner_tiles = [8, %vs_c8]
+    into %C_out_empty : tensor<1x?x8x?xi32> -> tensor<7x13xi32>
+
+  return %C_out_unpack: tensor<7x13xi32>
+}
+
+//===----------------------------------------------------------------------===//
+//  Helper methods for printing
+//===----------------------------------------------------------------------===//
+func.func private @print_pack_A(%A_pack : tensor<1x16x8x1xi32>) -> () {
+  %A_pack_cast = tensor.cast %A_pack : tensor<1x16x8x1xi32> to tensor<*xi32>
+  call @printMemrefI32(%A_pack_cast) : (tensor<*xi32>) -> ()
+
+  return
+}
+
+func.func private @print_pack_B(%B_pack : tensor<?x16x?x1xi32>) -> () {
+  %B_pack_cast = tensor.cast %B_pack : tensor<?x16x?x1xi32> to tensor<*xi32>
+  call @printMemrefI32(%B_pack_cast) : (tensor<*xi32>) -> ()
+
+  return
+}
+
+func.func private @print_pack_C(%C_pack : tensor<1x?x8x?xi32>) -> () {
+  %C_pack_cast = tensor.cast %C_pack : tensor<1x?x8x?xi32> to tensor<*xi32>
+  call @printMemrefI32(%C_pack_cast) : (tensor<*xi32>) -> ()
+
+  return
+}
+
+//===----------------------------------------------------------------------===//
+// @matmul_via_mmt4d
+//
+// Implements matrix-multiplication via linalg.mmt4d
+//===----------------------------------------------------------------------===//
+func.func private @matmul_via_mmt4d(%A: tensor<7x16xi32>, %B: tensor<16x13xi32>, %C: tensor<7x13xi32>) -> tensor<7x13xi32> {
+  // Pack input matrices
+  %A_pack = func.call @pack_lhs(%A): (tensor<7x16xi32>) -> tensor<1x16x8x1xi32>
+  %B_pack = func.call @pack_rhs(%B): (tensor<16x13xi32>) -> tensor<?x16x?x1xi32>
+  %C_pack = func.call @pack_acc(%C): (tensor<7x13xi32>) -> tensor<1x?x8x?xi32>
+
+  // Print the packed matrices (this is the only _visible_ part that changes
+  // when adjusting the SVE vector size).
+  func.call @print_pack_A(%A_pack) : (tensor<1x16x8x1xi32>) -> ()
+  func.call @print_pack_B(%B_pack) : (tensor<?x16x?x1xi32>) -> ()
+  func.call @print_pack_C(%C_pack) : (tensor<1x?x8x?xi32>) -> ()
+
+  // MMT4D
+  %mmt4d = linalg.mmt4d ins(%A_pack, %B_pack : tensor<1x16x8x1xi32>, tensor<?x16x?x1xi32>) outs(%C_pack : tensor<1x?x8x?xi32>) -> tensor<1x?x8x?xi32>
+
+  // Unpack the output
+  %C_out_unpack = func.call @unpack_acc(%mmt4d) : (tensor<1x?x8x?xi32>) -> tensor<7x13xi32>
+
+  return %C_out_unpack : tensor<7x13xi32>
+}
+
+//===----------------------------------------------------------------------===//
+// TD Sequence
+//===----------------------------------------------------------------------===//
+module @transforms attributes { transform.with_named_sequence } {
+  transform.named_sequence @__transform_main(%module: !transform.any_op {transform.consumed}) {
+    //==========================================================================
+    // HANDLE MMT4D
+    //==========================================================================
+    %mmt4d = transform.collect_matching @match_mmt4d in %module : (!transform.any_op) -> (!transform.any_op)
+    %mmt4d_func = transform.get_parent_op %mmt4d {isolated_from_above} : (!transform.any_op) -> !transform.op<"func.func">
+
+    // Step 1: Tile
+    // Tile parallel dims (note, the N dim is scalable!)
+    %tiled_mmt4d_parallel, %_:4 = transform.structured.tile_using_for %mmt4d tile_sizes [1, 1, 0, 8, [8], 0]
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+    // Tile reduction dims
+    %tiled_mmt4d, %_1:2 = transform.structured.tile_using_for %tiled_mmt4d_parallel tile_sizes [0, 0, 1, 0, 0, 1]
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+
+    // Step 2: Vectorize linalg.mmt4d (note, the N dim is scalable!)
+    transform.structured.vectorize %tiled_mmt4d
+      vector_sizes  [1, 1, 1, 8, [8], 1]  {assume_dynamic_dims_match_vec_sizes} : !transform.any_op
+
+    // Step 3: Simplify
+    // vector.multi_reduction --> vector.contract
+    // Generates a 6-dim vector.contract with the dim matching the original MMT4D Op
+    // and with the following split into parallel and reduction dims:
+    //    * parallel, parallel, reduction, parallel, parallel, reduction
+    transform.apply_patterns to %mmt4d_func {
+      transform.apply_patterns.vector.reduction_to_contract
+      // Reduce the rank of xfer ops. This transforms vector.contract to be
+      // more matmul-like and to enable the lowering to outer product Ops.
+      transform.apply_patterns.vector.transfer_permutation_patterns
+    } : !transform.op<"func.func">
+
+   // Hoisting and LICM - not strictly required
+   %mmt4d_func_h = transform.structured.hoist_redundant_vector_transfers %mmt4d_func
+     : (!transform.op<"func.func">) -> !transform.op<"func.func">
+   %all_loops = transform.structured.match interface{LoopLikeInterface} in %mmt4d_func_h
+     : (!transform.op<"func.func">) -> !transform.any_op
+   transform.apply_licm to %all_loops : !transform.any_op
+   transform.loop.hoist_loop_invariant_subsets %all_loops : !transform.any_op
+
+   // Simplification
+   transform.apply_patterns to %mmt4d_func_h {
+     transform.apply_patterns.vector.reduction_to_contract
+     transform.apply_patterns.vector.cast_away_vector_leading_one_dim
+     transform.apply_patterns.canonicalization
+   } : !transform.op<"func.func">
+
+   //==========================================================================
+   // HANDLE PACK + UNPACK
+   //==========================================================================
+    %pack = transform.structured.match ops{["linalg.pack"]} in %module : (!transform.any_op) -> !transform.any_op
+    %unpack = transform.structured.match ops{["linalg.unpack"]} in %module : (!transform.any_op) -> !transform.any_op
+
+    // 1.1 Tile the linalg.pack Op so that we can decompose it into e.g. tensor.pad
+    //    and other lower-level Ops (see step 2.1)
+    %tiled_pack_op_p, %loops_pack:2 = transform.structured.tile_using_for %pack tile_sizes [1, 1]
+       : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+
+    // 1.2 Tile the linalg.unpack Op so that we can decompose it into e.g. tensor.pad
+    //    and other lower-level Ops (see step 2)
+    %tiled_unpack_op_p, %loops_unpack:2 = transform.structured.tile_using_for %unpack tile_sizes [8, 1]
+       : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+
+    // 2.1. Decompose tiled PackOp into lower-level Ops + simplify
+    %func_op_pack = transform.get_parent_op %tiled_pack_op_p {isolated_from_above} : (!transform.any_op) -> !transform.op<"func.func">
+    transform.apply_patterns to %func_op_pack {
+      transform.apply_patterns.linalg.decompose_pack_unpack
+      transform.apply_patterns.linalg.decompose_pad
+    } : !transform.op<"func.func">
+
+    transform.apply_patterns to %func_op_pack {
+      transform.apply_patterns.tensor.fold_tensor_subset_ops
+      transform.apply_patterns.canonicalization
+    } : !transform.op<"func.func">
+
+    // 2.2. Decompose tiled UnpackOp into lower-level Ops + simplify
+    %func_op_unpack = transform.get_parent_op %tiled_unpack_op_p {isolated_from_above} : (!transform.any_op) -> !transform.op<"func.func">
+    transform.apply_patterns to %func_op_unpack {
+      transform.apply_patterns.linalg.decompose_pack_unpack
+    } : !transform.op<"func.func">
+
+    transform.apply_patterns to %func_op_unpack {
+      transform.apply_patterns.tensor.fold_tensor_subset_ops
+      transform.apply_patterns.canonicalization
+    } : !transform.op<"func.func">
+
+   //==========================================================================
+   // BUFFERIZATION
+   //==========================================================================
+   %bufferize = transform.bufferization.one_shot_bufferize %module
+     {bufferize_function_boundaries=true} : (!transform.any_op) -> !transform.any_op
+
+   //==========================================================================
+   // SIMPLIFY THE CONTRACT Op
+   //==========================================================================
+   %contract = transform.collect_matching @match_contract in %bufferize : (!transform.any_op) -> (!transform.any_op)
+   %contract_func = transform.get_parent_op %contract {isolated_from_above} : (!transform.any_op) -> !transform.op<"func.func">
+
+   // Drop trailing unit dims (the correspondong pattern works only
+   // post-bufferization)
+   transform.apply_patterns to %contract_func {
+      transform.apply_patterns.tensor.fold_tensor_subset_ops
+      transform.apply_patterns.vector.drop_inner_most_unit_dims_from_xfer_ops
+      transform.apply_patterns.canonicalization
+    } : !transform.op<"func.func">
+
+   //==========================================================================
+   // LOWER CONTRACT TO FMA
+   //==========================================================================
+   transform.apply_patterns to %contract_func {
+      transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
+      transform.apply_patterns.vector.lower_outerproduct
+   } : !transform.op<"func.func">
+
+   transform.yield
+   }
+
+   //==========================================================================
+   // TD MATCHERS (helper hooks)
+   //==========================================================================
+  transform.named_sequence @match_mmt4d(
+      %entry: !transform.any_op {transform.readonly}) -> !transform.any_op {
+    transform.match.operation_name %entry ["linalg.mmt4d"] : !transform.any_op
+    transform.yield %entry : !transform.any_op
+  }
+
+  transform.named_sequence @match_contract(
+      %entry: !transform.any_op {transform.readonly}) -> !transform.any_op {
+    transform.match.operation_name %entry ["vector.contract"] : !transform.any_op
+    transform.yield %entry : !transform.any_op
+  }
+}
+
+//===----------------------------------------------------------------------===//
+// Function signatures
+//===----------------------------------------------------------------------===//
+func.func private @printMemrefI32(%ptr : tensor<*xi32>)



More information about the Mlir-commits mailing list