[Mlir-commits] [mlir] [mlir][sve] Add an e2e for linalg.matmul with mixed types (PR #73773)

Andrzej WarzyƄski llvmlistbot at llvm.org
Wed Nov 29 07:44:06 PST 2023


https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/73773

>From ac40fc8e6626507e09e1bbebca116ebcb5623947 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Wed, 29 Nov 2023 09:30:35 +0000
Subject: [PATCH 1/2] [mlir][sve][nfc] Update a test to use
 transform-interpreter

This is a follow-up of #70040 in which the test updated here was missed.

Includes a few additional NFC changes in preparation for extending this
test.
---
 .../Dialect/Linalg/CPU/ArmSVE/matmul.mlir     | 70 +++++++++++--------
 1 file changed, 42 insertions(+), 28 deletions(-)

diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/matmul.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/matmul.mlir
index 2024da2a585d99f..d771d32d548bbe2 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/matmul.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/matmul.mlir
@@ -1,8 +1,14 @@
-// RUN: mlir-opt %s -test-transform-dialect-interpreter -test-transform-dialect-erase-schedule \
-// RUN:   -one-shot-bufferize -func-bufferize -cse -canonicalize -convert-vector-to-scf -arm-sve-legalize-vector-storage \
-// RUN:   -convert-vector-to-llvm="enable-arm-sve" -test-lower-to-llvm | \
-// RUN: %mcr_aarch64_cmd -e=matmul_f32 -entry-point-result=void --march=aarch64 --mattr="+sve" -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils | \
-// RUN: FileCheck %s
+// DEFINE: %{compile} =  mlir-opt %s \
+// DEFINE:    -transform-interpreter -test-transform-dialect-erase-schedule \
+// DEFINE:    -one-shot-bufferize -func-bufferize -cse -canonicalize -convert-vector-to-scf -arm-sve-legalize-vector-storage \
+// DEFINE:    -convert-vector-to-llvm="enable-arm-sve" -test-lower-to-llvm -o %t
+// DEFINE: %{entry_point} = matmul_f32
+// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void --march=aarch64 --mattr="+sve"\
+// DEFINE:    -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils
+
+// RUN: %{compile}
+
+// RUN: %{run} | FileCheck %s
 
 func.func @matmul_f32() {
   // Matrix dimensions
@@ -40,29 +46,37 @@ func.func @matmul_f32() {
   return
 }
 
-transform.sequence failures(propagate) {
-^bb1(%module_op: !transform.any_op):
-  // Step 1: Tile
-  %matmul = transform.structured.match ops{["linalg.matmul"]} in %module_op : (!transform.any_op) -> !transform.any_op
-  %func_op = get_parent_op %matmul : (!transform.any_op) -> !transform.op<"func.func">
-  %module_with_tiled_loops, %loops:3 = transform.structured.tile_using_for %matmul [2, [4], 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
-
-  // Step 2: Vectorize
-  %tiled_matmul = transform.structured.match ops{["linalg.matmul"]} in %module_with_tiled_loops : (!transform.any_op) -> !transform.any_op
-  transform.structured.vectorize %tiled_matmul vector_sizes [2, [4], 1] : !transform.any_op
-
-  // Step 3: Lower vector.multi_reduction to vector.contract (+ some helpful patterns)
-  transform.apply_patterns to %func_op {
-    transform.apply_patterns.vector.reduction_to_contract
-    transform.apply_patterns.vector.transfer_permutation_patterns
-    transform.apply_patterns.vector.lower_masked_transfers
-  } : !transform.op<"func.func">
-
-  // Step 4: Lower vector.contract to vector.fma
-  transform.apply_patterns to %func_op {
-    transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
-    transform.apply_patterns.vector.lower_outerproduct
-  } : !transform.op<"func.func">
+module attributes {transform.with_named_sequence} {
+transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) {
+    %matmul = transform.structured.match ops{["linalg.matmul"]} in %module
+      : (!transform.any_op) -> !transform.any_op
+
+    // Step 1: Tile
+    %module_with_tiled_loops, %loops:3 = transform.structured.tile_using_for %matmul [2, [4], 1]
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+
+    // Step 2: Vectorize
+    %tiled_matmul = transform.structured.match ops{["linalg.matmul"]} in %module_with_tiled_loops
+      : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %tiled_matmul vector_sizes [2, [4], 1] : !transform.any_op
+
+    // Step 3: Lower vector.multi_reduction to vector.contract (+ some helpful patterns)
+    %func = transform.structured.match ops{["func.func"]} in %module
+      : (!transform.any_op) -> !transform.op<"func.func">
+    transform.apply_patterns to %func {
+      transform.apply_patterns.vector.reduction_to_contract
+      transform.apply_patterns.vector.transfer_permutation_patterns
+      transform.apply_patterns.vector.lower_masked_transfers
+    } : !transform.op<"func.func">
+
+    // Step 4: Lower vector.contract to vector.fma
+    transform.apply_patterns to %func {
+      transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
+      transform.apply_patterns.vector.lower_outerproduct
+    } : !transform.op<"func.func">
+
+    transform.yield
+  }
 }
 
 func.func private @printMemrefF32(%ptr : tensor<*xf32>)

>From 4d199fb97729dc61cef3467316f99e2cce1314e3 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Wed, 29 Nov 2023 09:07:54 +0000
Subject: [PATCH 2/2] [mlir[[sve] Add an e2e for linalg.matmul with mixed types

Apart from the test itself, this patch also updates a few patterns to
fix how new VectorType(s) are created. Namely, it makes sure that
"scalability" is correctly propagated.

Regression tests will be updated seperately while auditing Vector
dialect tests in the context of scalable vectors:
  * https://github.com/orgs/llvm/projects/23
---
 .../Vector/Transforms/VectorTransforms.cpp    | 14 ++--
 .../Linalg/CPU/ArmSVE/matmul_mixed_ty.mlir    | 83 +++++++++++++++++++
 2 files changed, 89 insertions(+), 8 deletions(-)
 create mode 100644 mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/matmul_mixed_ty.mlir

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 582d627d1ce4ac0..6e7fab293d3a1ca 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -455,7 +455,7 @@ struct ReorderCastOpsOnBroadcast
 
     Type castResTy = getElementTypeOrSelf(op->getResult(0));
     if (auto vecTy = dyn_cast<VectorType>(bcastOp.getSourceType()))
-      castResTy = VectorType::get(vecTy.getShape(), castResTy);
+      castResTy = vecTy.clone(castResTy);
     auto *castOp =
         rewriter.create(op->getLoc(), op->getName().getIdentifier(),
                         bcastOp.getSource(), castResTy, op->getAttrs());
@@ -527,16 +527,14 @@ struct ReorderElementwiseOpsOnTranspose final
         srcValues.push_back(transposeOp.getVector());
       } else {
         // This is a constant. Create a reverse transpose op for it.
-        auto vectorType = VectorType::get(
-            srcType.getShape(),
-            cast<VectorType>(operand.getType()).getElementType());
+        auto vectorType =
+            srcType.clone(cast<VectorType>(operand.getType()).getElementType());
         srcValues.push_back(rewriter.create<vector::TransposeOp>(
             operand.getLoc(), vectorType, operand, invOrder));
       }
     }
 
-    auto vectorType = VectorType::get(
-        srcType.getShape(),
+    auto vectorType = srcType.clone(
         cast<VectorType>(op->getResultTypes()[0]).getElementType());
     Operation *elementwiseOp =
         rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues,
@@ -1314,8 +1312,8 @@ struct CanonicalizeContractMatmulToMMT final
         Value trans =
             rewriter.create<vector::TransposeOp>(loc, sext.getIn(), perm);
         VectorType newType =
-            VectorType::get(cast<VectorType>(trans.getType()).getShape(),
-                            cast<VectorType>(mat.getType()).getElementType());
+            cast<VectorType>(trans.getType())
+                .clone(cast<VectorType>(mat.getType()).getElementType());
         return rewriter.create<arith::ExtSIOp>(loc, newType, trans);
       }
       if (auto zext = mat.getDefiningOp<arith::ExtUIOp>()) {
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/matmul_mixed_ty.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/matmul_mixed_ty.mlir
new file mode 100644
index 000000000000000..f4f2d87b4d0b42c
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/matmul_mixed_ty.mlir
@@ -0,0 +1,83 @@
+// DEFINE: %{compile} =  mlir-opt %s \
+// DEFINE:    -transform-interpreter -test-transform-dialect-erase-schedule \
+// DEFINE:    -one-shot-bufferize -func-bufferize -cse -canonicalize -convert-vector-to-scf -arm-sve-legalize-vector-storage \
+// DEFINE:    -convert-vector-to-llvm="enable-arm-sve" -test-lower-to-llvm -o %t
+// DEFINE: %{entry_point} = matmul_mixed_ty
+// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void --march=aarch64 --mattr="+sve"\
+// DEFINE:    -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils
+
+// RUN: %{compile}
+
+// RUN: %{run} | FileCheck %s
+
+func.func @matmul_mixed_ty() {
+  // Matrix dimensions
+  %K = arith.constant 3 : index
+  %M = arith.constant 5 : index
+  %N = arith.constant 15 : index
+  %c0_i8 = arith.constant 0 : i8
+  %c0_i32 = arith.constant 0 : i32
+
+  // Allocate the matrices
+  %A_alloc = bufferization.alloc_tensor(%M, %K) : tensor<?x?xi8>
+  %B_alloc = bufferization.alloc_tensor(%K, %N) : tensor<?x?xi8>
+  %C_alloc = bufferization.alloc_tensor(%M, %N) : tensor<?x?xi32>
+
+  // Initialise the matrices
+  %pi = arith.constant  123 : i8
+  %A = linalg.fill ins(%pi : i8) outs(%A_alloc : tensor<?x?xi8>) -> tensor<?x?xi8>
+  %B = linalg.fill ins(%pi : i8) outs(%B_alloc : tensor<?x?xi8>) -> tensor<?x?xi8>
+  %C_in = linalg.fill ins(%c0_i32 : i32) outs(%C_alloc : tensor<?x?xi32>) -> tensor<?x?xi32>
+
+  // Matmul
+  %C_out = linalg.matmul ins(%A, %B: tensor<?x?xi8>, tensor<?x?xi8>) outs(%C_in: tensor<?x?xi32>) -> tensor<?x?xi32>
+
+  // Print and verify the output
+  // CHECK-LABEL: SVE: START OF TEST OUTPUT
+  vector.print str "SVE: START OF TEST OUTPUT"
+
+  // CHECK-NEXT: Unranked Memref {{.*}} rank = 2 offset = 0 sizes = [5, 15] strides = [15, 1] data =
+  // CHECK-COUNT-5: [45387,   45387,   45387,   45387,   45387,   45387,   45387,   45387,   45387,   45387,   45387,   45387,   45387,   45387,   45387]
+  %xf = tensor.cast %C_out : tensor<?x?xi32> to tensor<*xi32>
+  call @printMemrefI32(%xf) : (tensor<*xi32>) -> ()
+
+  // CHECK-NEXT: SVE: END OF TEST OUTPUT
+  vector.print str "SVE: END OF TEST OUTPUT"
+
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) {
+    %matmul = transform.structured.match ops{["linalg.matmul"]} in %module
+      : (!transform.any_op) -> !transform.any_op
+
+    // Step 1: Tile
+    %module_with_tiled_loops, %loops:3 = transform.structured.tile_using_for %matmul [2, [4], 1]
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+
+    // Step 2: Vectorize
+    %tiled_matmul = transform.structured.match ops{["linalg.matmul"]} in %module_with_tiled_loops
+      : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %tiled_matmul vector_sizes [2, [4], 1] : !transform.any_op
+
+    // Step 3: Lower vector.multi_reduction to vector.contract (+ some helpful patterns)
+    %func = transform.structured.match ops{["func.func"]} in %module
+      : (!transform.any_op) -> !transform.op<"func.func">
+    transform.apply_patterns to %func {
+      transform.apply_patterns.vector.reduction_to_contract
+      transform.apply_patterns.vector.transfer_permutation_patterns
+      transform.apply_patterns.vector.lower_masked_transfers
+    } : !transform.op<"func.func">
+
+    // Step 4: Lower vector.contract to vector.fma
+    transform.apply_patterns to %func {
+      transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
+      transform.apply_patterns.vector.lower_outerproduct
+    } : !transform.op<"func.func">
+
+    transform.yield
+  }
+}
+
+func.func private @printMemrefI32(%ptr : tensor<*xi32>)



More information about the Mlir-commits mailing list