[Mlir-commits] [mlir] 834cdc8 - [mlir][ArmSME] Fix get_tile_id type in zero lowering

Cullen Rhodes llvmlistbot at llvm.org
Wed Aug 30 00:16:48 PDT 2023


Author: Cullen Rhodes
Date: 2023-08-30T07:16:35Z
New Revision: 834cdc8b64d572d2b2aae28c916a6b27dca1eb65

URL: https://github.com/llvm/llvm-project/commit/834cdc8b64d572d2b2aae28c916a6b27dca1eb65
DIFF: https://github.com/llvm/llvm-project/commit/834cdc8b64d572d2b2aae28c916a6b27dca1eb65.diff

LOG: [mlir][ArmSME] Fix get_tile_id type in zero lowering

The arm_sme.get_tile_id op returns a scalar integer but the arm_sme.zero
op lowering incorrectly uses the element type, which could be
floating-point.

Reviewed By: awarzynski, benmxwl-arm

Differential Revision: https://reviews.llvm.org/D159080

Added: 
    

Modified: 
    mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
    mlir/test/Dialect/ArmSME/tile-zero-masks.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
index 88beb80de934fe..685f8d57f76f52 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
@@ -92,11 +92,12 @@ struct ZeroOpConversion : public ConvertOpToLLVMPattern<ZeroOp> {
                   ConversionPatternRewriter &rewriter) const override {
     auto loc = zero.getLoc();
 
+    unsigned tileElementWidth =
+        zero.getVectorType().getElementType().getIntOrFloatBitWidth();
+
     // Get Tile ID for the `zero` intrinsic.
     auto tileId = rewriter.create<arm_sme::GetTileID>(
-        loc, zero.getVectorType().getElementType());
-
-    auto tileElementWidth = tileId.getType().getIntOrFloatBitWidth();
+        loc, rewriter.getIntegerType(tileElementWidth));
 
     // Get the base mask for tile based on the element size.
     // The base mask is just the mask to zero the first tile (of a size).

diff  --git a/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir b/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir
index 0ff136dd47685f..26cd91bd3e8956 100644
--- a/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir
+++ b/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir
@@ -3,6 +3,9 @@
 // RUN:           -allow-unregistered-dialect                \
 // RUN: | FileCheck %s
 
+// This test verifies the tile mask operand of the zero intrinsic zeroes
+// the correct tiles. Both integer and floating-point datatypes are checked.
+
 // -----
 
 // CHECK-LABEL: zero_za_b
@@ -32,9 +35,9 @@ func.func @zero_za_h() {
   %zero_za0h = arm_sme.zero : vector<[8]x[8]xi16>
   "prevent.dce"(%zero_za0h) : (vector<[8]x[8]xi16>) -> ()
   // CHECK:      "arm_sme.intr.zero"(%[[ZERO_MASK_ZA1H]]) : (i32) -> ()
-  // CHECK-NEXT: %[[ZERO_ZA1H:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA1H]] : i16 to vector<[8]x[8]xi16>
-  %zero_za1h = arm_sme.zero : vector<[8]x[8]xi16>
-  "prevent.dce"(%zero_za1h) : (vector<[8]x[8]xi16>) -> ()
+  // CHECK-NEXT: %[[ZERO_ZA1H:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA1H]] : i16 to vector<[8]x[8]xf16>
+  %zero_za1h = arm_sme.zero : vector<[8]x[8]xf16>
+  "prevent.dce"(%zero_za1h) : (vector<[8]x[8]xf16>) -> ()
   return
 }
 
@@ -65,9 +68,9 @@ func.func @zero_za_s() {
   %zero_za2s = arm_sme.zero : vector<[4]x[4]xi32>
   "prevent.dce"(%zero_za2s) : (vector<[4]x[4]xi32>) -> ()
   // CHECK:     "arm_sme.intr.zero"(%[[ZERO_MASK_ZA3S]]) : (i32) -> ()
-  // CHECK-NEXT: %[[ZERO_ZA3S:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA3S]] : i32 to vector<[4]x[4]xi32>
-  %zero_za3s = arm_sme.zero : vector<[4]x[4]xi32>
-  "prevent.dce"(%zero_za3s) : (vector<[4]x[4]xi32>) -> ()
+  // CHECK-NEXT: %[[ZERO_ZA3S:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA3S]] : i32 to vector<[4]x[4]xf32>
+  %zero_za3s = arm_sme.zero : vector<[4]x[4]xf32>
+  "prevent.dce"(%zero_za3s) : (vector<[4]x[4]xf32>) -> ()
   return
 }
 
@@ -122,8 +125,8 @@ func.func @zero_za_d() {
   %zero_za6d = arm_sme.zero : vector<[2]x[2]xi64>
   "prevent.dce"(%zero_za6d) : (vector<[2]x[2]xi64>) -> ()
   // CHECK:     "arm_sme.intr.zero"(%[[ZERO_MASK_ZA7D]]) : (i32) -> ()
-  // CHECK-NEXT: %[[ZERO_ZA7D:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA7D]] : i64 to vector<[2]x[2]xi64>
-  %zero_za7d = arm_sme.zero : vector<[2]x[2]xi64>
-  "prevent.dce"(%zero_za7d) : (vector<[2]x[2]xi64>) -> ()
+  // CHECK-NEXT: %[[ZERO_ZA7D:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA7D]] : i64 to vector<[2]x[2]xf64>
+  %zero_za7d = arm_sme.zero : vector<[2]x[2]xf64>
+  "prevent.dce"(%zero_za7d) : (vector<[2]x[2]xf64>) -> ()
   return
 }


        


More information about the Mlir-commits mailing list