[Mlir-commits] [mlir] [mlir][ArmSME] Add arm_sme.move_tile_slice_to_vector op (PR #67652)
Benjamin Maxwell
llvmlistbot at llvm.org
Thu Sep 28 08:22:55 PDT 2023
https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/67652
>From d09d976e08280ae6d6e3cfac0fba74f3730df946 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 28 Sep 2023 09:17:35 +0000
Subject: [PATCH 1/4] [mlir][ArmSME] Add arm_sme.move_tile_slice_to_vector op
This adds a simple higher-level op for the tile slice to vector
intrinsics (and updates the existing vector.print lowering to use it).
This op will be used a few more times to implement vector.insert/extract
lowerings in later patches.
---
mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td | 33 +++++++
.../Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp | 32 +------
.../Transforms/LegalizeForLLVMExport.cpp | 51 ++++++++++-
.../ArmSMEToSCF/arm-sme-to-scf.mlir | 16 ++--
mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir | 86 +++++++++++++++++++
mlir/test/Dialect/ArmSME/invalid.mlir | 8 ++
mlir/test/Dialect/ArmSME/roundtrip.mlir | 77 +++++++++++++++++
7 files changed, 261 insertions(+), 42 deletions(-)
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
index d10cee5956d5e5f..10b2ae32b06dd37 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
@@ -501,6 +501,39 @@ def MoveVectorToTileSliceOp : ArmSME_Op<"move_vector_to_tile_slice", [
}];
}
+def MoveTileSliceToVectorOp : ArmSME_Op<"move_tile_slice_to_vector", [Pure,
+ TypesMatchWith<
+ "type of 'result' matches type of 'tile' slice",
+ "tile", "result",
+ "VectorType(VectorType::Builder(::llvm::cast<mlir::VectorType>($_self)).dropDim(0))">,
+]> {
+ let summary = "Move slice of a 2-D tile to a 1-D scalable vector";
+ let description = [{
+ The tile slice to vector operation extracts a 1-D scalable slice from a 2-D
+ scalable tile at the give index. A tile slice is a 1-D vector of
+ horizontally or vertically contiguous elements within a ZA tile. Horizontal
+ tile slices are currently assumed when lowering to intrinsics.
+
+ Example 1: Extract `vector<[16]xi8>` from tile at the given index.
+ ```mlir
+ %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[16]xi8> from vector<[16]x[16]xi8>
+ ```
+
+ Example 2: Extract `vector<[2]xf64>` from tile at the given index.
+ ```mlir
+ %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[2]xf64> from vector<[2]x[2]xf64>
+ ```
+ }];
+
+ let arguments = (ins SMETile:$tile, Index:$tile_slice_index);
+ let results = (outs SVEVector:$result);
+
+ let assemblyFormat = [{
+ $tile `[` $tile_slice_index `]` attr-dict
+ `:` type($result) `from` type($tile)
+ }];
+}
+
//===----------------------------------------------------------------------===//
// ArmSME Intrinsic op definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index d084373439ab6bf..ec625a6806e2bee 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -202,16 +202,11 @@ struct TileStoreOpConversion : public OpRewritePattern<arm_sme::TileStoreOp> {
/// %c0 = arith.constant 0 : index
/// %c1 = arith.constant 1 : index
/// %c4 = arith.constant 4 : index
-/// %ptrue = arith.constant dense<true> : vector<[4]xi1>
-/// %tile_id = arm_sme.cast_vector_to_tile %tile : vector<[4]x[4]xf32> to i32
/// %vscale = vector.vscale
/// %svl_s = arith.muli %c4, %vscale : index
-/// %cst = arith.constant dense<0.000000e+00> : vector<[4]xf32>
/// scf.for %i = %c0 to %svl_s step %c1 {
-/// %slice_idx = arith.index_cast %i : index to i32
-/// %tile_slice = "arm_sme.intr.read.horiz"
-/// (%cst, %ptrue, %tile_id, %slice_idx)
-/// : (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xf32>
+/// %tile_slice = arm_sme.move_tile_slice_to_vector %tile[%i]
+/// : vector<[4]xf32> from vector<[4]x[4]xf32>
/// vector.print %tile_slice : vector<[4]xf32>
/// }
/// ```
@@ -229,23 +224,6 @@ struct TileVectorPrintOpConversion : public OpRewritePattern<vector::PrintOp> {
auto loc = printOp.getLoc();
- // Create an 'all true' predicate for each tile row.
- auto predicateType =
- VectorType::get(vectorType.getDimSize(1), rewriter.getI1Type(), true);
- auto allTruePredicate = rewriter.create<arith::ConstantOp>(
- loc, DenseElementsAttr::get(predicateType, true));
-
- // Cast tile to i32 tile ID.
- auto tileId =
- rewriter.create<arm_sme::CastVectorToTile>(loc, printOp.getSource());
- Value tileIdI32 = castTileIDToI32(tileId, loc, rewriter);
-
- // Zero destination/fallback for tile slice extraction.
- auto rowType = VectorType::get(vectorType.getDimSize(1),
- vectorType.getElementType(), true);
- auto zeroVector = rewriter.create<arith::ConstantOp>(
- loc, rowType, rewriter.getZeroAttr(rowType));
-
// Create a loop over the rows of the tile.
auto vscale = rewriter.create<vector::VectorScaleOp>(loc);
auto minTileRows =
@@ -259,10 +237,8 @@ struct TileVectorPrintOpConversion : public OpRewritePattern<vector::PrintOp> {
rewriter.setInsertionPointToStart(forOp.getBody());
// Extract the current row from the tile.
Value rowIndex = forOp.getInductionVar();
- auto rowIndexI32 = rewriter.create<arith::IndexCastOp>(
- loc, rewriter.getI32Type(), rowIndex);
- auto tileSlice = rewriter.create<arm_sme::aarch64_sme_read_horiz>(
- loc, rowType, zeroVector, allTruePredicate, tileIdI32, rowIndexI32);
+ auto tileSlice = rewriter.create<arm_sme::MoveTileSliceToVectorOp>(
+ loc, printOp.getSource(), rowIndex);
// Print the row with a 1D vector.print.
rewriter.create<vector::PrintOp>(loc, tileSlice,
printOp.getPunctuation());
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
index 945716eea57543d..7157a915ee8a02a 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
@@ -402,6 +402,48 @@ struct MoveVectorToTileSliceToArmSMELowering
}
};
+/// Lower `arm_sme.move_tile_slice_to_vector` to SME intrinsics. Only horizontal
+/// tile slices are currently supported.
+struct MoveTileSliceToVectorArmSMELowering
+ : public ConvertOpToLLVMPattern<arm_sme::MoveTileSliceToVectorOp> {
+ using ConvertOpToLLVMPattern<
+ arm_sme::MoveTileSliceToVectorOp>::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(arm_sme::MoveTileSliceToVectorOp moveTileSliceToVector,
+ OpAdaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = moveTileSliceToVector.getLoc();
+ auto sliceType = moveTileSliceToVector.getResult().getType();
+ auto tile = moveTileSliceToVector.getTile();
+ auto sliceIndex = moveTileSliceToVector.getTileSliceIndex();
+
+ // Cast tile to i32 tile ID.
+ auto tileId = rewriter.create<arm_sme::CastVectorToTile>(loc, tile);
+ Value tileIdI32 = castTileIDToI32(tileId, loc, rewriter);
+
+ // Create an 'all true' predicate for the tile slice.
+ auto predicateType = sliceType.cloneWith({}, rewriter.getI1Type());
+ auto allTruePredicate = rewriter.create<arith::ConstantOp>(
+ loc, DenseElementsAttr::get(predicateType, true));
+
+ // Zero destination/fallback for tile slice extraction.
+ auto zeroVector = rewriter.create<arith::ConstantOp>(
+ loc, sliceType, rewriter.getZeroAttr(sliceType));
+
+ // Cast tile slice from index to i32 for intrinsic.
+ auto sliceIndexI32 = rewriter.create<arith::IndexCastOp>(
+ loc, rewriter.getI32Type(), sliceIndex);
+
+ // Create 'arm_sme.intr.read.horiz' to extract the tile slice.
+ rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_read_horiz>(
+ moveTileSliceToVector, sliceType, zeroVector, allTruePredicate,
+ tileIdI32, sliceIndexI32);
+
+ return success();
+ }
+};
+
/// Lower `vector.outerproduct` to SME MOPA intrinsics.
///
/// Example:
@@ -525,9 +567,9 @@ void mlir::configureArmSMELegalizeForExportTarget(
arm_sme::aarch64_sme_ld1d_vert, arm_sme::aarch64_sme_ld1q_vert,
arm_sme::aarch64_sme_st1b_vert, arm_sme::aarch64_sme_st1h_vert,
arm_sme::aarch64_sme_st1w_vert, arm_sme::aarch64_sme_st1d_vert,
- arm_sme::aarch64_sme_st1q_vert, arm_sme::aarch64_sme_write_horiz,
- arm_sme::aarch64_sme_mopa, arm_sme::aarch64_sme_za_enable,
- arm_sme::aarch64_sme_za_disable>();
+ arm_sme::aarch64_sme_st1q_vert, arm_sme::aarch64_sme_read_horiz,
+ arm_sme::aarch64_sme_write_horiz, arm_sme::aarch64_sme_mopa,
+ arm_sme::aarch64_sme_za_enable, arm_sme::aarch64_sme_za_disable>();
target.addLegalOp<GetTileID>();
target.addIllegalOp<vector::OuterProductOp>();
@@ -561,6 +603,7 @@ void mlir::populateArmSMELegalizeForLLVMExportPatterns(
patterns.add<EnableZAPattern, DisableZAPattern>(patterns.getContext());
patterns
.add<ZeroOpConversion, StoreTileSliceToArmSMELowering,
- LoadTileSliceToArmSMELowering, MoveVectorToTileSliceToArmSMELowering,
+ LoadTileSliceToArmSMELowering, MoveTileSliceToVectorArmSMELowering,
+ MoveVectorToTileSliceToArmSMELowering,
VectorOuterProductToArmSMELowering>(converter);
}
diff --git a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
index b287a00171d9bd5..09f148bcd42f593 100644
--- a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
+++ b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
@@ -66,15 +66,11 @@ func.func @arm_sme_tile_print(%tile: vector<[4]x[4]xf32>)
}
// CHECK-LABEL: func.func @arm_sme_tile_print(
// CHECK-SAME: %[[TILE:.*]]: vector<[4]x[4]xf32>) {
-// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
-// CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
-// CHECK-DAG: %[[PTRUE:.*]] = arith.constant dense<true> : vector<[4]xi1>
-// CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[4]x[4]xf32> to i32
-// CHECK-DAG: %[[ZERO_VECTOR:.*]] = arith.constant dense<0.000000e+00> : vector<[4]xf32>
-// CHECK-DAG: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+// CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
+// CHECK-DAG: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index
// CHECK-NEXT: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] {
-// CHECK-NEXT: %[[TILE_SLICE_INDEX_I32:.*]] = arith.index_cast %[[TILE_SLICE_INDEX]] : index to i32
-// CHECK-NEXT: %[[TILE_SLICE:.*]] = "arm_sme.intr.read.horiz"(%[[ZERO_VECTOR]], %[[PTRUE]], %[[TILE_ID]], %[[TILE_SLICE_INDEX_I32]]) : (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xf32>
+// CHECK-NEXT: %[[TILE_SLICE:.*]] = arm_sme.move_tile_slice_to_vector %[[TILE]][%[[TILE_SLICE_INDEX]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
// CHECK-NEXT: vector.print %[[TILE_SLICE]] : vector<[4]xf32>
diff --git a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
index a6b5217181481de..4c16e5c488a74cd 100644
--- a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
+++ b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
@@ -399,3 +399,89 @@ func.func @arm_sme_store_tile_slice_ver_f64(%tile : vector<[2]x[2]xf64>, %tile_s
arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0], <vertical> : memref<?x?xf64>, vector<[2]x[2]xf64>
return
}
+
+
+//===----------------------------------------------------------------------===//
+// arm_sme.move_tile_slice_to_vector
+//===----------------------------------------------------------------------===//
+
+// -----
+
+// CHECK-LABEL: @arm_sme_move_tile_slice_to_vector_i8
+// CHECK: "arm_sme.intr.read.horiz"({{.*}}) : (vector<[16]xi8>, vector<[16]xi1>, i32, i32) -> vector<[16]xi8>
+func.func @arm_sme_move_tile_slice_to_vector_i8(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index) -> vector<[16]xi8> {
+ %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[16]xi8> from vector<[16]x[16]xi8>
+ return %slice : vector<[16]xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_move_tile_slice_to_vector_i16
+// CHECK: "arm_sme.intr.read.horiz"({{.*}}) : (vector<[8]xi16>, vector<[8]xi1>, i32, i32) -> vector<[8]xi16>
+func.func @arm_sme_move_tile_slice_to_vector_i16(%tile : vector<[8]x[8]xi16>, %tile_slice_index : index) -> vector<[8]xi16> {
+ %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[8]xi16> from vector<[8]x[8]xi16>
+ return %slice : vector<[8]xi16>
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_move_tile_slice_to_vector_i32
+// CHECK: "arm_sme.intr.read.horiz"({{.*}}) : (vector<[4]xi32>, vector<[4]xi1>, i32, i32) -> vector<[4]xi32>
+func.func @arm_sme_move_tile_slice_to_vector_i32(%tile : vector<[4]x[4]xi32>, %tile_slice_index : index) -> vector<[4]xi32> {
+ %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[4]xi32> from vector<[4]x[4]xi32>
+ return %slice : vector<[4]xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_move_tile_slice_to_vector_i64
+// CHECK: "arm_sme.intr.read.horiz"({{.*}}) : (vector<[2]xi64>, vector<[2]xi1>, i32, i32) -> vector<[2]xi64>
+func.func @arm_sme_move_tile_slice_to_vector_i64(%tile : vector<[2]x[2]xi64>, %tile_slice_index : index) -> vector<[2]xi64> {
+ %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[2]xi64> from vector<[2]x[2]xi64>
+ return %slice : vector<[2]xi64>
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_move_tile_slice_to_vector_i128
+// CHECK: "arm_sme.intr.read.horiz"({{.*}}) : (vector<[1]xi128>, vector<[1]xi1>, i32, i32) -> vector<[1]xi128>
+func.func @arm_sme_move_tile_slice_to_vector_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index) -> vector<[1]xi128> {
+ %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[1]xi128> from vector<[1]x[1]xi128>
+ return %slice : vector<[1]xi128>
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_move_tile_slice_to_vector_f16
+// CHECK: "arm_sme.intr.read.horiz"({{.*}}) : (vector<[8]xf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xf16>
+func.func @arm_sme_move_tile_slice_to_vector_f16(%tile : vector<[8]x[8]xf16>, %tile_slice_index : index) -> vector<[8]xf16> {
+ %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[8]xf16> from vector<[8]x[8]xf16>
+ return %slice : vector<[8]xf16>
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_move_tile_slice_to_vector_bf16
+// CHECK: "arm_sme.intr.read.horiz"({{.*}}) : (vector<[8]xbf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xbf16>
+func.func @arm_sme_move_tile_slice_to_vector_bf16(%tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) -> vector<[8]xbf16> {
+ %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[8]xbf16> from vector<[8]x[8]xbf16>
+ return %slice : vector<[8]xbf16>
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_move_tile_slice_to_vector_f32
+// CHECK: "arm_sme.intr.read.horiz"({{.*}}) : (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xf32>
+func.func @arm_sme_move_tile_slice_to_vector_f32(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index) -> vector<[4]xf32> {
+ %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[4]xf32> from vector<[4]x[4]xf32>
+ return %slice : vector<[4]xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_move_tile_slice_to_vector_f64
+// CHECK: "arm_sme.intr.read.horiz"({{.*}}) : (vector<[2]xf64>, vector<[2]xi1>, i32, i32) -> vector<[2]xf64>
+func.func @arm_sme_move_tile_slice_to_vector_f64(%tile : vector<[2]x[2]xf64>, %tile_slice_index : index) -> vector<[2]xf64> {
+ %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[2]xf64> from vector<[2]x[2]xf64>
+ return %slice : vector<[2]xf64>
+}
diff --git a/mlir/test/Dialect/ArmSME/invalid.mlir b/mlir/test/Dialect/ArmSME/invalid.mlir
index f529b47cb98c3d9..431009b1b9ede2f 100644
--- a/mlir/test/Dialect/ArmSME/invalid.mlir
+++ b/mlir/test/Dialect/ArmSME/invalid.mlir
@@ -89,3 +89,11 @@ func.func @arm_sme_move_vector_to_tile_slice_f32__bad_vector_type(%vector : vect
%0 = arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index : vector<[8]xf32> into vector<[4]x[4]xf32>
return %0 : vector<[4]x[4]xf32>
}
+
+// -----
+
+func.func @arm_sme_move_tile_slice_to_vector__bad_result_type(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index) -> vector<[2]xf64> {
+ // expected-error at +1 {{op failed to verify that type of 'result' matches type of 'tile' slice}}
+ %0 = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[2]xf64> from vector<[4]x[4]xf32>
+ return %0 : vector<[2]xf64>
+}
diff --git a/mlir/test/Dialect/ArmSME/roundtrip.mlir b/mlir/test/Dialect/ArmSME/roundtrip.mlir
index 3232bad7996b486..f6d19359b8e3af8 100644
--- a/mlir/test/Dialect/ArmSME/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSME/roundtrip.mlir
@@ -1058,3 +1058,80 @@ func.func @arm_sme_move_vector_to_tile_slice_f64(%vector : vector<[2]xf64>, %til
arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index : vector<[2]xf64> into vector<[2]x[2]xf64>
return
}
+
+
+//===----------------------------------------------------------------------===//
+// arm_sme.move_tile_slice_to_vector
+//===----------------------------------------------------------------------===//
+
+// -----
+
+func.func @arm_sme_move_tile_slice_to_vector_i8(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index) -> vector<[16]xi8> {
+ // CHECK: arm_sme.move_tile_slice_to_vector {{.*}} : vector<[16]xi8> from vector<[16]x[16]xi8>
+ %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[16]xi8> from vector<[16]x[16]xi8>
+ return %slice : vector<[16]xi8>
+}
+
+// -----
+
+func.func @arm_sme_move_tile_slice_to_vector_i16(%tile : vector<[8]x[8]xi16>, %tile_slice_index : index) -> vector<[8]xi16> {
+ // CHECK: arm_sme.move_tile_slice_to_vector {{.*}} : vector<[8]xi16> from vector<[8]x[8]xi16>
+ %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[8]xi16> from vector<[8]x[8]xi16>
+ return %slice : vector<[8]xi16>
+}
+
+// -----
+
+func.func @arm_sme_move_tile_slice_to_vector_i32(%tile : vector<[4]x[4]xi32>, %tile_slice_index : index) -> vector<[4]xi32> {
+ // CHECK: arm_sme.move_tile_slice_to_vector {{.*}} : vector<[4]xi32> from vector<[4]x[4]xi32>
+ %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[4]xi32> from vector<[4]x[4]xi32>
+ return %slice : vector<[4]xi32>
+}
+
+// -----
+
+func.func @arm_sme_move_tile_slice_to_vector_i64(%tile : vector<[2]x[2]xi64>, %tile_slice_index : index) -> vector<[2]xi64> {
+ // CHECK: arm_sme.move_tile_slice_to_vector {{.*}} : vector<[2]xi64> from vector<[2]x[2]xi64>
+ %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[2]xi64> from vector<[2]x[2]xi64>
+ return %slice : vector<[2]xi64>
+}
+
+// -----
+
+func.func @arm_sme_move_tile_slice_to_vector_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index) -> vector<[1]xi128> {
+ // CHECK: arm_sme.move_tile_slice_to_vector {{.*}} : vector<[1]xi128> from vector<[1]x[1]xi128>
+ %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[1]xi128> from vector<[1]x[1]xi128>
+ return %slice : vector<[1]xi128>
+}
+
+// -----
+
+func.func @arm_sme_move_tile_slice_to_vector_f16(%tile : vector<[8]x[8]xf16>, %tile_slice_index : index) -> vector<[8]xf16> {
+ // CHECK: arm_sme.move_tile_slice_to_vector {{.*}} : vector<[8]xf16> from vector<[8]x[8]xf16>
+ %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[8]xf16> from vector<[8]x[8]xf16>
+ return %slice : vector<[8]xf16>
+}
+
+// -----
+
+func.func @arm_sme_move_tile_slice_to_vector_bf16(%tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) -> vector<[8]xbf16> {
+ // CHECK: arm_sme.move_tile_slice_to_vector {{.*}} : vector<[8]xbf16> from vector<[8]x[8]xbf16>
+ %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[8]xbf16> from vector<[8]x[8]xbf16>
+ return %slice : vector<[8]xbf16>
+}
+
+// -----
+
+func.func @arm_sme_move_tile_slice_to_vector_f32(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index) -> vector<[4]xf32> {
+ // CHECK: arm_sme.move_tile_slice_to_vector {{.*}} : vector<[4]xf32> from vector<[4]x[4]xf32>
+ %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[4]xf32> from vector<[4]x[4]xf32>
+ return %slice : vector<[4]xf32>
+}
+
+// -----
+
+func.func @arm_sme_move_tile_slice_to_vector_f64(%tile : vector<[2]x[2]xf64>, %tile_slice_index : index) -> vector<[2]xf64> {
+ // CHECK: arm_sme.move_tile_slice_to_vector {{.*}} : vector<[2]xf64> from vector<[2]x[2]xf64>
+ %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[2]xf64> from vector<[2]x[2]xf64>
+ return %slice : vector<[2]xf64>
+}
>From 097ffbfb6d999b6e24fff6d7053e3148357ab109 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 28 Sep 2023 12:16:41 +0000
Subject: [PATCH 2/4] Fixup: Add getSliceType()
---
mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td | 4 ++++
mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp | 2 +-
2 files changed, 5 insertions(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
index 10b2ae32b06dd37..4fd92dbbbf3391a 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
@@ -528,6 +528,10 @@ def MoveTileSliceToVectorOp : ArmSME_Op<"move_tile_slice_to_vector", [Pure,
let arguments = (ins SMETile:$tile, Index:$tile_slice_index);
let results = (outs SVEVector:$result);
+ let extraClassDeclaration = [{
+ VectorType getSliceType() { return getResult().getType(); }
+ }];
+
let assemblyFormat = [{
$tile `[` $tile_slice_index `]` attr-dict
`:` type($result) `from` type($tile)
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
index 7157a915ee8a02a..0322c2f3fcd14d4 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
@@ -414,7 +414,7 @@ struct MoveTileSliceToVectorArmSMELowering
OpAdaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = moveTileSliceToVector.getLoc();
- auto sliceType = moveTileSliceToVector.getResult().getType();
+ auto sliceType = moveTileSliceToVector.getSliceType();
auto tile = moveTileSliceToVector.getTile();
auto sliceIndex = moveTileSliceToVector.getTileSliceIndex();
>From 54480245f4260f0dcce3fd51692d1cda4d66e6ae Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <macdue at dueutil.tech>
Date: Thu, 28 Sep 2023 16:14:11 +0100
Subject: [PATCH 3/4] Update mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Co-authored-by: Andrzej WarzyĆski <andrzej.warzynski at gmail.com>
---
mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
index 4fd92dbbbf3391a..66a432ea1b171e0 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
@@ -510,7 +510,7 @@ def MoveTileSliceToVectorOp : ArmSME_Op<"move_tile_slice_to_vector", [Pure,
let summary = "Move slice of a 2-D tile to a 1-D scalable vector";
let description = [{
The tile slice to vector operation extracts a 1-D scalable slice from a 2-D
- scalable tile at the give index. A tile slice is a 1-D vector of
+ scalable tile at the given index. A tile slice is a 1-D vector of
horizontally or vertically contiguous elements within a ZA tile. Horizontal
tile slices are currently assumed when lowering to intrinsics.
>From 94842d717648a4742c22dd98a0de2e7817e93abd Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 28 Sep 2023 15:18:33 +0000
Subject: [PATCH 4/4] Fixup comment
---
mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index ec625a6806e2bee..881cc8575fb4824 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -191,7 +191,8 @@ struct TileStoreOpConversion : public OpRewritePattern<arm_sme::TileStoreOp> {
};
/// Lowers `vector.print` of a tile into a loop over the rows of the tile,
-/// extracting them via a MOVA, then printing with a 1D `vector.print`.
+/// extracting them via `arm_sme.move_tile_slice_to_vector`, then printing with
+/// a 1D `vector.print`.
///
/// BEFORE:
/// ```mlir
More information about the Mlir-commits
mailing list