[Mlir-commits] [mlir] 2f055dd - [mlir][ArmSME] Add tile slice layout attr to vector <-> tile ops (#69186)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Oct 25 06:44:35 PDT 2023
Author: Cullen Rhodes
Date: 2023-10-25T14:44:31+01:00
New Revision: 2f055ddca3fc365b850b5712cca4002185fd3933
URL: https://github.com/llvm/llvm-project/commit/2f055ddca3fc365b850b5712cca4002185fd3933
DIFF: https://github.com/llvm/llvm-project/commit/2f055ddca3fc365b850b5712cca4002185fd3933.diff
LOG: [mlir][ArmSME] Add tile slice layout attr to vector <-> tile ops (#69186)
This is used in #69148 when lowering masked tile_store with non-zero
pad, see #69148
This updates:
* `arm_sme.move_vector_to_tile_slice`
* `arm_sme.move_tile_slice_to_vector`
Added:
Modified:
mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
mlir/test/Dialect/ArmSME/roundtrip.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
index dab54b63d8d22be..9b9dbff10ea2da6 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
@@ -441,21 +441,24 @@ def MoveVectorToTileSliceOp : ArmSME_Op<"move_vector_to_tile_slice", [
of a 2-D scalable vector tile at the given index. The type of the 1-D
scalable vector to be moved must match the type of the tile slice. A tile
slice is a 1-D vector of horizontally or vertically contiguous elements
- within a ZA tile. Horizontal tile slices are currently assumed when
- lowering to intrinsics. The updated tile is returned as the result.
+ within a ZA tile. The updated tile is returned as the result.
- Example 1: Move a vector<[16]xi8> into tile at given index.
+ An optional tile slice layout attribute specifies whether the tile slice is
+ horizontal (default) or vertical.
+
+ Example 1: Move a vector<[16]xi8> into tile horizontally (default) at given index.
```mlir
%tile_update = arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index : vector<[16]xi8> into vector<[16]x[16]xi8>
```
- Example 2: Move a vector<[2]xf64> into tile at given index.
+ Example 2: Move a vector<[2]xf64> into tile vertically at given index.
```mlir
- %tile_update = arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index : vector<[2]xf64> into vector<[2]x[2]xf64>
+ %tile_update = arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index layout<vertical> : vector<[2]xf64> into vector<[2]x[2]xf64>
```
}];
let arguments = (ins
- SVEVector:$vector, SMETile:$tile, Index:$tile_slice_index);
+ SVEVector:$vector, SMETile:$tile, Index:$tile_slice_index,
+ ArmSME_TileSliceLayoutAttr:$layout);
let results = (outs SMETile:$result);
let extraClassDeclaration = [{
@@ -465,7 +468,7 @@ def MoveVectorToTileSliceOp : ArmSME_Op<"move_vector_to_tile_slice", [
}];
let assemblyFormat = [{
- $vector `,` $tile `,` $tile_slice_index
+ $vector `,` $tile `,` $tile_slice_index (`layout` `` $layout^)?
attr-dict `:` type($vector) `into` type($result)
}];
}
@@ -480,21 +483,26 @@ def MoveTileSliceToVectorOp : ArmSME_Op<"move_tile_slice_to_vector", [Pure,
let description = [{
The tile slice to vector operation extracts a 1-D scalable slice from a 2-D
scalable tile at the given index. A tile slice is a 1-D vector of
- horizontally or vertically contiguous elements within a ZA tile. Horizontal
- tile slices are currently assumed when lowering to intrinsics.
+ horizontally or vertically contiguous elements within a ZA tile.
+
+ An optional tile slice layout attribute specifies whether the tile slice is
+ horizontal (default) or vertical.
- Example 1: Extract `vector<[16]xi8>` from tile at the given index.
+ Example 1: Extract `vector<[16]xi8>` from tile horizontally at the given index.
```mlir
%slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[16]xi8> from vector<[16]x[16]xi8>
```
- Example 2: Extract `vector<[2]xf64>` from tile at the given index.
+ Example 2: Extract `vector<[2]xf64>` from tile vertically at the given index.
```mlir
- %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[2]xf64> from vector<[2]x[2]xf64>
+ %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] layout<vertical> : vector<[2]xf64> from vector<[2]x[2]xf64>
```
}];
- let arguments = (ins SMETile:$tile, Index:$tile_slice_index);
+ let arguments = (ins
+ SMETile:$tile, Index:$tile_slice_index,
+ ArmSME_TileSliceLayoutAttr:$layout
+ );
let results = (outs SVEVector:$result);
let extraClassDeclaration = [{
@@ -502,7 +510,7 @@ def MoveTileSliceToVectorOp : ArmSME_Op<"move_tile_slice_to_vector", [Pure,
}];
let assemblyFormat = [{
- $tile `[` $tile_slice_index `]` attr-dict
+ $tile `[` $tile_slice_index `]` (`layout` `` $layout^)? attr-dict
`:` type($result) `from` type($tile)
}];
}
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
index 5e13707ea0aa2b9..1231da356f8ed95 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
@@ -350,8 +350,7 @@ struct StoreTileSliceToArmSMELowering
}
};
-/// Lower `arm_sme.move_vector_to_tile_slice` to SME intrinsics. Only horizontal
-/// tile slices are currently supported.
+/// Lower `arm_sme.move_vector_to_tile_slice` to SME intrinsics.
struct MoveVectorToTileSliceToArmSMELowering
: public ConvertOpToLLVMPattern<arm_sme::MoveVectorToTileSliceOp> {
using ConvertOpToLLVMPattern<
@@ -388,10 +387,19 @@ struct MoveVectorToTileSliceToArmSMELowering
auto tileI32 = castTileIDToI32(tile, loc, rewriter);
- // Create 'arm_sme.intr.write.horiz' to write vector to tile slice.
- rewriter.create<arm_sme::aarch64_sme_write_horiz>(
- loc, tileI32, tileSliceI32, allActiveMask,
- moveVectorToTileSliceOp.getVector());
+ // Create 'arm_sme.intr.write.(horiz|vert)' to write vector to tile slice.
+ switch (moveVectorToTileSliceOp.getLayout()) {
+ case arm_sme::TileSliceLayout::Horizontal:
+ rewriter.create<arm_sme::aarch64_sme_write_horiz>(
+ loc, tileI32, tileSliceI32, allActiveMask,
+ moveVectorToTileSliceOp.getVector());
+ break;
+ case arm_sme::TileSliceLayout::Vertical:
+ rewriter.create<arm_sme::aarch64_sme_write_vert>(
+ loc, tileI32, tileSliceI32, allActiveMask,
+ moveVectorToTileSliceOp.getVector());
+ break;
+ }
// Intrinsic has no result, replace 'arm_sme.move_vector_to_tile_slice' with
// 'arm_sme.cast_tile_to_vector' to preserve dataflow.
@@ -402,8 +410,7 @@ struct MoveVectorToTileSliceToArmSMELowering
}
};
-/// Lower `arm_sme.move_tile_slice_to_vector` to SME intrinsics. Only horizontal
-/// tile slices are currently supported.
+/// Lower `arm_sme.move_tile_slice_to_vector` to SME intrinsics.
struct MoveTileSliceToVectorArmSMELowering
: public ConvertOpToLLVMPattern<arm_sme::MoveTileSliceToVectorOp> {
using ConvertOpToLLVMPattern<
@@ -435,10 +442,19 @@ struct MoveTileSliceToVectorArmSMELowering
auto sliceIndexI32 = rewriter.create<arith::IndexCastOp>(
loc, rewriter.getI32Type(), sliceIndex);
- // Create 'arm_sme.intr.read.horiz' to extract the tile slice.
- rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_read_horiz>(
- moveTileSliceToVector, sliceType, zeroVector, allTruePredicate,
- tileIdI32, sliceIndexI32);
+ // Create 'arm_sme.intr.read.(horiz|vert)' to extract the tile slice.
+ switch (moveTileSliceToVector.getLayout()) {
+ case arm_sme::TileSliceLayout::Horizontal:
+ rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_read_horiz>(
+ moveTileSliceToVector, sliceType, zeroVector, allTruePredicate,
+ tileIdI32, sliceIndexI32);
+ break;
+ case arm_sme::TileSliceLayout::Vertical:
+ rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_read_vert>(
+ moveTileSliceToVector, sliceType, zeroVector, allTruePredicate,
+ tileIdI32, sliceIndexI32);
+ break;
+ }
return success();
}
@@ -680,7 +696,8 @@ void mlir::configureArmSMELegalizeForExportTarget(
arm_sme::aarch64_sme_st1b_vert, arm_sme::aarch64_sme_st1h_vert,
arm_sme::aarch64_sme_st1w_vert, arm_sme::aarch64_sme_st1d_vert,
arm_sme::aarch64_sme_st1q_vert, arm_sme::aarch64_sme_read_horiz,
- arm_sme::aarch64_sme_write_horiz, arm_sme::aarch64_sme_mopa,
+ arm_sme::aarch64_sme_read_vert, arm_sme::aarch64_sme_write_horiz,
+ arm_sme::aarch64_sme_write_vert, arm_sme::aarch64_sme_mopa,
arm_sme::aarch64_sme_za_enable, arm_sme::aarch64_sme_za_disable>();
target.addLegalOp<GetTileID>();
target.addIllegalOp<vector::OuterProductOp>();
diff --git a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
index 07485b3ee8ddf86..9074f0a7ee655c1 100644
--- a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
+++ b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
@@ -400,6 +400,29 @@ func.func @arm_sme_store_tile_slice_ver_f64(%tile : vector<[2]x[2]xf64>, %tile_s
return
}
+//===----------------------------------------------------------------------===//
+// arm_sme.move_vector_to_tile_slice
+//===----------------------------------------------------------------------===//
+
+// -----
+
+// CHECK-LABEL: @arm_sme_move_vector_to_tile_slice_hor_i32
+// CHECK: "arm_sme.intr.write.horiz"({{.*}}) : (i32, i32, vector<[4]xi1>, vector<[4]xi32>) -> ()
+func.func @arm_sme_move_vector_to_tile_slice_hor_i32(%vector : vector<[4]xi32>, %tile : vector<[4]x[4]xi32>, %tile_slice_index : index) -> () {
+ %c0 = arith.constant 0 : index
+ arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index : vector<[4]xi32> into vector<[4]x[4]xi32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_move_vector_to_tile_slice_ver_bf16
+// CHECK: "arm_sme.intr.write.vert"({{.*}}) : (i32, i32, vector<[8]xi1>, vector<[8]xbf16>) -> ()
+func.func @arm_sme_move_vector_to_tile_slice_ver_bf16(%vector : vector<[8]xbf16>, %tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) -> () {
+ %c0 = arith.constant 0 : index
+ arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index layout<vertical> : vector<[8]xbf16> into vector<[8]x[8]xbf16>
+ return
+}
//===----------------------------------------------------------------------===//
// arm_sme.move_tile_slice_to_vector
@@ -485,3 +508,12 @@ func.func @arm_sme_move_tile_slice_to_vector_f64(%tile : vector<[2]x[2]xf64>, %t
%slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[2]xf64> from vector<[2]x[2]xf64>
return %slice : vector<[2]xf64>
}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_move_tile_slice_to_vector_ver_i128
+// CHECK: "arm_sme.intr.read.vert"({{.*}}) : (vector<[1]xi128>, vector<[1]xi1>, i32, i32) -> vector<[1]xi128>
+func.func @arm_sme_move_tile_slice_to_vector_ver_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index) -> vector<[1]xi128> {
+ %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] layout<vertical> : vector<[1]xi128> from vector<[1]x[1]xi128>
+ return %slice : vector<[1]xi128>
+}
diff --git a/mlir/test/Dialect/ArmSME/roundtrip.mlir b/mlir/test/Dialect/ArmSME/roundtrip.mlir
index 427154158e797fd..e5ba81eff836027 100644
--- a/mlir/test/Dialect/ArmSME/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSME/roundtrip.mlir
@@ -1059,6 +1059,14 @@ func.func @arm_sme_move_vector_to_tile_slice_f64(%vector : vector<[2]xf64>, %til
return
}
+// -----
+
+func.func @arm_sme_move_vector_to_tile_slice_ver_i8(%vector : vector<[16]xi8>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) -> () {
+ // CHECK: arm_sme.move_vector_to_tile_slice {{.*}} layout<vertical> : vector<[16]xi8> into vector<[16]x[16]xi8>
+ %c0 = arith.constant 0 : index
+ arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index layout<vertical> : vector<[16]xi8> into vector<[16]x[16]xi8>
+ return
+}
//===----------------------------------------------------------------------===//
// arm_sme.move_tile_slice_to_vector
@@ -1135,3 +1143,11 @@ func.func @arm_sme_move_tile_slice_to_vector_f64(%tile : vector<[2]x[2]xf64>, %t
%slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[2]xf64> from vector<[2]x[2]xf64>
return %slice : vector<[2]xf64>
}
+
+// -----
+
+func.func @arm_sme_move_tile_slice_to_vector_ver_f64(%tile : vector<[2]x[2]xf64>, %tile_slice_index : index) -> vector<[2]xf64> {
+ // CHECK: arm_sme.move_tile_slice_to_vector {{.*}} layout<vertical> : vector<[2]xf64> from vector<[2]x[2]xf64>
+ %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] layout<vertical> : vector<[2]xf64> from vector<[2]x[2]xf64>
+ return %slice : vector<[2]xf64>
+}
More information about the Mlir-commits
mailing list