[PATCH] D159080: [mlir][ArmSME] Fix get_tile_id type in zero lowering
Cullen Rhodes via Phabricator via llvm-commits
llvm-commits at lists.llvm.org
Tue Aug 29 03:37:42 PDT 2023
c-rhodes created this revision.
c-rhodes added reviewers: awarzynski, benmxwl-arm.
c-rhodes added a project: MLIR.
Herald added subscribers: gysit, Dinistro, bviyer, Moerafaat, zero9178, bzcheeseman, sdasgup3, wenzhicui, wrengr, cota, teijeong, rdzhabarov, tatianashp, msifontes, jurahul, Kayjukh, grosul1, Joonsoo, liufengdb, aartbik, mgester, arpith-jacob, antiagainst, shauheen, rriddle, mehdi_amini, kristof.beyls.
Herald added a reviewer: ftynse.
Herald added a reviewer: dcaballe.
Herald added a project: All.
c-rhodes requested review of this revision.
Herald added a reviewer: nicolasvasilache.
Herald added subscribers: stephenneuendorffer, nicolasvasilache.
The arm_sme.get_tile_id op returns a scalar integer but the arm_sme.zero
op lowering incorrectly uses the element type, which could be
floating-point.
Repository:
rG LLVM Github Monorepo
https://reviews.llvm.org/D159080
Files:
mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
mlir/test/Dialect/ArmSME/tile-zero-masks.mlir
Index: mlir/test/Dialect/ArmSME/tile-zero-masks.mlir
===================================================================
--- mlir/test/Dialect/ArmSME/tile-zero-masks.mlir
+++ mlir/test/Dialect/ArmSME/tile-zero-masks.mlir
@@ -3,6 +3,9 @@
// RUN: -allow-unregistered-dialect \
// RUN: | FileCheck %s
+// This test verifies the 8-bit tile mask operand of the zero intrinsic zeroes
+// the correct tiles. Both integer and floating-point datatypes are checked.
+
// -----
// CHECK-LABEL: zero_za_b
@@ -32,9 +35,9 @@
%zero_za0h = arm_sme.zero : vector<[8]x[8]xi16>
"prevent.dce"(%zero_za0h) : (vector<[8]x[8]xi16>) -> ()
// CHECK: "arm_sme.intr.zero"(%[[ZERO_MASK_ZA1H]]) : (i32) -> ()
- // CHECK-NEXT: %[[ZERO_ZA1H:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA1H]] : i16 to vector<[8]x[8]xi16>
- %zero_za1h = arm_sme.zero : vector<[8]x[8]xi16>
- "prevent.dce"(%zero_za1h) : (vector<[8]x[8]xi16>) -> ()
+ // CHECK-NEXT: %[[ZERO_ZA1H:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA1H]] : i16 to vector<[8]x[8]xf16>
+ %zero_za1h = arm_sme.zero : vector<[8]x[8]xf16>
+ "prevent.dce"(%zero_za1h) : (vector<[8]x[8]xf16>) -> ()
return
}
@@ -65,9 +68,9 @@
%zero_za2s = arm_sme.zero : vector<[4]x[4]xi32>
"prevent.dce"(%zero_za2s) : (vector<[4]x[4]xi32>) -> ()
// CHECK: "arm_sme.intr.zero"(%[[ZERO_MASK_ZA3S]]) : (i32) -> ()
- // CHECK-NEXT: %[[ZERO_ZA3S:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA3S]] : i32 to vector<[4]x[4]xi32>
- %zero_za3s = arm_sme.zero : vector<[4]x[4]xi32>
- "prevent.dce"(%zero_za3s) : (vector<[4]x[4]xi32>) -> ()
+ // CHECK-NEXT: %[[ZERO_ZA3S:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA3S]] : i32 to vector<[4]x[4]xf32>
+ %zero_za3s = arm_sme.zero : vector<[4]x[4]xf32>
+ "prevent.dce"(%zero_za3s) : (vector<[4]x[4]xf32>) -> ()
return
}
@@ -122,8 +125,8 @@
%zero_za6d = arm_sme.zero : vector<[2]x[2]xi64>
"prevent.dce"(%zero_za6d) : (vector<[2]x[2]xi64>) -> ()
// CHECK: "arm_sme.intr.zero"(%[[ZERO_MASK_ZA7D]]) : (i32) -> ()
- // CHECK-NEXT: %[[ZERO_ZA7D:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA7D]] : i64 to vector<[2]x[2]xi64>
- %zero_za7d = arm_sme.zero : vector<[2]x[2]xi64>
- "prevent.dce"(%zero_za7d) : (vector<[2]x[2]xi64>) -> ()
+ // CHECK-NEXT: %[[ZERO_ZA7D:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA7D]] : i64 to vector<[2]x[2]xf64>
+ %zero_za7d = arm_sme.zero : vector<[2]x[2]xf64>
+ "prevent.dce"(%zero_za7d) : (vector<[2]x[2]xf64>) -> ()
return
}
Index: mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
===================================================================
--- mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
+++ mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
@@ -92,11 +92,12 @@
ConversionPatternRewriter &rewriter) const override {
auto loc = zero.getLoc();
+ unsigned tileElementWidth =
+ zero.getVectorType().getElementType().getIntOrFloatBitWidth();
+
// Get Tile ID for the `zero` intrinsic.
auto tileId = rewriter.create<arm_sme::GetTileID>(
- loc, zero.getVectorType().getElementType());
-
- auto tileElementWidth = tileId.getType().getIntOrFloatBitWidth();
+ loc, rewriter.getIntegerType(tileElementWidth));
// Get the base mask for tile based on the element size.
// The base mask is just the mask to zero the first tile (of a size).
-------------- next part --------------
A non-text attachment was scrubbed...
Name: D159080.554246.patch
Type: text/x-patch
Size: 3435 bytes
Desc: not available
URL: <http://lists.llvm.org/pipermail/llvm-commits/attachments/20230829/47463a67/attachment.bin>
More information about the llvm-commits
mailing list