[Mlir-commits] [mlir] [mlir[[sve] Add an e2e for linalg.matmul with mixed types (PR #73773)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Nov 29 02:00:57 PST 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-sve
@llvm/pr-subscribers-mlir-vector
Author: Andrzej WarzyĆski (banach-space)
<details>
<summary>Changes</summary>
Apart from the test itself, this patch also updates a few patterns to
fix how new VectorType(s) are created. Namely, it makes sure that
"scalability" is correctly propagated.
Regression tests will be updated seperately while auditing Vector
dialect tests in the context of scalable vectors:
* https://github.com/orgs/llvm/projects/23
---
Full diff: https://github.com/llvm/llvm-project/pull/73773.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp (+8-8)
- (modified) mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/matmul.mlir (+42-28)
``````````diff
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 582d627d1ce4ac0..96ec44fcd77677a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -455,7 +455,7 @@ struct ReorderCastOpsOnBroadcast
Type castResTy = getElementTypeOrSelf(op->getResult(0));
if (auto vecTy = dyn_cast<VectorType>(bcastOp.getSourceType()))
- castResTy = VectorType::get(vecTy.getShape(), castResTy);
+ castResTy = vecTy.clone(castResTy);
auto *castOp =
rewriter.create(op->getLoc(), op->getName().getIdentifier(),
bcastOp.getSource(), castResTy, op->getAttrs());
@@ -527,16 +527,14 @@ struct ReorderElementwiseOpsOnTranspose final
srcValues.push_back(transposeOp.getVector());
} else {
// This is a constant. Create a reverse transpose op for it.
- auto vectorType = VectorType::get(
- srcType.getShape(),
- cast<VectorType>(operand.getType()).getElementType());
+ auto vectorType =
+ srcType.clone(cast<VectorType>(operand.getType()).getElementType());
srcValues.push_back(rewriter.create<vector::TransposeOp>(
operand.getLoc(), vectorType, operand, invOrder));
}
}
- auto vectorType = VectorType::get(
- srcType.getShape(),
+ auto vectorType = srcType.clone(
cast<VectorType>(op->getResultTypes()[0]).getElementType());
Operation *elementwiseOp =
rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues,
@@ -1314,8 +1312,10 @@ struct CanonicalizeContractMatmulToMMT final
Value trans =
rewriter.create<vector::TransposeOp>(loc, sext.getIn(), perm);
VectorType newType =
- VectorType::get(cast<VectorType>(trans.getType()).getShape(),
- cast<VectorType>(mat.getType()).getElementType());
+ cast<VectorType>(trans.getType())
+ .clone(cast<VectorType>(mat.getType()).getElementType());
+ // VectorType::get(cast<VectorType>(trans.getType()).getShape(),
+ // cast<VectorType>(mat.getType()).getElementType());
return rewriter.create<arith::ExtSIOp>(loc, newType, trans);
}
if (auto zext = mat.getDefiningOp<arith::ExtUIOp>()) {
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/matmul.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/matmul.mlir
index 2024da2a585d99f..d771d32d548bbe2 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/matmul.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/matmul.mlir
@@ -1,8 +1,14 @@
-// RUN: mlir-opt %s -test-transform-dialect-interpreter -test-transform-dialect-erase-schedule \
-// RUN: -one-shot-bufferize -func-bufferize -cse -canonicalize -convert-vector-to-scf -arm-sve-legalize-vector-storage \
-// RUN: -convert-vector-to-llvm="enable-arm-sve" -test-lower-to-llvm | \
-// RUN: %mcr_aarch64_cmd -e=matmul_f32 -entry-point-result=void --march=aarch64 --mattr="+sve" -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils | \
-// RUN: FileCheck %s
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE: -transform-interpreter -test-transform-dialect-erase-schedule \
+// DEFINE: -one-shot-bufferize -func-bufferize -cse -canonicalize -convert-vector-to-scf -arm-sve-legalize-vector-storage \
+// DEFINE: -convert-vector-to-llvm="enable-arm-sve" -test-lower-to-llvm -o %t
+// DEFINE: %{entry_point} = matmul_f32
+// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void --march=aarch64 --mattr="+sve"\
+// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils
+
+// RUN: %{compile}
+
+// RUN: %{run} | FileCheck %s
func.func @matmul_f32() {
// Matrix dimensions
@@ -40,29 +46,37 @@ func.func @matmul_f32() {
return
}
-transform.sequence failures(propagate) {
-^bb1(%module_op: !transform.any_op):
- // Step 1: Tile
- %matmul = transform.structured.match ops{["linalg.matmul"]} in %module_op : (!transform.any_op) -> !transform.any_op
- %func_op = get_parent_op %matmul : (!transform.any_op) -> !transform.op<"func.func">
- %module_with_tiled_loops, %loops:3 = transform.structured.tile_using_for %matmul [2, [4], 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
-
- // Step 2: Vectorize
- %tiled_matmul = transform.structured.match ops{["linalg.matmul"]} in %module_with_tiled_loops : (!transform.any_op) -> !transform.any_op
- transform.structured.vectorize %tiled_matmul vector_sizes [2, [4], 1] : !transform.any_op
-
- // Step 3: Lower vector.multi_reduction to vector.contract (+ some helpful patterns)
- transform.apply_patterns to %func_op {
- transform.apply_patterns.vector.reduction_to_contract
- transform.apply_patterns.vector.transfer_permutation_patterns
- transform.apply_patterns.vector.lower_masked_transfers
- } : !transform.op<"func.func">
-
- // Step 4: Lower vector.contract to vector.fma
- transform.apply_patterns to %func_op {
- transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
- transform.apply_patterns.vector.lower_outerproduct
- } : !transform.op<"func.func">
+module attributes {transform.with_named_sequence} {
+transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) {
+ %matmul = transform.structured.match ops{["linalg.matmul"]} in %module
+ : (!transform.any_op) -> !transform.any_op
+
+ // Step 1: Tile
+ %module_with_tiled_loops, %loops:3 = transform.structured.tile_using_for %matmul [2, [4], 1]
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+
+ // Step 2: Vectorize
+ %tiled_matmul = transform.structured.match ops{["linalg.matmul"]} in %module_with_tiled_loops
+ : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %tiled_matmul vector_sizes [2, [4], 1] : !transform.any_op
+
+ // Step 3: Lower vector.multi_reduction to vector.contract (+ some helpful patterns)
+ %func = transform.structured.match ops{["func.func"]} in %module
+ : (!transform.any_op) -> !transform.op<"func.func">
+ transform.apply_patterns to %func {
+ transform.apply_patterns.vector.reduction_to_contract
+ transform.apply_patterns.vector.transfer_permutation_patterns
+ transform.apply_patterns.vector.lower_masked_transfers
+ } : !transform.op<"func.func">
+
+ // Step 4: Lower vector.contract to vector.fma
+ transform.apply_patterns to %func {
+ transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
+ transform.apply_patterns.vector.lower_outerproduct
+ } : !transform.op<"func.func">
+
+ transform.yield
+ }
}
func.func private @printMemrefF32(%ptr : tensor<*xf32>)
``````````
</details>
https://github.com/llvm/llvm-project/pull/73773
More information about the Mlir-commits
mailing list