[Mlir-commits] [mlir] [mlir][ArmSME] Fix scalable dims check in isValidSMETileVectorType (PR #65254)

Cullen Rhodes llvmlistbot at llvm.org
Mon Sep 4 04:29:12 PDT 2023


https://github.com/c-rhodes created https://github.com/llvm/llvm-project/pull/65254:

Check for allDimsScalable is incorrect and currently permits fixed vectors.

>From 47cf7b0d3d1ddd89ddad30b539c65fc92dec825f Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Mon, 4 Sep 2023 10:59:08 +0000
Subject: [PATCH] [mlir][ArmSME] Fix scalable dims check in
 isValidSMETileVectorType

Check for allDimsScalable is incorrect and currently permits fixed
vectors.
---
 mlir/lib/Dialect/ArmSME/Utils/Utils.cpp         |  2 +-
 mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir | 11 +++++++++++
 2 files changed, 12 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp b/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp
index cc9e36cfca4193c..8b2be7bc1901b9a 100644
--- a/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp
@@ -31,7 +31,7 @@ bool mlir::arm_sme::isValidSMETileElementType(Type type) {
 }
 
 bool mlir::arm_sme::isValidSMETileVectorType(VectorType vType) {
-  if ((vType.getRank() != 2) && vType.allDimsScalable())
+  if ((vType.getRank() != 2) || !vType.allDimsScalable())
     return false;
 
   auto elemType = vType.getElementType();
diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
index 8b6bd8f52d1900f..cb35de11ab5b3ed 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
@@ -154,6 +154,17 @@ func.func @transfer_write_2d_zero__non_memref_type(%arg0 : tensor<?x?xi8>) -> te
   return %0 : tensor<?x?xi8>
 }
 
+// -----
+
+// CHECK-LABEL: @transfer_write_2d__fixed
+// CHECK: vector.transfer_write
+// CHECK-NOT: arm_sme.tile_store
+func.func @transfer_write_2d__fixed(%vector : vector<16x16xi8>, %dest : memref<?x?xi8>) {
+  %c0 = arith.constant 0 : index
+  vector.transfer_write %vector, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xi8>, memref<?x?xi8>
+  return
+}
+
 // =============================================================================
 // vector.broadcast
 // =============================================================================



More information about the Mlir-commits mailing list