[Mlir-commits] [mlir] 538f135 - [mlir][ArmSME] Fix scalable dims check in isValidSMETileVectorType (#65254)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Sep 4 23:08:21 PDT 2023


Author: Cullen Rhodes
Date: 2023-09-05T07:08:16+01:00
New Revision: 538f13584dbd79c0b97e52b839bf70674bffa4a5

URL: https://github.com/llvm/llvm-project/commit/538f13584dbd79c0b97e52b839bf70674bffa4a5
DIFF: https://github.com/llvm/llvm-project/commit/538f13584dbd79c0b97e52b839bf70674bffa4a5.diff

LOG: [mlir][ArmSME] Fix scalable dims check in isValidSMETileVectorType (#65254)

Check for allDimsScalable is incorrect and currently permits fixed
vectors.

Added: 
    

Modified: 
    mlir/lib/Dialect/ArmSME/Utils/Utils.cpp
    mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp b/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp
index cc9e36cfca4193c..8b2be7bc1901b9a 100644
--- a/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp
@@ -31,7 +31,7 @@ bool mlir::arm_sme::isValidSMETileElementType(Type type) {
 }
 
 bool mlir::arm_sme::isValidSMETileVectorType(VectorType vType) {
-  if ((vType.getRank() != 2) && vType.allDimsScalable())
+  if ((vType.getRank() != 2) || !vType.allDimsScalable())
     return false;
 
   auto elemType = vType.getElementType();

diff  --git a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
index 8b6bd8f52d1900f..cb35de11ab5b3ed 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
@@ -154,6 +154,17 @@ func.func @transfer_write_2d_zero__non_memref_type(%arg0 : tensor<?x?xi8>) -> te
   return %0 : tensor<?x?xi8>
 }
 
+// -----
+
+// CHECK-LABEL: @transfer_write_2d__fixed
+// CHECK: vector.transfer_write
+// CHECK-NOT: arm_sme.tile_store
+func.func @transfer_write_2d__fixed(%vector : vector<16x16xi8>, %dest : memref<?x?xi8>) {
+  %c0 = arith.constant 0 : index
+  vector.transfer_write %vector, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xi8>, memref<?x?xi8>
+  return
+}
+
 // =============================================================================
 // vector.broadcast
 // =============================================================================


        


More information about the Mlir-commits mailing list