[Mlir-commits] [mlir] [mlir][ArmSME] Fix scalable dims check in isValidSMETileVectorType (PR #65254)
Cullen Rhodes
llvmlistbot at llvm.org
Mon Sep 4 04:29:12 PDT 2023
https://github.com/c-rhodes created https://github.com/llvm/llvm-project/pull/65254:
Check for allDimsScalable is incorrect and currently permits fixed vectors.
>From 47cf7b0d3d1ddd89ddad30b539c65fc92dec825f Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Mon, 4 Sep 2023 10:59:08 +0000
Subject: [PATCH] [mlir][ArmSME] Fix scalable dims check in
isValidSMETileVectorType
Check for allDimsScalable is incorrect and currently permits fixed
vectors.
---
mlir/lib/Dialect/ArmSME/Utils/Utils.cpp | 2 +-
mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir | 11 +++++++++++
2 files changed, 12 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp b/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp
index cc9e36cfca4193c..8b2be7bc1901b9a 100644
--- a/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp
@@ -31,7 +31,7 @@ bool mlir::arm_sme::isValidSMETileElementType(Type type) {
}
bool mlir::arm_sme::isValidSMETileVectorType(VectorType vType) {
- if ((vType.getRank() != 2) && vType.allDimsScalable())
+ if ((vType.getRank() != 2) || !vType.allDimsScalable())
return false;
auto elemType = vType.getElementType();
diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
index 8b6bd8f52d1900f..cb35de11ab5b3ed 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
@@ -154,6 +154,17 @@ func.func @transfer_write_2d_zero__non_memref_type(%arg0 : tensor<?x?xi8>) -> te
return %0 : tensor<?x?xi8>
}
+// -----
+
+// CHECK-LABEL: @transfer_write_2d__fixed
+// CHECK: vector.transfer_write
+// CHECK-NOT: arm_sme.tile_store
+func.func @transfer_write_2d__fixed(%vector : vector<16x16xi8>, %dest : memref<?x?xi8>) {
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %vector, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xi8>, memref<?x?xi8>
+ return
+}
+
// =============================================================================
// vector.broadcast
// =============================================================================
More information about the Mlir-commits
mailing list