[Mlir-commits] [mlir] 834cdc8 - [mlir][ArmSME] Fix get_tile_id type in zero lowering
Cullen Rhodes
llvmlistbot at llvm.org
Wed Aug 30 00:16:48 PDT 2023
Author: Cullen Rhodes
Date: 2023-08-30T07:16:35Z
New Revision: 834cdc8b64d572d2b2aae28c916a6b27dca1eb65
URL: https://github.com/llvm/llvm-project/commit/834cdc8b64d572d2b2aae28c916a6b27dca1eb65
DIFF: https://github.com/llvm/llvm-project/commit/834cdc8b64d572d2b2aae28c916a6b27dca1eb65.diff
LOG: [mlir][ArmSME] Fix get_tile_id type in zero lowering
The arm_sme.get_tile_id op returns a scalar integer but the arm_sme.zero
op lowering incorrectly uses the element type, which could be
floating-point.
Reviewed By: awarzynski, benmxwl-arm
Differential Revision: https://reviews.llvm.org/D159080
Added:
Modified:
mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
mlir/test/Dialect/ArmSME/tile-zero-masks.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
index 88beb80de934fe..685f8d57f76f52 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
@@ -92,11 +92,12 @@ struct ZeroOpConversion : public ConvertOpToLLVMPattern<ZeroOp> {
ConversionPatternRewriter &rewriter) const override {
auto loc = zero.getLoc();
+ unsigned tileElementWidth =
+ zero.getVectorType().getElementType().getIntOrFloatBitWidth();
+
// Get Tile ID for the `zero` intrinsic.
auto tileId = rewriter.create<arm_sme::GetTileID>(
- loc, zero.getVectorType().getElementType());
-
- auto tileElementWidth = tileId.getType().getIntOrFloatBitWidth();
+ loc, rewriter.getIntegerType(tileElementWidth));
// Get the base mask for tile based on the element size.
// The base mask is just the mask to zero the first tile (of a size).
diff --git a/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir b/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir
index 0ff136dd47685f..26cd91bd3e8956 100644
--- a/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir
+++ b/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir
@@ -3,6 +3,9 @@
// RUN: -allow-unregistered-dialect \
// RUN: | FileCheck %s
+// This test verifies the tile mask operand of the zero intrinsic zeroes
+// the correct tiles. Both integer and floating-point datatypes are checked.
+
// -----
// CHECK-LABEL: zero_za_b
@@ -32,9 +35,9 @@ func.func @zero_za_h() {
%zero_za0h = arm_sme.zero : vector<[8]x[8]xi16>
"prevent.dce"(%zero_za0h) : (vector<[8]x[8]xi16>) -> ()
// CHECK: "arm_sme.intr.zero"(%[[ZERO_MASK_ZA1H]]) : (i32) -> ()
- // CHECK-NEXT: %[[ZERO_ZA1H:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA1H]] : i16 to vector<[8]x[8]xi16>
- %zero_za1h = arm_sme.zero : vector<[8]x[8]xi16>
- "prevent.dce"(%zero_za1h) : (vector<[8]x[8]xi16>) -> ()
+ // CHECK-NEXT: %[[ZERO_ZA1H:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA1H]] : i16 to vector<[8]x[8]xf16>
+ %zero_za1h = arm_sme.zero : vector<[8]x[8]xf16>
+ "prevent.dce"(%zero_za1h) : (vector<[8]x[8]xf16>) -> ()
return
}
@@ -65,9 +68,9 @@ func.func @zero_za_s() {
%zero_za2s = arm_sme.zero : vector<[4]x[4]xi32>
"prevent.dce"(%zero_za2s) : (vector<[4]x[4]xi32>) -> ()
// CHECK: "arm_sme.intr.zero"(%[[ZERO_MASK_ZA3S]]) : (i32) -> ()
- // CHECK-NEXT: %[[ZERO_ZA3S:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA3S]] : i32 to vector<[4]x[4]xi32>
- %zero_za3s = arm_sme.zero : vector<[4]x[4]xi32>
- "prevent.dce"(%zero_za3s) : (vector<[4]x[4]xi32>) -> ()
+ // CHECK-NEXT: %[[ZERO_ZA3S:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA3S]] : i32 to vector<[4]x[4]xf32>
+ %zero_za3s = arm_sme.zero : vector<[4]x[4]xf32>
+ "prevent.dce"(%zero_za3s) : (vector<[4]x[4]xf32>) -> ()
return
}
@@ -122,8 +125,8 @@ func.func @zero_za_d() {
%zero_za6d = arm_sme.zero : vector<[2]x[2]xi64>
"prevent.dce"(%zero_za6d) : (vector<[2]x[2]xi64>) -> ()
// CHECK: "arm_sme.intr.zero"(%[[ZERO_MASK_ZA7D]]) : (i32) -> ()
- // CHECK-NEXT: %[[ZERO_ZA7D:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA7D]] : i64 to vector<[2]x[2]xi64>
- %zero_za7d = arm_sme.zero : vector<[2]x[2]xi64>
- "prevent.dce"(%zero_za7d) : (vector<[2]x[2]xi64>) -> ()
+ // CHECK-NEXT: %[[ZERO_ZA7D:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA7D]] : i64 to vector<[2]x[2]xf64>
+ %zero_za7d = arm_sme.zero : vector<[2]x[2]xf64>
+ "prevent.dce"(%zero_za7d) : (vector<[2]x[2]xf64>) -> ()
return
}
More information about the Mlir-commits
mailing list