[Mlir-commits] [mlir] a4d87e3 - [mlir][ArmSME] Calculate correct tile mask when lowering arm_sme.zero
Benjamin Maxwell
llvmlistbot at llvm.org
Fri Aug 18 02:35:44 PDT 2023
Author: Benjamin Maxwell
Date: 2023-08-18T09:34:29Z
New Revision: a4d87e3d0655675318d2b4e420afaea65afa8f55
URL: https://github.com/llvm/llvm-project/commit/a4d87e3d0655675318d2b4e420afaea65afa8f55
DIFF: https://github.com/llvm/llvm-project/commit/a4d87e3d0655675318d2b4e420afaea65afa8f55.diff
LOG: [mlir][ArmSME] Calculate correct tile mask when lowering arm_sme.zero
This patch updates the lowering of the arm_sme.zero to intrinsics so
that it calculates the correct mask for the tile to zero.
The zero instruction takes an 8-bit mask which specifies which 64-bit
tiles to zero, ZA0.D to ZA7.D correspond to bits 0 to 7. To zero tiles
with element sizes of 8-bit to 32-bit just requires zeroing the right
64-bit tiles.
This is quite easy to calculate, each size has a "base mask" which can
be shifted left by the tile ID to get the mask for that tile.
base_mask << tile_id
After tile allocation, this will be folded to a constant mask.
Reviewed By: awarzynski
Differential Revision: https://reviews.llvm.org/D157902
Added:
mlir/test/Dialect/ArmSME/tile-zero-masks.mlir
Modified:
mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
index e1df09ff9e0758..2f4dee7ba916e8 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
@@ -20,8 +20,6 @@
using namespace mlir;
using namespace mlir::arm_sme;
-static constexpr unsigned kZeroZAMask = 255;
-
namespace {
/// Insert 'llvm.aarch64.sme.za.enable' intrinsic at the start of 'func.func'
/// ops to enable the ZA storage array.
@@ -51,21 +49,41 @@ struct DisableZAPattern : public OpRewritePattern<func::ReturnOp> {
}
};
-/// Lower 'arm_sme.zero'. Use 'arm_sme.cast_tile_to_vector' to model the return
-/// value. The latter is a nop, which should be folded away (e.g. during
-/// canonicalisation).
+/// Extends or truncates `tile`, which should be an `arm_sme::GetTileID` or
+/// `arm_sme::CastVectorToTile` op returning an 8/16/32/64/128-bit scalar
+/// integer, to an i32 that can be passed as the `tile` parameter to the SME
+/// intrinsics. Or returns `tile` if already i32.
+Value castTileIDToI32(Value tile, Location loc,
+ ConversionPatternRewriter &rewriter) {
+ assert((isa<arm_sme::GetTileID, arm_sme::CastVectorToTile>(
+ tile.getDefiningOp())) &&
+ "expected ArmSME GetTileID or CastVectorToTile op!");
+ unsigned tileElementWidth = tile.getType().getIntOrFloatBitWidth();
+ if (tileElementWidth < 32)
+ return rewriter.create<arith::ExtUIOp>(loc, rewriter.getI32Type(), tile);
+ if (tileElementWidth > 32)
+ return rewriter.create<arith::TruncIOp>(loc, rewriter.getI32Type(), tile);
+ return tile;
+}
+
+/// Lower 'arm_sme.zero' to SME intrinsics.
///
/// BEFORE:
/// ```mlir
-/// %0 = arm_sme.zero : vector<[16]x[16]xi8>
+/// %v = arm_sme.zero : vector<[4]x[4]xi32>
/// ```
///
/// AFTER:
/// ```mlir
-/// %1 = arm_sme.get_tile_id : i8
-/// %2 = arm_sme.cast_tile_to_vector %1 : i8 to vector<[16]x[16]xi8>
-/// "arm_sme.intr.zero"(%c255_i32) : (i32) -> ()
+/// %tile_id = arm_sme.get_tile_id : i32
+/// %zero_mask = arith.shli %c17_i32, %tile_id : i32
+/// "arm_sme.intr.zero"(%zero_mask) : (i32) -> ()
+/// %v = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[4]xi32>
/// ```
+///
+/// The 'arm_sme.cast_tile_to_vector' (which models the return) and the
+/// 'arith.shli' (which generates the mask) will be folded away after tile
+/// allocation and canonization.
struct ZeroOpConversion : public ConvertOpToLLVMPattern<ZeroOp> {
using ConvertOpToLLVMPattern<ZeroOp>::ConvertOpToLLVMPattern;
@@ -75,18 +93,69 @@ struct ZeroOpConversion : public ConvertOpToLLVMPattern<ZeroOp> {
auto loc = zero.getLoc();
// Get Tile ID for the `zero` intrinsic.
- // TODO: Map this to a valid `mask` for the `zero` intrinsic.
auto tileId = rewriter.create<arm_sme::GetTileID>(
loc, zero.getVectorType().getElementType());
- // Create 'arm_sme.intr.zero' intrinsic to zero ZA.
- // FIXME: Replace the hard-coded mask with a valid value based
- // on `tileId`.
- auto mask = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(kZeroZAMask));
- rewriter.create<arm_sme::aarch64_sme_zero>(loc, mask);
-
- // Create `CastTileToVectorOp` to use it as the output
+ auto tileElementWidth = tileId.getType().getIntOrFloatBitWidth();
+
+ // Get the base mask for tile based on the element size.
+ // The base mask is just the mask to zero the first tile (of a size).
+ // These masks are derived from:
+ // https://developer.arm.com/documentation/ddi0602/2022-06/SME-Instructions/ZERO--Zero-a-list-of-64-bit-element-ZA-tiles-
+ auto baseMaskForSize = [&] {
+ switch (tileElementWidth) {
+ case 8:
+ // Zeroing the 8-bit ZA0.B tile is equivalent to zeroing all eight
+ // 64-bit element tiles named ZA0.D to ZA7.D.
+ return 0b1111'1111;
+ case 16:
+ // Zeroing the 16-bit ZA0.H tile is equivalent to zeroing 64-bit element
+ // tiles named ZA0.D, ZA2.D, ZA4.D, and ZA6.D.
+ // Shift this left once for ZA1.H.
+ return 0b0101'0101;
+ case 32:
+ // Zeroing the 32-bit ZA0.S tile is equivalent to zeroing 64-bit
+ // element tiles named ZA0.D and ZA4.D.
+ // Shift left by 1, 2, or 3 respectively for ZA1.S, ZA2.S, ZA3.S.
+ return 0b0001'0001;
+ case 64:
+ // Zeroing one of the a 64-bit tiles ZA0.D to ZA7.D just requires
+ // setting the bit for that tile.
+ return 0b0000'0001;
+ default:
+ llvm_unreachable("bad element size");
+ }
+ }();
+ auto maskType = rewriter.getI32Type();
+ auto baseMask = rewriter.create<arith::ConstantOp>(
+ loc, maskType, rewriter.getIntegerAttr(maskType, baseMaskForSize));
+
+ // The actual mask is just the base mask shifted by the tile ID.
+ // This will be folded to a constant after tile allocation.
+ //
+ // The shift is just derived from the layout of the tiles, and that the tile
+ // ID is the index of the tile. For example, looking at the 32-bit ZAx.S
+ // tiles:
+ //
+ // ZA0.S = ZA0.D and ZA4.D
+ // * Tile ID -> 0
+ // * Mask -> 00010001 = (00010001 << 0)
+ // ZA1.S = ZA1.D and ZA5.D
+ // * Tile ID -> 1
+ // * Mask -> 00100010 = (00010001 << 1)
+ // ZA2.S = ZA2.D and ZA6.D
+ // * Tile ID -> 2
+ // * Mask -> 01000100 = (00010001 << 2)
+ // ZA3.S = ZA3.D and ZA7.D
+ // * Tile ID -> 3
+ // * Mask -> 10001000 = (00010001 << 3)
+ //
+ // This holds for all tile sizes.
+ auto tileMask = rewriter.create<arith::ShLIOp>(
+ loc, baseMask, castTileIDToI32(tileId, loc, rewriter));
+ rewriter.create<arm_sme::aarch64_sme_zero>(loc, tileMask);
+
+ // Create `CastTileToVectorOp` to use as the output.
rewriter.replaceOpWithNewOp<arm_sme::CastTileToVector>(zero, zero.getType(),
tileId);
@@ -94,23 +163,6 @@ struct ZeroOpConversion : public ConvertOpToLLVMPattern<ZeroOp> {
}
};
-/// Extends or truncates `tile`, which should be an `arm_sme::GetTileID` or
-/// `arm_sme::CastVectorToTile` op returning an 8/16/32/64/128-bit scalar
-/// integer, to an i32 that can be passed as the `tile` parameter to the SME
-/// intrinsics. Or returns `tile` if already i32.
-Value castTileIDToI32(Value tile, Location loc,
- ConversionPatternRewriter &rewriter) {
- assert((isa<arm_sme::GetTileID, arm_sme::CastVectorToTile>(
- tile.getDefiningOp())) &&
- "expected ArmSME GetTileID or CastVectorToTile op!");
- unsigned tileElementWidth = tile.getType().getIntOrFloatBitWidth();
- if (tileElementWidth < 32)
- return rewriter.create<arith::ExtUIOp>(loc, rewriter.getI32Type(), tile);
- if (tileElementWidth > 32)
- return rewriter.create<arith::TruncIOp>(loc, rewriter.getI32Type(), tile);
- return tile;
-}
-
/// Lower `arm_sme.load_tile_slice` to SME intrinsics.
struct LoadTileSliceToArmSMELowering
: public ConvertOpToLLVMPattern<arm_sme::LoadTileSliceOp> {
diff --git a/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir b/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir
new file mode 100644
index 00000000000000..0ff136dd47685f
--- /dev/null
+++ b/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir
@@ -0,0 +1,129 @@
+// RUN: mlir-opt %s -convert-vector-to-llvm="enable-arm-sme" \
+// RUN: -allocate-arm-sme-tiles -canonicalize \
+// RUN: -allow-unregistered-dialect \
+// RUN: | FileCheck %s
+
+// -----
+
+// CHECK-LABEL: zero_za_b
+func.func @zero_za_b() {
+ // CHECK-DAG: %[[TILE_ID:.*]] = arith.constant 0 : i8
+ // CHECK-DAG: %[[ZERO_MASK:.*]] = arith.constant 255 : i32
+
+ // CHECK: "arm_sme.intr.zero"(%[[ZERO_MASK]]) : (i32) -> ()
+ // CHECK-NEXT: %[[ZERO_ZA0B:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i8 to vector<[16]x[16]xi8>
+ %zero_za0b = arm_sme.zero : vector<[16]x[16]xi8>
+ "prevent.dce"(%zero_za0b) : (vector<[16]x[16]xi8>) -> ()
+ return
+}
+
+// -----
+
+// CHECK-LABEL: zero_za_h
+func.func @zero_za_h() {
+ // CHECK-DAG: %[[TILE_ID_ZA0H:.*]] = arith.constant 0 : i16
+ // CHECK-DAG: %[[TILE_ID_ZA1H:.*]] = arith.constant 1 : i16
+
+ // CHECK-DAG: %[[ZERO_MASK_ZA0H:.*]] = arith.constant 85 : i32
+ // CHECK-DAG: %[[ZERO_MASK_ZA1H:.*]] = arith.constant 170 : i32
+
+ // CHECK: "arm_sme.intr.zero"(%[[ZERO_MASK_ZA0H]]) : (i32) -> ()
+ // CHECK-NEXT: %[[ZERO_ZA0H:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA0H]] : i16 to vector<[8]x[8]xi16>
+ %zero_za0h = arm_sme.zero : vector<[8]x[8]xi16>
+ "prevent.dce"(%zero_za0h) : (vector<[8]x[8]xi16>) -> ()
+ // CHECK: "arm_sme.intr.zero"(%[[ZERO_MASK_ZA1H]]) : (i32) -> ()
+ // CHECK-NEXT: %[[ZERO_ZA1H:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA1H]] : i16 to vector<[8]x[8]xi16>
+ %zero_za1h = arm_sme.zero : vector<[8]x[8]xi16>
+ "prevent.dce"(%zero_za1h) : (vector<[8]x[8]xi16>) -> ()
+ return
+}
+
+// -----
+
+// CHECK-LABEL: zero_za_s
+func.func @zero_za_s() {
+ // CHECK-DAG: %[[TILE_ID_ZA0S:.*]] = arith.constant 0 : i32
+ // CHECK-DAG: %[[TILE_ID_ZA1S:.*]] = arith.constant 1 : i32
+ // CHECK-DAG: %[[TILE_ID_ZA2S:.*]] = arith.constant 2 : i32
+ // CHECK-DAG: %[[TILE_ID_ZA3S:.*]] = arith.constant 3 : i32
+
+ // CHECK-DAG: %[[ZERO_MASK_ZA0S:.*]] = arith.constant 17 : i32
+ // CHECK-DAG: %[[ZERO_MASK_ZA1S:.*]] = arith.constant 34 : i32
+ // CHECK-DAG: %[[ZERO_MASK_ZA2S:.*]] = arith.constant 68 : i32
+ // CHECK-DAG: %[[ZERO_MASK_ZA3S:.*]] = arith.constant 136 : i32
+
+ // CHECK: "arm_sme.intr.zero"(%[[ZERO_MASK_ZA0S]]) : (i32) -> ()
+ // CHECK-NEXT: %[[ZERO_ZA0S:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA0S]] : i32 to vector<[4]x[4]xi32>
+ %zero_za0s = arm_sme.zero : vector<[4]x[4]xi32>
+ "prevent.dce"(%zero_za0s) : (vector<[4]x[4]xi32>) -> ()
+ // CHECK: "arm_sme.intr.zero"(%[[ZERO_MASK_ZA1S]]) : (i32) -> ()
+ // CHECK-NEXT: %[[ZERO_ZA1S:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA1S]] : i32 to vector<[4]x[4]xi32>
+ %zero_za1s = arm_sme.zero : vector<[4]x[4]xi32>
+ "prevent.dce"(%zero_za1s) : (vector<[4]x[4]xi32>) -> ()
+ // CHECK: "arm_sme.intr.zero"(%[[ZERO_MASK_ZA2S]]) : (i32) -> ()
+ // CHECK-NEXT: %[[ZERO_ZA2S:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA2S]] : i32 to vector<[4]x[4]xi32>
+ %zero_za2s = arm_sme.zero : vector<[4]x[4]xi32>
+ "prevent.dce"(%zero_za2s) : (vector<[4]x[4]xi32>) -> ()
+ // CHECK: "arm_sme.intr.zero"(%[[ZERO_MASK_ZA3S]]) : (i32) -> ()
+ // CHECK-NEXT: %[[ZERO_ZA3S:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA3S]] : i32 to vector<[4]x[4]xi32>
+ %zero_za3s = arm_sme.zero : vector<[4]x[4]xi32>
+ "prevent.dce"(%zero_za3s) : (vector<[4]x[4]xi32>) -> ()
+ return
+}
+
+// -----
+
+// CHECK-LABEL: zero_za_d
+func.func @zero_za_d() {
+ // CHECK-DAG: %[[TILE_ID_ZA0D:.*]] = arith.constant 0 : i64
+ // CHECK-DAG: %[[TILE_ID_ZA1D:.*]] = arith.constant 1 : i64
+ // CHECK-DAG: %[[TILE_ID_ZA2D:.*]] = arith.constant 2 : i64
+ // CHECK-DAG: %[[TILE_ID_ZA3D:.*]] = arith.constant 3 : i64
+ // CHECK-DAG: %[[TILE_ID_ZA4D:.*]] = arith.constant 4 : i64
+ // CHECK-DAG: %[[TILE_ID_ZA5D:.*]] = arith.constant 5 : i64
+ // CHECK-DAG: %[[TILE_ID_ZA6D:.*]] = arith.constant 6 : i64
+ // CHECK-DAG: %[[TILE_ID_ZA7D:.*]] = arith.constant 7 : i64
+
+ // CHECK-DAG: %[[ZERO_MASK_ZA0D:.*]] = arith.constant 1 : i32
+ // CHECK-DAG: %[[ZERO_MASK_ZA1D:.*]] = arith.constant 2 : i32
+ // CHECK-DAG: %[[ZERO_MASK_ZA2D:.*]] = arith.constant 4 : i32
+ // CHECK-DAG: %[[ZERO_MASK_ZA3D:.*]] = arith.constant 8 : i32
+ // CHECK-DAG: %[[ZERO_MASK_ZA4D:.*]] = arith.constant 16 : i32
+ // CHECK-DAG: %[[ZERO_MASK_ZA5D:.*]] = arith.constant 32 : i32
+ // CHECK-DAG: %[[ZERO_MASK_ZA6D:.*]] = arith.constant 64 : i32
+ // CHECK-DAG: %[[ZERO_MASK_ZA7D:.*]] = arith.constant 128 : i32
+
+ // CHECK: "arm_sme.intr.zero"(%[[ZERO_MASK_ZA0D]]) : (i32) -> ()
+ // CHECK-NEXT: %[[ZERO_ZA0D:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA0D]] : i64 to vector<[2]x[2]xi64>
+ %zero_za0d = arm_sme.zero : vector<[2]x[2]xi64>
+ "prevent.dce"(%zero_za0d) : (vector<[2]x[2]xi64>) -> ()
+ // CHECK: "arm_sme.intr.zero"(%[[ZERO_MASK_ZA1D]]) : (i32) -> ()
+ // CHECK-NEXT: %[[ZERO_ZA1D:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA1D]] : i64 to vector<[2]x[2]xi64>
+ %zero_za1d = arm_sme.zero : vector<[2]x[2]xi64>
+ "prevent.dce"(%zero_za1d) : (vector<[2]x[2]xi64>) -> ()
+ // CHECK: "arm_sme.intr.zero"(%[[ZERO_MASK_ZA2D]]) : (i32) -> ()
+ // CHECK-NEXT: %[[ZERO_ZA2D:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA2D]] : i64 to vector<[2]x[2]xi64>
+ %zero_za2d = arm_sme.zero : vector<[2]x[2]xi64>
+ "prevent.dce"(%zero_za2d) : (vector<[2]x[2]xi64>) -> ()
+ // CHECK: "arm_sme.intr.zero"(%[[ZERO_MASK_ZA3D]]) : (i32) -> ()
+ // CHECK-NEXT: %[[ZERO_ZA3D:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA3D]] : i64 to vector<[2]x[2]xi64>
+ %zero_za3d = arm_sme.zero : vector<[2]x[2]xi64>
+ "prevent.dce"(%zero_za3d) : (vector<[2]x[2]xi64>) -> ()
+ // CHECK: "arm_sme.intr.zero"(%[[ZERO_MASK_ZA4D]]) : (i32) -> ()
+ // CHECK-NEXT: %[[ZERO_ZA4D:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA4D]] : i64 to vector<[2]x[2]xi64>
+ %zero_za4d = arm_sme.zero : vector<[2]x[2]xi64>
+ "prevent.dce"(%zero_za4d) : (vector<[2]x[2]xi64>) -> ()
+ // CHECK: "arm_sme.intr.zero"(%[[ZERO_MASK_ZA5D]]) : (i32) -> ()
+ // CHECK-NEXT: %[[ZERO_ZA5D:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA5D]] : i64 to vector<[2]x[2]xi64>
+ %zero_za5d = arm_sme.zero : vector<[2]x[2]xi64>
+ "prevent.dce"(%zero_za5d) : (vector<[2]x[2]xi64>) -> ()
+ // CHECK: "arm_sme.intr.zero"(%[[ZERO_MASK_ZA6D]]) : (i32) -> ()
+ // CHECK-NEXT: %[[ZERO_ZA6D:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA6D]] : i64 to vector<[2]x[2]xi64>
+ %zero_za6d = arm_sme.zero : vector<[2]x[2]xi64>
+ "prevent.dce"(%zero_za6d) : (vector<[2]x[2]xi64>) -> ()
+ // CHECK: "arm_sme.intr.zero"(%[[ZERO_MASK_ZA7D]]) : (i32) -> ()
+ // CHECK-NEXT: %[[ZERO_ZA7D:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA7D]] : i64 to vector<[2]x[2]xi64>
+ %zero_za7d = arm_sme.zero : vector<[2]x[2]xi64>
+ "prevent.dce"(%zero_za7d) : (vector<[2]x[2]xi64>) -> ()
+ return
+}
diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
index ddd55319d347f3..de8bc5f93b2c76 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
@@ -9,8 +9,10 @@
// CHECK-DAG: %[[C255:.*]] = arith.constant 255 : i32
// CHECK-DAG: %[[PTRUE_ALL:.*]] = arith.constant dense<true> : vector<[16]xi1>
// CHECK-DAG: %[[C0_I64:.*]] = builtin.unrealized_conversion_cast %[[C0]] : index to i64
-// CHECK-DAG: "arm_sme.intr.zero"(%[[C255]]) : (i32) -> ()
// CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i8
+// CHECK-DAG: %[[EXT_TILE_ID:.*]] = arith.extui %[[TILE_ID]] : i8 to i32
+// CHECK-DAG: %[[TILE_MASK:.*]] = arith.shli %[[C255]], %[[EXT_TILE_ID]] : i32
+// CHECK-DAG: "arm_sme.intr.zero"(%[[TILE_MASK]]) : (i32) -> ()
// CHECK-DAG: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64
// CHECK-NEXT: %[[VSCALE_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index
// CHECK-NEXT: %[[SVL_B:.*]] = arith.muli %[[VSCALE_IDX]], %[[MIN_SVL_B]] : index
More information about the Mlir-commits
mailing list