[Mlir-commits] [mlir] [mlir][sve] Add an e2e for linalg.matmul with mixed types (PR #73773)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Wed Nov 29 07:46:32 PST 2023
https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/73773
>From b66226c75768f6451d14c0be9fd61a0dc4dbe3ca Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Wed, 29 Nov 2023 09:07:54 +0000
Subject: [PATCH] [mlir[[sve] Add an e2e for linalg.matmul with mixed types
Apart from the test itself, this patch also updates a few patterns to
fix how new VectorType(s) are created. Namely, it makes sure that
"scalability" is correctly propagated.
Regression tests will be updated seperately while auditing Vector
dialect tests in the context of scalable vectors:
* https://github.com/orgs/llvm/projects/23
---
.../Vector/Transforms/VectorTransforms.cpp | 14 ++--
.../Linalg/CPU/ArmSVE/matmul_mixed_ty.mlir | 83 +++++++++++++++++++
2 files changed, 89 insertions(+), 8 deletions(-)
create mode 100644 mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/matmul_mixed_ty.mlir
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 582d627d1ce4ac0..6e7fab293d3a1ca 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -455,7 +455,7 @@ struct ReorderCastOpsOnBroadcast
Type castResTy = getElementTypeOrSelf(op->getResult(0));
if (auto vecTy = dyn_cast<VectorType>(bcastOp.getSourceType()))
- castResTy = VectorType::get(vecTy.getShape(), castResTy);
+ castResTy = vecTy.clone(castResTy);
auto *castOp =
rewriter.create(op->getLoc(), op->getName().getIdentifier(),
bcastOp.getSource(), castResTy, op->getAttrs());
@@ -527,16 +527,14 @@ struct ReorderElementwiseOpsOnTranspose final
srcValues.push_back(transposeOp.getVector());
} else {
// This is a constant. Create a reverse transpose op for it.
- auto vectorType = VectorType::get(
- srcType.getShape(),
- cast<VectorType>(operand.getType()).getElementType());
+ auto vectorType =
+ srcType.clone(cast<VectorType>(operand.getType()).getElementType());
srcValues.push_back(rewriter.create<vector::TransposeOp>(
operand.getLoc(), vectorType, operand, invOrder));
}
}
- auto vectorType = VectorType::get(
- srcType.getShape(),
+ auto vectorType = srcType.clone(
cast<VectorType>(op->getResultTypes()[0]).getElementType());
Operation *elementwiseOp =
rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues,
@@ -1314,8 +1312,8 @@ struct CanonicalizeContractMatmulToMMT final
Value trans =
rewriter.create<vector::TransposeOp>(loc, sext.getIn(), perm);
VectorType newType =
- VectorType::get(cast<VectorType>(trans.getType()).getShape(),
- cast<VectorType>(mat.getType()).getElementType());
+ cast<VectorType>(trans.getType())
+ .clone(cast<VectorType>(mat.getType()).getElementType());
return rewriter.create<arith::ExtSIOp>(loc, newType, trans);
}
if (auto zext = mat.getDefiningOp<arith::ExtUIOp>()) {
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/matmul_mixed_ty.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/matmul_mixed_ty.mlir
new file mode 100644
index 000000000000000..f4f2d87b4d0b42c
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/matmul_mixed_ty.mlir
@@ -0,0 +1,83 @@
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE: -transform-interpreter -test-transform-dialect-erase-schedule \
+// DEFINE: -one-shot-bufferize -func-bufferize -cse -canonicalize -convert-vector-to-scf -arm-sve-legalize-vector-storage \
+// DEFINE: -convert-vector-to-llvm="enable-arm-sve" -test-lower-to-llvm -o %t
+// DEFINE: %{entry_point} = matmul_mixed_ty
+// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void --march=aarch64 --mattr="+sve"\
+// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils
+
+// RUN: %{compile}
+
+// RUN: %{run} | FileCheck %s
+
+func.func @matmul_mixed_ty() {
+ // Matrix dimensions
+ %K = arith.constant 3 : index
+ %M = arith.constant 5 : index
+ %N = arith.constant 15 : index
+ %c0_i8 = arith.constant 0 : i8
+ %c0_i32 = arith.constant 0 : i32
+
+ // Allocate the matrices
+ %A_alloc = bufferization.alloc_tensor(%M, %K) : tensor<?x?xi8>
+ %B_alloc = bufferization.alloc_tensor(%K, %N) : tensor<?x?xi8>
+ %C_alloc = bufferization.alloc_tensor(%M, %N) : tensor<?x?xi32>
+
+ // Initialise the matrices
+ %pi = arith.constant 123 : i8
+ %A = linalg.fill ins(%pi : i8) outs(%A_alloc : tensor<?x?xi8>) -> tensor<?x?xi8>
+ %B = linalg.fill ins(%pi : i8) outs(%B_alloc : tensor<?x?xi8>) -> tensor<?x?xi8>
+ %C_in = linalg.fill ins(%c0_i32 : i32) outs(%C_alloc : tensor<?x?xi32>) -> tensor<?x?xi32>
+
+ // Matmul
+ %C_out = linalg.matmul ins(%A, %B: tensor<?x?xi8>, tensor<?x?xi8>) outs(%C_in: tensor<?x?xi32>) -> tensor<?x?xi32>
+
+ // Print and verify the output
+ // CHECK-LABEL: SVE: START OF TEST OUTPUT
+ vector.print str "SVE: START OF TEST OUTPUT"
+
+ // CHECK-NEXT: Unranked Memref {{.*}} rank = 2 offset = 0 sizes = [5, 15] strides = [15, 1] data =
+ // CHECK-COUNT-5: [45387, 45387, 45387, 45387, 45387, 45387, 45387, 45387, 45387, 45387, 45387, 45387, 45387, 45387, 45387]
+ %xf = tensor.cast %C_out : tensor<?x?xi32> to tensor<*xi32>
+ call @printMemrefI32(%xf) : (tensor<*xi32>) -> ()
+
+ // CHECK-NEXT: SVE: END OF TEST OUTPUT
+ vector.print str "SVE: END OF TEST OUTPUT"
+
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) {
+ %matmul = transform.structured.match ops{["linalg.matmul"]} in %module
+ : (!transform.any_op) -> !transform.any_op
+
+ // Step 1: Tile
+ %module_with_tiled_loops, %loops:3 = transform.structured.tile_using_for %matmul [2, [4], 1]
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+
+ // Step 2: Vectorize
+ %tiled_matmul = transform.structured.match ops{["linalg.matmul"]} in %module_with_tiled_loops
+ : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %tiled_matmul vector_sizes [2, [4], 1] : !transform.any_op
+
+ // Step 3: Lower vector.multi_reduction to vector.contract (+ some helpful patterns)
+ %func = transform.structured.match ops{["func.func"]} in %module
+ : (!transform.any_op) -> !transform.op<"func.func">
+ transform.apply_patterns to %func {
+ transform.apply_patterns.vector.reduction_to_contract
+ transform.apply_patterns.vector.transfer_permutation_patterns
+ transform.apply_patterns.vector.lower_masked_transfers
+ } : !transform.op<"func.func">
+
+ // Step 4: Lower vector.contract to vector.fma
+ transform.apply_patterns to %func {
+ transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
+ transform.apply_patterns.vector.lower_outerproduct
+ } : !transform.op<"func.func">
+
+ transform.yield
+ }
+}
+
+func.func private @printMemrefI32(%ptr : tensor<*xi32>)
More information about the Mlir-commits
mailing list