[Mlir-commits] [mlir] [mlir][ArmSME] Lower vector.extract/insert on SME tiles to MOVA intrinsics (PR #67786)
Benjamin Maxwell
llvmlistbot at llvm.org
Tue Oct 3 10:29:39 PDT 2023
https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/67786
>From bab01e1075e8698dd0e8c1e52338f488733d258e Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 28 Sep 2023 10:53:23 +0000
Subject: [PATCH 1/4] [mlir][ArmSME] Lower vector.extract/insert on SME tiles
to MOVA intrinsics
This patch adds support for lowering vector.insert/extract of tile
slices or elements to ArmSME MOVA intrinsic.
For example:
```
// Extract slice from tile:
%slice = vector.extract %tile[%y]: vector<[4]x[4]xi32>
// Extract element from tile:
%el = vector.extract %tile[%y,%x]: vector<[4]x[4]xi32>
// Insert slice into tile:
%new_tile = vector.insert %slice, %tile[%y]
: vector<[4]xi32> into vector<[4]x[4]xi32>
// Insert element into tile;
%new_tile = vector.insert %el, %tile[%y,%x]
: i32 into vector<[4]x[4]xi32>
```
---
.../Transforms/LegalizeForLLVMExport.cpp | 110 +++++++++++++++++-
.../Dialect/ArmSME/vector-ops-to-llvm.mlir | 78 +++++++++++++
2 files changed, 187 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
index e75e958e18a2cfd..3f7be30172878e7 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
@@ -550,6 +550,113 @@ struct VectorOuterProductToArmSMELowering
}
};
+/// Lower `vector.extract` using SME MOVA intrinsics.
+///
+/// Example:
+/// ```
+/// %el = vector.extract %tile[%y,%x]: i32 from vector<[4]x[4]xi32>
+/// ```
+/// Becomes:
+/// ```
+/// %slice = arm_sme.move_tile_slice_to_vector %tile[%y]
+/// : vector<[4]xi32> from vector<[4]x[4]xi32>
+/// %el = vector.extract %slice[%x] : i32 from vector<[4]xi32>
+/// ```
+struct VectorExtractToArmSMELowering
+ : public ConvertOpToLLVMPattern<vector::ExtractOp> {
+ using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ VectorType sourceType = extractOp.getSourceVectorType();
+ if (!isValidSMETileVectorType(sourceType))
+ return failure();
+
+ auto loc = extractOp.getLoc();
+ auto position = extractOp.getMixedPosition();
+
+ Value sourceVector = extractOp.getVector();
+
+ if (position.empty()) {
+ rewriter.replaceOp(extractOp, sourceVector);
+ return success();
+ }
+
+ Value sliceIndex = vector::getAsValues(rewriter, loc, position[0])[0];
+ auto moveTileSliceToVector =
+ rewriter.create<arm_sme::MoveTileSliceToVectorOp>(loc, sourceVector,
+ sliceIndex);
+
+ if (position.size() == 1) {
+ // Single index case: Extracts a 1D slice.
+ rewriter.replaceOp(extractOp, moveTileSliceToVector);
+ return success();
+ }
+
+ // Two indices case: Extracts a single element.
+ assert(position.size() == 2);
+ rewriter.replaceOpWithNewOp<vector::ExtractOp>(
+ extractOp, moveTileSliceToVector, position[1]);
+
+ return success();
+ }
+};
+
+/// Lower `vector.insert` using SME MOVA intrinsics.
+///
+/// Example:
+/// ```
+/// %new_tile = vector.insert %el, %tile[%y,%x] : i32 into vector<[4]x[4]xi32>
+/// ```
+/// Becomes:
+/// ```
+/// %slice = arm_sme.move_tile_slice_to_vector %tile[%y]
+/// : vector<[4]xi32> from vector<[4]x[4]xi32>
+/// %new_slice = vector.insert %el, %slice[%x] : i32 into vector<[4]xi32>
+/// %new_tile = arm_sme.move_vector_to_tile_slice %new_slice, %tile, %y
+/// : vector<[4]xi32> into vector<[4]x[4]xi32>
+/// ```
+struct VectorInsertToArmSMELowering
+ : public ConvertOpToLLVMPattern<vector::InsertOp> {
+ using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ VectorType resultType = insertOp.getResult().getType();
+
+ if (!isValidSMETileVectorType(resultType))
+ return failure();
+
+ auto loc = insertOp.getLoc();
+ auto position = insertOp.getMixedPosition();
+
+ Value source = adaptor.getSource();
+
+ if (position.empty()) {
+ rewriter.replaceOp(insertOp, source);
+ return success();
+ }
+
+ Value tileSlice = source;
+ Value sliceIndex = vector::getAsValues(rewriter, loc, position[0])[0];
+ if (position.size() == 2) {
+ // Two indices case: Insert signle element into tile.
+ // We need to first extract the existing slice and update the element.
+ tileSlice = rewriter.create<arm_sme::MoveTileSliceToVectorOp>(
+ loc, adaptor.getDest(), sliceIndex);
+ tileSlice = rewriter.create<vector::InsertOp>(loc, source, tileSlice,
+ position[1]);
+ }
+
+ // Insert the slice into the destination tile.
+ rewriter.replaceOpWithNewOp<arm_sme::MoveVectorToTileSliceOp>(
+ insertOp, tileSlice, adaptor.getDest(), sliceIndex);
+ return success();
+ }
+};
+
} // namespace
void mlir::configureArmSMELegalizeForExportTarget(
@@ -604,5 +711,6 @@ void mlir::populateArmSMELegalizeForLLVMExportPatterns(
patterns.add<
LoadTileSliceToArmSMELowering, MoveTileSliceToVectorArmSMELowering,
MoveVectorToTileSliceToArmSMELowering, StoreTileSliceToArmSMELowering,
- VectorOuterProductToArmSMELowering, ZeroOpConversion>(converter);
+ VectorOuterProductToArmSMELowering, ZeroOpConversion,
+ VectorExtractToArmSMELowering, VectorInsertToArmSMELowering>(converter);
}
diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
index 687ef79385334cf..d455ac0dce22b1a 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
@@ -496,3 +496,81 @@ func.func @vector_outerproduct_add_masked_f32(%lhs : vector<[4]xf32>, %rhs : vec
%0 = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
"prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> ()
}
+
+//===----------------------------------------------------------------------===//
+// vector.insert
+//===----------------------------------------------------------------------===//
+
+// -----
+
+// CHECK-LABEL: @vector_insert_slice(
+// CHECK-SAME: %[[TILE:.*]]: vector<[4]x[4]xi32>,
+// CHECK-SAME: %[[SLICE:.*]]: vector<[4]xi32>,
+// CHECK-SAME: %[[INDEX:.*]]: index)
+func.func @vector_insert_slice(%tile: vector<[4]x[4]xi32>, %slice: vector<[4]xi32>, %y: index) -> vector<[4]x[4]xi32>{
+ // CHECK-NEXT: %[[PTRUE:.*]] = arith.constant dense<true> : vector<[4]xi1>
+ // CHECK-NEXT: %[[TILE_ID:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[4]x[4]xi32> to i32
+ // CHECK-NEXT: %[[TILE_SLICE_INDEX:.*]] = arith.index_castui %[[INDEX]] : index to i32
+ // CHECK-NEXT: "arm_sme.intr.write.horiz"(%[[TILE_ID]], %[[TILE_SLICE_INDEX]], %[[PTRUE]], %[[SLICE]]) : (i32, i32, vector<[4]xi1>, vector<[4]xi32>) -> ()
+ %new_tile = vector.insert %slice, %tile[%y] : vector<[4]xi32> into vector<[4]x[4]xi32>
+ return %new_tile : vector<[4]x[4]xi32>
+}
+
+// -----
+
+
+// CHECK-LABEL: @vector_insert_element(
+// CHECK-SAME: %[[TILE:.*]]: vector<[4]x[4]xi32>,
+// CHECK-SAME: %[[EL:.*]]: i32,
+// CHECK-SAME: %[[ROW:.*]]: index,
+// CHECK-SAME: %[[COL:.*]]: index)
+func.func @vector_insert_element(%tile: vector<[4]x[4]xi32>, %el: i32, %y: index, %x: index) -> vector<[4]x[4]xi32> {
+ // CHECK-NEXT: %[[ZERO_VEC:.*]] = arith.constant dense<0> : vector<[4]xi32>
+ // CHECK-NEXT: %[[PTRUE:.*]] = arith.constant dense<true> : vector<[4]xi1>
+ // CHECK-NEXT: %[[COL_I32:.*]] = builtin.unrealized_conversion_cast %[[COL]] : index to i64
+ // CHECK-NEXT: %[[TILE_ID:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[4]x[4]xi32> to i32
+ // CHECK-NEXT: %[[ROW_I32:.*]] = arith.index_cast %[[ROW]] : index to i32
+ // CHECK-NEXT: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%[[ZERO_VEC]], %[[PTRUE]], %[[TILE_ID]], %[[ROW_I32]]) : (vector<[4]xi32>, vector<[4]xi1>, i32, i32) -> vector<[4]xi32>
+ // CHECK-NEXT: %[[NEW_SLICE:.*]] = llvm.insertelement %[[EL]], %[[SLICE]]{{\[}}%[[COL_I32]] : i64] : vector<[4]xi32>
+ // CHECK-NEXT: %[[SLICE_INDEX:.*]] = arith.index_castui %[[ROW]] : index to i32
+ // CHECK-NEXT: "arm_sme.intr.write.horiz"(%[[TILE_ID]], %[[SLICE_INDEX]], %[[PTRUE]], %[[NEW_SLICE]]) : (i32, i32, vector<[4]xi1>, vector<[4]xi32>) -> ()
+ %new_tile = vector.insert %el, %tile[%y,%x] : i32 into vector<[4]x[4]xi32>
+ return %new_tile : vector<[4]x[4]xi32>
+}
+
+//===----------------------------------------------------------------------===//
+// vector.extract
+//===----------------------------------------------------------------------===//
+
+// -----
+
+// CHECK-LABEL: @vector_extract_slice(
+// CHECK-SAME: %[[TILE:.*]]: vector<[4]x[4]xi32>,
+// CHECK-SAME: %[[INDEX:.*]]: index)
+func.func @vector_extract_slice(%tile: vector<[4]x[4]xi32>, %y: index) -> vector<[4]xi32> {
+ // CHECK-NEXT: %[[ZERO_VEC:.*]] = arith.constant dense<0> : vector<[4]xi32>
+ // CHECK-NEXT: %[[PTRUE:.*]] = arith.constant dense<true> : vector<[4]xi1>
+ // CHECK-NEXT: %[[TILE_ID:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[4]x[4]xi32> to i32
+ // CHECK-NEXT: %[[TILE_SLICE_INDEX:.*]] = arith.index_cast %[[INDEX]] : index to i32
+ // CHECK-NEXT: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%[[ZERO_VEC]], %[[PTRUE]], %[[TILE_ID]], %[[TILE_SLICE_INDEX]]) : (vector<[4]xi32>, vector<[4]xi1>, i32, i32) -> vector<[4]xi32>
+ %slice = vector.extract %tile[%y] : vector<[4]xi32> from vector<[4]x[4]xi32>
+ return %slice : vector<[4]xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @vector_extract_element(
+// CHECK-SAME: %[[TILE:.*]]: vector<[4]x[4]xi32>,
+// CHECK-SAME: %[[ROW:.*]]: index,
+// CHECK-SAME: %[[COL:.*]]: index)
+func.func @vector_extract_element(%tile: vector<[4]x[4]xi32>, %y: index, %x: index) -> i32 {
+ // CHECK-NEXT: %[[ZERO_VEC:.*]] = arith.constant dense<0> : vector<[4]xi32>
+ // CHECK-NEXT: %[[PTRUE:.*]] = arith.constant dense<true> : vector<[4]xi1>
+ // CHECK-NEXT: %[[COL_I32:.*]] = builtin.unrealized_conversion_cast %[[COL]] : index to i64
+ // CHECK-NEXT: %[[TILE_ID:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[4]x[4]xi32> to i32
+ // CHECK-NEXT: %[[ROW_I32:.*]] = arith.index_cast %[[ROW]] : index to i32
+ // CHECK-NEXT: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%[[ZERO_VEC]], %[[PTRUE]], %[[TILE_ID]], %[[ROW_I32]]) : (vector<[4]xi32>, vector<[4]xi1>, i32, i32) -> vector<[4]xi32>
+ // CHECK-NEXT: %[[EL:.*]] = llvm.extractelement %[[SLICE]]{{\[}}%[[COL_I32]] : i64] : vector<[4]xi32>
+ %el = vector.extract %tile[%y,%x] : i32 from vector<[4]x[4]xi32>
+ return %el : i32
+}
>From f7d0806a6c158e87f24e55c6e2418629a3e735c0 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Mon, 2 Oct 2023 11:32:41 +0000
Subject: [PATCH 2/4] Add simplifed tests for other element types
---
.../Dialect/ArmSME/vector-ops-to-llvm.mlir | 331 +++++++++++++++++-
1 file changed, 320 insertions(+), 11 deletions(-)
diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
index d455ac0dce22b1a..a987f9612fb69a8 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
@@ -503,28 +503,105 @@ func.func @vector_outerproduct_add_masked_f32(%lhs : vector<[4]xf32>, %rhs : vec
// -----
-// CHECK-LABEL: @vector_insert_slice(
+// CHECK-LABEL: @vector_insert_slice_i32(
// CHECK-SAME: %[[TILE:.*]]: vector<[4]x[4]xi32>,
// CHECK-SAME: %[[SLICE:.*]]: vector<[4]xi32>,
// CHECK-SAME: %[[INDEX:.*]]: index)
-func.func @vector_insert_slice(%tile: vector<[4]x[4]xi32>, %slice: vector<[4]xi32>, %y: index) -> vector<[4]x[4]xi32>{
+func.func @vector_insert_slice_i32(%tile: vector<[4]x[4]xi32>, %slice: vector<[4]xi32>, %row: index) -> vector<[4]x[4]xi32>{
// CHECK-NEXT: %[[PTRUE:.*]] = arith.constant dense<true> : vector<[4]xi1>
// CHECK-NEXT: %[[TILE_ID:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[4]x[4]xi32> to i32
// CHECK-NEXT: %[[TILE_SLICE_INDEX:.*]] = arith.index_castui %[[INDEX]] : index to i32
// CHECK-NEXT: "arm_sme.intr.write.horiz"(%[[TILE_ID]], %[[TILE_SLICE_INDEX]], %[[PTRUE]], %[[SLICE]]) : (i32, i32, vector<[4]xi1>, vector<[4]xi32>) -> ()
- %new_tile = vector.insert %slice, %tile[%y] : vector<[4]xi32> into vector<[4]x[4]xi32>
+ %new_tile = vector.insert %slice, %tile[%row] : vector<[4]xi32> into vector<[4]x[4]xi32>
return %new_tile : vector<[4]x[4]xi32>
}
// -----
+// CHECK-LABEL: @vector_insert_slice_i128
+func.func @vector_insert_slice_i128(%tile: vector<[1]x[1]xi128>, %slice: vector<[1]xi128>, %row: index) -> vector<[1]x[1]xi128> {
+ // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[1]xi1>, vector<[1]xi128>) -> ()
+ %new_tile = vector.insert %slice, %tile[%row] : vector<[1]xi128> into vector<[1]x[1]xi128>
+ return %new_tile : vector<[1]x[1]xi128>
+}
+
+
+// -----
+
+// CHECK-LABEL: @vector_insert_slice_i64
+func.func @vector_insert_slice_i64(%tile: vector<[2]x[2]xi64>, %slice: vector<[2]xi64>, %row: index) -> vector<[2]x[2]xi64> {
+ // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[2]xi1>, vector<[2]xi64>) -> ()
+ %new_tile = vector.insert %slice, %tile[%row] : vector<[2]xi64> into vector<[2]x[2]xi64>
+ return %new_tile : vector<[2]x[2]xi64>
+}
+
+
+// -----
+
+// CHECK-LABEL: @vector_insert_slice_f64
+func.func @vector_insert_slice_f64(%tile: vector<[2]x[2]xf64>, %slice: vector<[2]xf64>, %row: index) -> vector<[2]x[2]xf64> {
+ // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[2]xi1>, vector<[2]xf64>) -> ()
+ %new_tile = vector.insert %slice, %tile[%row] : vector<[2]xf64> into vector<[2]x[2]xf64>
+ return %new_tile : vector<[2]x[2]xf64>
+}
+
+// -----
+
+// CHECK-LABEL: @vector_insert_slice_f32
+func.func @vector_insert_slice_f32(%tile: vector<[4]x[4]xf32>, %slice: vector<[4]xf32>, %row: index) -> vector<[4]x[4]xf32> {
+ // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[4]xi1>, vector<[4]xf32>) -> ()
+ %new_tile = vector.insert %slice, %tile[%row] : vector<[4]xf32> into vector<[4]x[4]xf32>
+ return %new_tile : vector<[4]x[4]xf32>
+}
+
+
+// -----
+
+// CHECK-LABEL: @vector_insert_slice_i16
+func.func @vector_insert_slice_i16(%tile: vector<[8]x[8]xi16>, %slice: vector<[8]xi16>, %row: index) -> vector<[8]x[8]xi16> {
+ // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[8]xi1>, vector<[8]xi16>) -> ()
+ %new_tile = vector.insert %slice, %tile[%row] : vector<[8]xi16> into vector<[8]x[8]xi16>
+ return %new_tile : vector<[8]x[8]xi16>
+}
+
+
+// -----
+
+// CHECK-LABEL: @vector_insert_slice_f16
+func.func @vector_insert_slice_f16(%tile: vector<[8]x[8]xf16>, %slice: vector<[8]xf16>, %row: index) -> vector<[8]x[8]xf16> {
+ // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[8]xi1>, vector<[8]xf16>) -> ()
+ %new_tile = vector.insert %slice, %tile[%row] : vector<[8]xf16> into vector<[8]x[8]xf16>
+ return %new_tile : vector<[8]x[8]xf16>
+}
+
-// CHECK-LABEL: @vector_insert_element(
+// -----
+
+// CHECK-LABEL: @vector_insert_slice_bf16
+func.func @vector_insert_slice_bf16(%tile: vector<[8]x[8]xbf16>, %slice: vector<[8]xbf16>, %row: index) -> vector<[8]x[8]xbf16> {
+ // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[8]xi1>, vector<[8]xbf16>) -> ()
+ %new_tile = vector.insert %slice, %tile[%row] : vector<[8]xbf16> into vector<[8]x[8]xbf16>
+ return %new_tile : vector<[8]x[8]xbf16>
+}
+
+
+// -----
+
+// CHECK-LABEL: @vector_insert_slice_i8
+func.func @vector_insert_slice_i8(%tile: vector<[16]x[16]xi8>, %slice: vector<[16]xi8>, %row: index) -> vector<[16]x[16]xi8> {
+ // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[16]xi1>, vector<[16]xi8>) -> ()
+ %new_tile = vector.insert %slice, %tile[%row] : vector<[16]xi8> into vector<[16]x[16]xi8>
+ return %new_tile : vector<[16]x[16]xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @vector_insert_element_i32(
// CHECK-SAME: %[[TILE:.*]]: vector<[4]x[4]xi32>,
// CHECK-SAME: %[[EL:.*]]: i32,
// CHECK-SAME: %[[ROW:.*]]: index,
// CHECK-SAME: %[[COL:.*]]: index)
-func.func @vector_insert_element(%tile: vector<[4]x[4]xi32>, %el: i32, %y: index, %x: index) -> vector<[4]x[4]xi32> {
+func.func @vector_insert_element_i32(%tile: vector<[4]x[4]xi32>, %el: i32, %row: index, %col: index) -> vector<[4]x[4]xi32> {
// CHECK-NEXT: %[[ZERO_VEC:.*]] = arith.constant dense<0> : vector<[4]xi32>
// CHECK-NEXT: %[[PTRUE:.*]] = arith.constant dense<true> : vector<[4]xi1>
// CHECK-NEXT: %[[COL_I32:.*]] = builtin.unrealized_conversion_cast %[[COL]] : index to i64
@@ -534,36 +611,188 @@ func.func @vector_insert_element(%tile: vector<[4]x[4]xi32>, %el: i32, %y: index
// CHECK-NEXT: %[[NEW_SLICE:.*]] = llvm.insertelement %[[EL]], %[[SLICE]]{{\[}}%[[COL_I32]] : i64] : vector<[4]xi32>
// CHECK-NEXT: %[[SLICE_INDEX:.*]] = arith.index_castui %[[ROW]] : index to i32
// CHECK-NEXT: "arm_sme.intr.write.horiz"(%[[TILE_ID]], %[[SLICE_INDEX]], %[[PTRUE]], %[[NEW_SLICE]]) : (i32, i32, vector<[4]xi1>, vector<[4]xi32>) -> ()
- %new_tile = vector.insert %el, %tile[%y,%x] : i32 into vector<[4]x[4]xi32>
+ %new_tile = vector.insert %el, %tile[%row, %col] : i32 into vector<[4]x[4]xi32>
return %new_tile : vector<[4]x[4]xi32>
}
+// -----
+
+// CHECK-LABEL: @vector_insert_element_i128
+func.func @vector_insert_element_i128(%tile: vector<[1]x[1]xi128>, %el: i128, %row: index, %col: index) -> vector<[1]x[1]xi128> {
+ // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[1]xi128>, vector<[1]xi1>, i32, i32) -> vector<[1]xi128>
+ // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[1]xi1>, vector<[1]xi128>) -> ()
+ %new_tile = vector.insert %el, %tile[%row, %col] : i128 into vector<[1]x[1]xi128>
+ return %new_tile : vector<[1]x[1]xi128>
+}
+
+// -----
+
+// CHECK-LABEL: @vector_insert_element_i64
+func.func @vector_insert_element_i64(%tile: vector<[2]x[2]xi64>, %el: i64, %row: index, %col: index) -> vector<[2]x[2]xi64> {
+ // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[2]xi64>, vector<[2]xi1>, i32, i32) -> vector<[2]xi64>
+ // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[2]xi1>, vector<[2]xi64>) -> ()
+ %new_tile = vector.insert %el, %tile[%row, %col] : i64 into vector<[2]x[2]xi64>
+ return %new_tile : vector<[2]x[2]xi64>
+}
+
+// -----
+
+// CHECK-LABEL: @vector_insert_element_f64
+func.func @vector_insert_element_f64(%tile: vector<[2]x[2]xf64>, %el: f64, %row: index, %col: index) -> vector<[2]x[2]xf64> {
+ // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[2]xf64>, vector<[2]xi1>, i32, i32) -> vector<[2]xf64>
+ // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[2]xi1>, vector<[2]xf64>) -> ()
+ %new_tile = vector.insert %el, %tile[%row, %col] : f64 into vector<[2]x[2]xf64>
+ return %new_tile : vector<[2]x[2]xf64>
+}
+
+// -----
+
+// CHECK-LABEL: @vector_insert_element_f32
+func.func @vector_insert_element_f32(%tile: vector<[4]x[4]xf32>, %el: f32, %row: index, %col: index) -> vector<[4]x[4]xf32> {
+ // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xf32>
+ // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[4]xi1>, vector<[4]xf32>) -> ()
+ %new_tile = vector.insert %el, %tile[%row, %col] : f32 into vector<[4]x[4]xf32>
+ return %new_tile : vector<[4]x[4]xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @vector_insert_element_i16
+func.func @vector_insert_element_i16(%tile: vector<[8]x[8]xi16>, %el: i16, %row: index, %col: index) -> vector<[8]x[8]xi16> {
+ // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[8]xi16>, vector<[8]xi1>, i32, i32) -> vector<[8]xi16>
+ // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[8]xi1>, vector<[8]xi16>) -> ()
+ %new_tile = vector.insert %el, %tile[%row, %col] : i16 into vector<[8]x[8]xi16>
+ return %new_tile : vector<[8]x[8]xi16>
+}
+
+// -----
+
+// CHECK-LABEL: @vector_insert_element_f16
+func.func @vector_insert_element_f16(%tile: vector<[8]x[8]xf16>, %el: f16, %row: index, %col: index) -> vector<[8]x[8]xf16> {
+ // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[8]xf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xf16>
+ // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[8]xi1>, vector<[8]xf16>) -> ()
+ %new_tile = vector.insert %el, %tile[%row, %col] : f16 into vector<[8]x[8]xf16>
+ return %new_tile : vector<[8]x[8]xf16>
+}
+
+// -----
+
+// CHECK-LABEL: @vector_insert_element_bf16
+func.func @vector_insert_element_bf16(%tile: vector<[8]x[8]xbf16>, %el: bf16, %row: index, %col: index) -> vector<[8]x[8]xbf16> {
+ // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[8]xbf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xbf16>
+ // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[8]xi1>, vector<[8]xbf16>) -> ()
+ %new_tile = vector.insert %el, %tile[%row, %col] : bf16 into vector<[8]x[8]xbf16>
+ return %new_tile : vector<[8]x[8]xbf16>
+}
+
+// -----
+
+// CHECK-LABEL: @vector_insert_element_i8
+func.func @vector_insert_element_i8(%tile: vector<[16]x[16]xi8>, %el: i8, %row: index, %col: index) -> vector<[16]x[16]xi8> {
+ // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[16]xi8>, vector<[16]xi1>, i32, i32) -> vector<[16]xi8>
+ // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[16]xi1>, vector<[16]xi8>) -> ()
+ %new_tile = vector.insert %el, %tile[%row, %col] : i8 into vector<[16]x[16]xi8>
+ return %new_tile : vector<[16]x[16]xi8>
+}
+
//===----------------------------------------------------------------------===//
// vector.extract
//===----------------------------------------------------------------------===//
// -----
-// CHECK-LABEL: @vector_extract_slice(
+// CHECK-LABEL: @vector_extract_slice_i32(
// CHECK-SAME: %[[TILE:.*]]: vector<[4]x[4]xi32>,
// CHECK-SAME: %[[INDEX:.*]]: index)
-func.func @vector_extract_slice(%tile: vector<[4]x[4]xi32>, %y: index) -> vector<[4]xi32> {
+func.func @vector_extract_slice_i32(%tile: vector<[4]x[4]xi32>, %row: index) -> vector<[4]xi32> {
// CHECK-NEXT: %[[ZERO_VEC:.*]] = arith.constant dense<0> : vector<[4]xi32>
// CHECK-NEXT: %[[PTRUE:.*]] = arith.constant dense<true> : vector<[4]xi1>
// CHECK-NEXT: %[[TILE_ID:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[4]x[4]xi32> to i32
// CHECK-NEXT: %[[TILE_SLICE_INDEX:.*]] = arith.index_cast %[[INDEX]] : index to i32
// CHECK-NEXT: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%[[ZERO_VEC]], %[[PTRUE]], %[[TILE_ID]], %[[TILE_SLICE_INDEX]]) : (vector<[4]xi32>, vector<[4]xi1>, i32, i32) -> vector<[4]xi32>
- %slice = vector.extract %tile[%y] : vector<[4]xi32> from vector<[4]x[4]xi32>
+ %slice = vector.extract %tile[%row] : vector<[4]xi32> from vector<[4]x[4]xi32>
return %slice : vector<[4]xi32>
}
// -----
+// CHECK-LABEL: @vector_extract_slice_i128
+func.func @vector_extract_slice_i128(%tile: vector<[1]x[1]xi128>, %row: index) -> vector<[1]xi128> {
+ // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[1]xi128>, vector<[1]xi1>, i32, i32) -> vector<[1]xi128>
+ %slice = vector.extract %tile[%row] : vector<[1]xi128> from vector<[1]x[1]xi128>
+ return %slice : vector<[1]xi128>
+}
+
+// -----
+
+// CHECK-LABEL: @vector_extract_slice_i64
+func.func @vector_extract_slice_i64(%tile: vector<[2]x[2]xi64>, %row: index) -> vector<[2]xi64> {
+ // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[2]xi64>, vector<[2]xi1>, i32, i32) -> vector<[2]xi64>
+ %slice = vector.extract %tile[%row] : vector<[2]xi64> from vector<[2]x[2]xi64>
+ return %slice : vector<[2]xi64>
+}
+
+// -----
+
+// CHECK-LABEL: @vector_extract_slice_f64
+func.func @vector_extract_slice_f64(%tile: vector<[2]x[2]xf64>, %row: index) -> vector<[2]xf64> {
+ // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[2]xf64>, vector<[2]xi1>, i32, i32) -> vector<[2]xf64>
+ %slice = vector.extract %tile[%row] : vector<[2]xf64> from vector<[2]x[2]xf64>
+ return %slice : vector<[2]xf64>
+}
+
+// -----
+
+// CHECK-LABEL: @vector_extract_slice_f32
+func.func @vector_extract_slice_f32(%tile: vector<[4]x[4]xf32>, %row: index) -> vector<[4]xf32> {
+ // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xf32>
+ %slice = vector.extract %tile[%row] : vector<[4]xf32> from vector<[4]x[4]xf32>
+ return %slice : vector<[4]xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @vector_extract_slice_i16
+func.func @vector_extract_slice_i16(%tile: vector<[8]x[8]xi16>, %row: index) -> vector<[8]xi16> {
+ // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[8]xi16>, vector<[8]xi1>, i32, i32) -> vector<[8]xi16>
+ %slice = vector.extract %tile[%row] : vector<[8]xi16> from vector<[8]x[8]xi16>
+ return %slice : vector<[8]xi16>
+}
+
+// -----
+
+// CHECK-LABEL: @vector_extract_slice_f16
+func.func @vector_extract_slice_f16(%tile: vector<[8]x[8]xf16>, %row: index) -> vector<[8]xf16> {
+ // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[8]xf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xf16>
+ %slice = vector.extract %tile[%row] : vector<[8]xf16> from vector<[8]x[8]xf16>
+ return %slice : vector<[8]xf16>
+}
+
+// -----
+
+// CHECK-LABEL: @vector_extract_slice_bf16
+func.func @vector_extract_slice_bf16(%tile: vector<[8]x[8]xbf16>, %row: index) -> vector<[8]xbf16> {
+ // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[8]xbf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xbf16>
+ %slice = vector.extract %tile[%row] : vector<[8]xbf16> from vector<[8]x[8]xbf16>
+ return %slice : vector<[8]xbf16>
+}
+
+// -----
+
+// CHECK-LABEL: @vector_extract_slice_i8
+func.func @vector_extract_slice_i8(%tile: vector<[16]x[16]xi8>, %row: index) -> vector<[16]xi8> {
+ // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[16]xi8>, vector<[16]xi1>, i32, i32) -> vector<[16]xi8>
+ %slice = vector.extract %tile[%row] : vector<[16]xi8> from vector<[16]x[16]xi8>
+ return %slice : vector<[16]xi8>
+}
+
+// -----
+
// CHECK-LABEL: @vector_extract_element(
// CHECK-SAME: %[[TILE:.*]]: vector<[4]x[4]xi32>,
// CHECK-SAME: %[[ROW:.*]]: index,
// CHECK-SAME: %[[COL:.*]]: index)
-func.func @vector_extract_element(%tile: vector<[4]x[4]xi32>, %y: index, %x: index) -> i32 {
+func.func @vector_extract_element(%tile: vector<[4]x[4]xi32>, %row: index, %col: index) -> i32 {
// CHECK-NEXT: %[[ZERO_VEC:.*]] = arith.constant dense<0> : vector<[4]xi32>
// CHECK-NEXT: %[[PTRUE:.*]] = arith.constant dense<true> : vector<[4]xi1>
// CHECK-NEXT: %[[COL_I32:.*]] = builtin.unrealized_conversion_cast %[[COL]] : index to i64
@@ -571,6 +800,86 @@ func.func @vector_extract_element(%tile: vector<[4]x[4]xi32>, %y: index, %x: ind
// CHECK-NEXT: %[[ROW_I32:.*]] = arith.index_cast %[[ROW]] : index to i32
// CHECK-NEXT: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%[[ZERO_VEC]], %[[PTRUE]], %[[TILE_ID]], %[[ROW_I32]]) : (vector<[4]xi32>, vector<[4]xi1>, i32, i32) -> vector<[4]xi32>
// CHECK-NEXT: %[[EL:.*]] = llvm.extractelement %[[SLICE]]{{\[}}%[[COL_I32]] : i64] : vector<[4]xi32>
- %el = vector.extract %tile[%y,%x] : i32 from vector<[4]x[4]xi32>
+ %el = vector.extract %tile[%row, %col] : i32 from vector<[4]x[4]xi32>
return %el : i32
}
+
+// -----
+
+// CHECK-LABEL: @vector_extract_element_i128
+func.func @vector_extract_element_i128(%tile: vector<[1]x[1]xi128>, %row: index, %col: index) -> i128 {
+ // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[1]xi128>, vector<[1]xi1>, i32, i32) -> vector<[1]xi128>
+ // CHECK-NEXT: %{{.*}} = llvm.extractelement %[[SLICE]]{{\[}}%{{.*}} : i64] : vector<[1]xi128>
+ %el = vector.extract %tile[%row, %col] : i128 from vector<[1]x[1]xi128>
+ return %el : i128
+}
+
+// -----
+
+// CHECK-LABEL: @vector_extract_element_i64
+func.func @vector_extract_element_i64(%tile: vector<[2]x[2]xi64>, %row: index, %col: index) -> i64 {
+ // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[2]xi64>, vector<[2]xi1>, i32, i32) -> vector<[2]xi64>
+ // CHECK-NEXT: %{{.*}} = llvm.extractelement %[[SLICE]]{{\[}}%{{.*}} : i64] : vector<[2]xi64>
+ %el = vector.extract %tile[%row, %col] : i64 from vector<[2]x[2]xi64>
+ return %el : i64
+}
+
+// -----
+
+// CHECK-LABEL: @vector_extract_element_f64
+func.func @vector_extract_element_f64(%tile: vector<[2]x[2]xf64>, %row: index, %col: index) -> f64 {
+ // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[2]xf64>, vector<[2]xi1>, i32, i32) -> vector<[2]xf64>
+ // CHECK-NEXT: %{{.*}} = llvm.extractelement %[[SLICE]]{{\[}}%{{.*}} : i64] : vector<[2]xf64>
+ %el = vector.extract %tile[%row, %col] : f64 from vector<[2]x[2]xf64>
+ return %el : f64
+}
+
+// -----
+
+// CHECK-LABEL: @vector_extract_element_f32
+func.func @vector_extract_element_f32(%tile: vector<[4]x[4]xf32>, %row: index, %col: index) -> f32 {
+ // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xf32>
+ // CHECK-NEXT: %{{.*}} = llvm.extractelement %[[SLICE]]{{\[}}%{{.*}} : i64] : vector<[4]xf32>
+ %el = vector.extract %tile[%row, %col] : f32 from vector<[4]x[4]xf32>
+ return %el : f32
+}
+
+// -----
+
+// CHECK-LABEL: @vector_extract_element_i16
+func.func @vector_extract_element_i16(%tile: vector<[8]x[8]xi16>, %row: index, %col: index) -> i16 {
+ // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[8]xi16>, vector<[8]xi1>, i32, i32) -> vector<[8]xi16>
+ // CHECK-NEXT: %{{.*}} = llvm.extractelement %[[SLICE]]{{\[}}%{{.*}} : i64] : vector<[8]xi16>
+ %el = vector.extract %tile[%row, %col] : i16 from vector<[8]x[8]xi16>
+ return %el : i16
+}
+
+// -----
+
+// CHECK-LABEL: @vector_extract_element_f16
+func.func @vector_extract_element_f16(%tile: vector<[8]x[8]xf16>, %row: index, %col: index) -> f16 {
+ // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[8]xf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xf16>
+ // CHECK-NEXT: %{{.*}} = llvm.extractelement %[[SLICE]]{{\[}}%{{.*}} : i64] : vector<[8]xf16>
+ %el = vector.extract %tile[%row, %col] : f16 from vector<[8]x[8]xf16>
+ return %el : f16
+}
+
+// -----
+
+// CHECK-LABEL: @vector_extract_element_bf16
+func.func @vector_extract_element_bf16(%tile: vector<[8]x[8]xbf16>, %row: index, %col: index) -> bf16 {
+ // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[8]xbf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xbf16>
+ // CHECK-NEXT: %{{.*}} = llvm.extractelement %[[SLICE]]{{\[}}%{{.*}} : i64] : vector<[8]xbf16>
+ %el = vector.extract %tile[%row, %col] : bf16 from vector<[8]x[8]xbf16>
+ return %el : bf16
+}
+
+// -----
+
+// CHECK-LABEL: @vector_extract_element_i8
+func.func @vector_extract_element_i8(%tile: vector<[16]x[16]xi8>, %row: index, %col: index) -> i8 {
+ // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[16]xi8>, vector<[16]xi1>, i32, i32) -> vector<[16]xi8>
+ // CHECK-NEXT: %{{.*}} = llvm.extractelement %[[SLICE]]{{\[}}%{{.*}} : i64] : vector<[16]xi8>
+ %el = vector.extract %tile[%row, %col] : i8 from vector<[16]x[16]xi8>
+ return %el : i8
+}
>From 89cba2d484fff7f198d3c7970317db6bc77a3447 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Mon, 2 Oct 2023 11:40:10 +0000
Subject: [PATCH 3/4] Fixups
---
.../ArmSME/Transforms/LegalizeForLLVMExport.cpp | 14 +++++++++-----
1 file changed, 9 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
index 3f7be30172878e7..ff82c727bd8cb3e 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
@@ -550,7 +550,7 @@ struct VectorOuterProductToArmSMELowering
}
};
-/// Lower `vector.extract` using SME MOVA intrinsics.
+/// Lower `vector.extract` using `arm_sme.move_tile_slice_to_vector`.
///
/// Example:
/// ```
@@ -578,12 +578,13 @@ struct VectorExtractToArmSMELowering
Value sourceVector = extractOp.getVector();
+ // Extract entire vector. Should be handled by folder, but just to be safe.
if (position.empty()) {
rewriter.replaceOp(extractOp, sourceVector);
return success();
}
- Value sliceIndex = vector::getAsValues(rewriter, loc, position[0])[0];
+ Value sliceIndex = vector::getAsValues(rewriter, loc, position[0]).front();
auto moveTileSliceToVector =
rewriter.create<arm_sme::MoveTileSliceToVectorOp>(loc, sourceVector,
sliceIndex);
@@ -603,7 +604,8 @@ struct VectorExtractToArmSMELowering
}
};
-/// Lower `vector.insert` using SME MOVA intrinsics.
+/// Lower `vector.insert` using `arm_sme.move_vector_to_tile_slice` and
+/// `arm_sme.move_tile_slice_to_vector`.
///
/// Example:
/// ```
@@ -634,15 +636,17 @@ struct VectorInsertToArmSMELowering
Value source = adaptor.getSource();
+ // Overwrite entire vector with value. Should be handled by folder, but
+ // just to be safe.
if (position.empty()) {
rewriter.replaceOp(insertOp, source);
return success();
}
Value tileSlice = source;
- Value sliceIndex = vector::getAsValues(rewriter, loc, position[0])[0];
+ Value sliceIndex = vector::getAsValues(rewriter, loc, position[0]).front();
if (position.size() == 2) {
- // Two indices case: Insert signle element into tile.
+ // Two indices case: Insert single element into tile.
// We need to first extract the existing slice and update the element.
tileSlice = rewriter.create<arm_sme::MoveTileSliceToVectorOp>(
loc, adaptor.getDest(), sliceIndex);
>From 4d9501a153c72d89e201ef31c1235e3d2acde404 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Tue, 3 Oct 2023 17:22:09 +0000
Subject: [PATCH 4/4] Fixups
---
.../Transforms/LegalizeForLLVMExport.cpp | 15 +-
.../Dialect/ArmSME/vector-ops-to-llvm.mlir | 284 +++++++++---------
2 files changed, 156 insertions(+), 143 deletions(-)
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
index ff82c727bd8cb3e..5e13707ea0aa2b9 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
@@ -554,13 +554,13 @@ struct VectorOuterProductToArmSMELowering
///
/// Example:
/// ```
-/// %el = vector.extract %tile[%y,%x]: i32 from vector<[4]x[4]xi32>
+/// %el = vector.extract %tile[%row, %col]: i32 from vector<[4]x[4]xi32>
/// ```
/// Becomes:
/// ```
-/// %slice = arm_sme.move_tile_slice_to_vector %tile[%y]
+/// %slice = arm_sme.move_tile_slice_to_vector %tile[%row]
/// : vector<[4]xi32> from vector<[4]x[4]xi32>
-/// %el = vector.extract %slice[%x] : i32 from vector<[4]xi32>
+/// %el = vector.extract %slice[%col] : i32 from vector<[4]xi32>
/// ```
struct VectorExtractToArmSMELowering
: public ConvertOpToLLVMPattern<vector::ExtractOp> {
@@ -609,14 +609,15 @@ struct VectorExtractToArmSMELowering
///
/// Example:
/// ```
-/// %new_tile = vector.insert %el, %tile[%y,%x] : i32 into vector<[4]x[4]xi32>
+/// %new_tile = vector.insert %el, %tile[%row, %col]
+/// : i32 into vector<[4]x[4]xi32>
/// ```
/// Becomes:
/// ```
-/// %slice = arm_sme.move_tile_slice_to_vector %tile[%y]
+/// %slice = arm_sme.move_tile_slice_to_vector %tile[%row]
/// : vector<[4]xi32> from vector<[4]x[4]xi32>
-/// %new_slice = vector.insert %el, %slice[%x] : i32 into vector<[4]xi32>
-/// %new_tile = arm_sme.move_vector_to_tile_slice %new_slice, %tile, %y
+/// %new_slice = vector.insert %el, %slice[%col] : i32 into vector<[4]xi32>
+/// %new_tile = arm_sme.move_vector_to_tile_slice %new_slice, %tile, %row
/// : vector<[4]xi32> into vector<[4]x[4]xi32>
/// ```
struct VectorInsertToArmSMELowering
diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
index a987f9612fb69a8..319e7609a759d85 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
@@ -518,13 +518,21 @@ func.func @vector_insert_slice_i32(%tile: vector<[4]x[4]xi32>, %slice: vector<[4
// -----
-// CHECK-LABEL: @vector_insert_slice_i128
-func.func @vector_insert_slice_i128(%tile: vector<[1]x[1]xi128>, %slice: vector<[1]xi128>, %row: index) -> vector<[1]x[1]xi128> {
- // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[1]xi1>, vector<[1]xi128>) -> ()
- %new_tile = vector.insert %slice, %tile[%row] : vector<[1]xi128> into vector<[1]x[1]xi128>
- return %new_tile : vector<[1]x[1]xi128>
+// CHECK-LABEL: @vector_insert_slice_i8
+func.func @vector_insert_slice_i8(%tile: vector<[16]x[16]xi8>, %slice: vector<[16]xi8>, %row: index) -> vector<[16]x[16]xi8> {
+ // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[16]xi1>, vector<[16]xi8>) -> ()
+ %new_tile = vector.insert %slice, %tile[%row] : vector<[16]xi8> into vector<[16]x[16]xi8>
+ return %new_tile : vector<[16]x[16]xi8>
}
+// -----
+
+// CHECK-LABEL: @vector_insert_slice_i16
+func.func @vector_insert_slice_i16(%tile: vector<[8]x[8]xi16>, %slice: vector<[8]xi16>, %row: index) -> vector<[8]x[8]xi16> {
+ // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[8]xi1>, vector<[8]xi16>) -> ()
+ %new_tile = vector.insert %slice, %tile[%row] : vector<[8]xi16> into vector<[8]x[8]xi16>
+ return %new_tile : vector<[8]x[8]xi16>
+}
// -----
@@ -535,36 +543,15 @@ func.func @vector_insert_slice_i64(%tile: vector<[2]x[2]xi64>, %slice: vector<[2
return %new_tile : vector<[2]x[2]xi64>
}
-
-// -----
-
-// CHECK-LABEL: @vector_insert_slice_f64
-func.func @vector_insert_slice_f64(%tile: vector<[2]x[2]xf64>, %slice: vector<[2]xf64>, %row: index) -> vector<[2]x[2]xf64> {
- // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[2]xi1>, vector<[2]xf64>) -> ()
- %new_tile = vector.insert %slice, %tile[%row] : vector<[2]xf64> into vector<[2]x[2]xf64>
- return %new_tile : vector<[2]x[2]xf64>
-}
-
// -----
-// CHECK-LABEL: @vector_insert_slice_f32
-func.func @vector_insert_slice_f32(%tile: vector<[4]x[4]xf32>, %slice: vector<[4]xf32>, %row: index) -> vector<[4]x[4]xf32> {
- // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[4]xi1>, vector<[4]xf32>) -> ()
- %new_tile = vector.insert %slice, %tile[%row] : vector<[4]xf32> into vector<[4]x[4]xf32>
- return %new_tile : vector<[4]x[4]xf32>
-}
-
-
-// -----
-
-// CHECK-LABEL: @vector_insert_slice_i16
-func.func @vector_insert_slice_i16(%tile: vector<[8]x[8]xi16>, %slice: vector<[8]xi16>, %row: index) -> vector<[8]x[8]xi16> {
- // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[8]xi1>, vector<[8]xi16>) -> ()
- %new_tile = vector.insert %slice, %tile[%row] : vector<[8]xi16> into vector<[8]x[8]xi16>
- return %new_tile : vector<[8]x[8]xi16>
+// CHECK-LABEL: @vector_insert_slice_i128
+func.func @vector_insert_slice_i128(%tile: vector<[1]x[1]xi128>, %slice: vector<[1]xi128>, %row: index) -> vector<[1]x[1]xi128> {
+ // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[1]xi1>, vector<[1]xi128>) -> ()
+ %new_tile = vector.insert %slice, %tile[%row] : vector<[1]xi128> into vector<[1]x[1]xi128>
+ return %new_tile : vector<[1]x[1]xi128>
}
-
// -----
// CHECK-LABEL: @vector_insert_slice_f16
@@ -574,7 +561,6 @@ func.func @vector_insert_slice_f16(%tile: vector<[8]x[8]xf16>, %slice: vector<[8
return %new_tile : vector<[8]x[8]xf16>
}
-
// -----
// CHECK-LABEL: @vector_insert_slice_bf16
@@ -584,6 +570,23 @@ func.func @vector_insert_slice_bf16(%tile: vector<[8]x[8]xbf16>, %slice: vector<
return %new_tile : vector<[8]x[8]xbf16>
}
+// -----
+
+// CHECK-LABEL: @vector_insert_slice_f32
+func.func @vector_insert_slice_f32(%tile: vector<[4]x[4]xf32>, %slice: vector<[4]xf32>, %row: index) -> vector<[4]x[4]xf32> {
+ // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[4]xi1>, vector<[4]xf32>) -> ()
+ %new_tile = vector.insert %slice, %tile[%row] : vector<[4]xf32> into vector<[4]x[4]xf32>
+ return %new_tile : vector<[4]x[4]xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @vector_insert_slice_f64
+func.func @vector_insert_slice_f64(%tile: vector<[2]x[2]xf64>, %slice: vector<[2]xf64>, %row: index) -> vector<[2]x[2]xf64> {
+ // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[2]xi1>, vector<[2]xf64>) -> ()
+ %new_tile = vector.insert %slice, %tile[%row] : vector<[2]xf64> into vector<[2]x[2]xf64>
+ return %new_tile : vector<[2]x[2]xf64>
+}
// -----
@@ -617,12 +620,22 @@ func.func @vector_insert_element_i32(%tile: vector<[4]x[4]xi32>, %el: i32, %row:
// -----
-// CHECK-LABEL: @vector_insert_element_i128
-func.func @vector_insert_element_i128(%tile: vector<[1]x[1]xi128>, %el: i128, %row: index, %col: index) -> vector<[1]x[1]xi128> {
- // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[1]xi128>, vector<[1]xi1>, i32, i32) -> vector<[1]xi128>
- // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[1]xi1>, vector<[1]xi128>) -> ()
- %new_tile = vector.insert %el, %tile[%row, %col] : i128 into vector<[1]x[1]xi128>
- return %new_tile : vector<[1]x[1]xi128>
+// CHECK-LABEL: @vector_insert_element_i8
+func.func @vector_insert_element_i8(%tile: vector<[16]x[16]xi8>, %el: i8, %row: index, %col: index) -> vector<[16]x[16]xi8> {
+ // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[16]xi8>, vector<[16]xi1>, i32, i32) -> vector<[16]xi8>
+ // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[16]xi1>, vector<[16]xi8>) -> ()
+ %new_tile = vector.insert %el, %tile[%row, %col] : i8 into vector<[16]x[16]xi8>
+ return %new_tile : vector<[16]x[16]xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @vector_insert_element_i16
+func.func @vector_insert_element_i16(%tile: vector<[8]x[8]xi16>, %el: i16, %row: index, %col: index) -> vector<[8]x[8]xi16> {
+ // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[8]xi16>, vector<[8]xi1>, i32, i32) -> vector<[8]xi16>
+ // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[8]xi1>, vector<[8]xi16>) -> ()
+ %new_tile = vector.insert %el, %tile[%row, %col] : i16 into vector<[8]x[8]xi16>
+ return %new_tile : vector<[8]x[8]xi16>
}
// -----
@@ -637,32 +650,12 @@ func.func @vector_insert_element_i64(%tile: vector<[2]x[2]xi64>, %el: i64, %row:
// -----
-// CHECK-LABEL: @vector_insert_element_f64
-func.func @vector_insert_element_f64(%tile: vector<[2]x[2]xf64>, %el: f64, %row: index, %col: index) -> vector<[2]x[2]xf64> {
- // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[2]xf64>, vector<[2]xi1>, i32, i32) -> vector<[2]xf64>
- // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[2]xi1>, vector<[2]xf64>) -> ()
- %new_tile = vector.insert %el, %tile[%row, %col] : f64 into vector<[2]x[2]xf64>
- return %new_tile : vector<[2]x[2]xf64>
-}
-
-// -----
-
-// CHECK-LABEL: @vector_insert_element_f32
-func.func @vector_insert_element_f32(%tile: vector<[4]x[4]xf32>, %el: f32, %row: index, %col: index) -> vector<[4]x[4]xf32> {
- // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xf32>
- // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[4]xi1>, vector<[4]xf32>) -> ()
- %new_tile = vector.insert %el, %tile[%row, %col] : f32 into vector<[4]x[4]xf32>
- return %new_tile : vector<[4]x[4]xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @vector_insert_element_i16
-func.func @vector_insert_element_i16(%tile: vector<[8]x[8]xi16>, %el: i16, %row: index, %col: index) -> vector<[8]x[8]xi16> {
- // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[8]xi16>, vector<[8]xi1>, i32, i32) -> vector<[8]xi16>
- // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[8]xi1>, vector<[8]xi16>) -> ()
- %new_tile = vector.insert %el, %tile[%row, %col] : i16 into vector<[8]x[8]xi16>
- return %new_tile : vector<[8]x[8]xi16>
+// CHECK-LABEL: @vector_insert_element_i128
+func.func @vector_insert_element_i128(%tile: vector<[1]x[1]xi128>, %el: i128, %row: index, %col: index) -> vector<[1]x[1]xi128> {
+ // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[1]xi128>, vector<[1]xi1>, i32, i32) -> vector<[1]xi128>
+ // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[1]xi1>, vector<[1]xi128>) -> ()
+ %new_tile = vector.insert %el, %tile[%row, %col] : i128 into vector<[1]x[1]xi128>
+ return %new_tile : vector<[1]x[1]xi128>
}
// -----
@@ -687,12 +680,22 @@ func.func @vector_insert_element_bf16(%tile: vector<[8]x[8]xbf16>, %el: bf16, %r
// -----
-// CHECK-LABEL: @vector_insert_element_i8
-func.func @vector_insert_element_i8(%tile: vector<[16]x[16]xi8>, %el: i8, %row: index, %col: index) -> vector<[16]x[16]xi8> {
- // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[16]xi8>, vector<[16]xi1>, i32, i32) -> vector<[16]xi8>
- // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[16]xi1>, vector<[16]xi8>) -> ()
- %new_tile = vector.insert %el, %tile[%row, %col] : i8 into vector<[16]x[16]xi8>
- return %new_tile : vector<[16]x[16]xi8>
+// CHECK-LABEL: @vector_insert_element_f32
+func.func @vector_insert_element_f32(%tile: vector<[4]x[4]xf32>, %el: f32, %row: index, %col: index) -> vector<[4]x[4]xf32> {
+ // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xf32>
+ // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[4]xi1>, vector<[4]xf32>) -> ()
+ %new_tile = vector.insert %el, %tile[%row, %col] : f32 into vector<[4]x[4]xf32>
+ return %new_tile : vector<[4]x[4]xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @vector_insert_element_f64
+func.func @vector_insert_element_f64(%tile: vector<[2]x[2]xf64>, %el: f64, %row: index, %col: index) -> vector<[2]x[2]xf64> {
+ // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[2]xf64>, vector<[2]xi1>, i32, i32) -> vector<[2]xf64>
+ // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[2]xi1>, vector<[2]xf64>) -> ()
+ %new_tile = vector.insert %el, %tile[%row, %col] : f64 into vector<[2]x[2]xf64>
+ return %new_tile : vector<[2]x[2]xf64>
}
//===----------------------------------------------------------------------===//
@@ -716,11 +719,20 @@ func.func @vector_extract_slice_i32(%tile: vector<[4]x[4]xi32>, %row: index) ->
// -----
-// CHECK-LABEL: @vector_extract_slice_i128
-func.func @vector_extract_slice_i128(%tile: vector<[1]x[1]xi128>, %row: index) -> vector<[1]xi128> {
- // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[1]xi128>, vector<[1]xi1>, i32, i32) -> vector<[1]xi128>
- %slice = vector.extract %tile[%row] : vector<[1]xi128> from vector<[1]x[1]xi128>
- return %slice : vector<[1]xi128>
+// CHECK-LABEL: @vector_extract_slice_i8
+func.func @vector_extract_slice_i8(%tile: vector<[16]x[16]xi8>, %row: index) -> vector<[16]xi8> {
+ // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[16]xi8>, vector<[16]xi1>, i32, i32) -> vector<[16]xi8>
+ %slice = vector.extract %tile[%row] : vector<[16]xi8> from vector<[16]x[16]xi8>
+ return %slice : vector<[16]xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @vector_extract_slice_i16
+func.func @vector_extract_slice_i16(%tile: vector<[8]x[8]xi16>, %row: index) -> vector<[8]xi16> {
+ // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[8]xi16>, vector<[8]xi1>, i32, i32) -> vector<[8]xi16>
+ %slice = vector.extract %tile[%row] : vector<[8]xi16> from vector<[8]x[8]xi16>
+ return %slice : vector<[8]xi16>
}
// -----
@@ -734,29 +746,11 @@ func.func @vector_extract_slice_i64(%tile: vector<[2]x[2]xi64>, %row: index) ->
// -----
-// CHECK-LABEL: @vector_extract_slice_f64
-func.func @vector_extract_slice_f64(%tile: vector<[2]x[2]xf64>, %row: index) -> vector<[2]xf64> {
- // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[2]xf64>, vector<[2]xi1>, i32, i32) -> vector<[2]xf64>
- %slice = vector.extract %tile[%row] : vector<[2]xf64> from vector<[2]x[2]xf64>
- return %slice : vector<[2]xf64>
-}
-
-// -----
-
-// CHECK-LABEL: @vector_extract_slice_f32
-func.func @vector_extract_slice_f32(%tile: vector<[4]x[4]xf32>, %row: index) -> vector<[4]xf32> {
- // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xf32>
- %slice = vector.extract %tile[%row] : vector<[4]xf32> from vector<[4]x[4]xf32>
- return %slice : vector<[4]xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @vector_extract_slice_i16
-func.func @vector_extract_slice_i16(%tile: vector<[8]x[8]xi16>, %row: index) -> vector<[8]xi16> {
- // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[8]xi16>, vector<[8]xi1>, i32, i32) -> vector<[8]xi16>
- %slice = vector.extract %tile[%row] : vector<[8]xi16> from vector<[8]x[8]xi16>
- return %slice : vector<[8]xi16>
+// CHECK-LABEL: @vector_extract_slice_i128
+func.func @vector_extract_slice_i128(%tile: vector<[1]x[1]xi128>, %row: index) -> vector<[1]xi128> {
+ // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[1]xi128>, vector<[1]xi1>, i32, i32) -> vector<[1]xi128>
+ %slice = vector.extract %tile[%row] : vector<[1]xi128> from vector<[1]x[1]xi128>
+ return %slice : vector<[1]xi128>
}
// -----
@@ -779,6 +773,24 @@ func.func @vector_extract_slice_bf16(%tile: vector<[8]x[8]xbf16>, %row: index) -
// -----
+// CHECK-LABEL: @vector_extract_slice_f32
+func.func @vector_extract_slice_f32(%tile: vector<[4]x[4]xf32>, %row: index) -> vector<[4]xf32> {
+ // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xf32>
+ %slice = vector.extract %tile[%row] : vector<[4]xf32> from vector<[4]x[4]xf32>
+ return %slice : vector<[4]xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @vector_extract_slice_f64
+func.func @vector_extract_slice_f64(%tile: vector<[2]x[2]xf64>, %row: index) -> vector<[2]xf64> {
+ // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[2]xf64>, vector<[2]xi1>, i32, i32) -> vector<[2]xf64>
+ %slice = vector.extract %tile[%row] : vector<[2]xf64> from vector<[2]x[2]xf64>
+ return %slice : vector<[2]xf64>
+}
+
+// -----
+
// CHECK-LABEL: @vector_extract_slice_i8
func.func @vector_extract_slice_i8(%tile: vector<[16]x[16]xi8>, %row: index) -> vector<[16]xi8> {
// CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[16]xi8>, vector<[16]xi1>, i32, i32) -> vector<[16]xi8>
@@ -806,12 +818,22 @@ func.func @vector_extract_element(%tile: vector<[4]x[4]xi32>, %row: index, %col:
// -----
-// CHECK-LABEL: @vector_extract_element_i128
-func.func @vector_extract_element_i128(%tile: vector<[1]x[1]xi128>, %row: index, %col: index) -> i128 {
- // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[1]xi128>, vector<[1]xi1>, i32, i32) -> vector<[1]xi128>
- // CHECK-NEXT: %{{.*}} = llvm.extractelement %[[SLICE]]{{\[}}%{{.*}} : i64] : vector<[1]xi128>
- %el = vector.extract %tile[%row, %col] : i128 from vector<[1]x[1]xi128>
- return %el : i128
+// CHECK-LABEL: @vector_extract_element_i8
+func.func @vector_extract_element_i8(%tile: vector<[16]x[16]xi8>, %row: index, %col: index) -> i8 {
+ // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[16]xi8>, vector<[16]xi1>, i32, i32) -> vector<[16]xi8>
+ // CHECK-NEXT: %{{.*}} = llvm.extractelement %[[SLICE]]{{\[}}%{{.*}} : i64] : vector<[16]xi8>
+ %el = vector.extract %tile[%row, %col] : i8 from vector<[16]x[16]xi8>
+ return %el : i8
+}
+
+// -----
+
+// CHECK-LABEL: @vector_extract_element_i16
+func.func @vector_extract_element_i16(%tile: vector<[8]x[8]xi16>, %row: index, %col: index) -> i16 {
+ // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[8]xi16>, vector<[8]xi1>, i32, i32) -> vector<[8]xi16>
+ // CHECK-NEXT: %{{.*}} = llvm.extractelement %[[SLICE]]{{\[}}%{{.*}} : i64] : vector<[8]xi16>
+ %el = vector.extract %tile[%row, %col] : i16 from vector<[8]x[8]xi16>
+ return %el : i16
}
// -----
@@ -826,32 +848,12 @@ func.func @vector_extract_element_i64(%tile: vector<[2]x[2]xi64>, %row: index, %
// -----
-// CHECK-LABEL: @vector_extract_element_f64
-func.func @vector_extract_element_f64(%tile: vector<[2]x[2]xf64>, %row: index, %col: index) -> f64 {
- // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[2]xf64>, vector<[2]xi1>, i32, i32) -> vector<[2]xf64>
- // CHECK-NEXT: %{{.*}} = llvm.extractelement %[[SLICE]]{{\[}}%{{.*}} : i64] : vector<[2]xf64>
- %el = vector.extract %tile[%row, %col] : f64 from vector<[2]x[2]xf64>
- return %el : f64
-}
-
-// -----
-
-// CHECK-LABEL: @vector_extract_element_f32
-func.func @vector_extract_element_f32(%tile: vector<[4]x[4]xf32>, %row: index, %col: index) -> f32 {
- // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xf32>
- // CHECK-NEXT: %{{.*}} = llvm.extractelement %[[SLICE]]{{\[}}%{{.*}} : i64] : vector<[4]xf32>
- %el = vector.extract %tile[%row, %col] : f32 from vector<[4]x[4]xf32>
- return %el : f32
-}
-
-// -----
-
-// CHECK-LABEL: @vector_extract_element_i16
-func.func @vector_extract_element_i16(%tile: vector<[8]x[8]xi16>, %row: index, %col: index) -> i16 {
- // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[8]xi16>, vector<[8]xi1>, i32, i32) -> vector<[8]xi16>
- // CHECK-NEXT: %{{.*}} = llvm.extractelement %[[SLICE]]{{\[}}%{{.*}} : i64] : vector<[8]xi16>
- %el = vector.extract %tile[%row, %col] : i16 from vector<[8]x[8]xi16>
- return %el : i16
+// CHECK-LABEL: @vector_extract_element_i128
+func.func @vector_extract_element_i128(%tile: vector<[1]x[1]xi128>, %row: index, %col: index) -> i128 {
+ // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[1]xi128>, vector<[1]xi1>, i32, i32) -> vector<[1]xi128>
+ // CHECK-NEXT: %{{.*}} = llvm.extractelement %[[SLICE]]{{\[}}%{{.*}} : i64] : vector<[1]xi128>
+ %el = vector.extract %tile[%row, %col] : i128 from vector<[1]x[1]xi128>
+ return %el : i128
}
// -----
@@ -876,10 +878,20 @@ func.func @vector_extract_element_bf16(%tile: vector<[8]x[8]xbf16>, %row: index,
// -----
-// CHECK-LABEL: @vector_extract_element_i8
-func.func @vector_extract_element_i8(%tile: vector<[16]x[16]xi8>, %row: index, %col: index) -> i8 {
- // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[16]xi8>, vector<[16]xi1>, i32, i32) -> vector<[16]xi8>
- // CHECK-NEXT: %{{.*}} = llvm.extractelement %[[SLICE]]{{\[}}%{{.*}} : i64] : vector<[16]xi8>
- %el = vector.extract %tile[%row, %col] : i8 from vector<[16]x[16]xi8>
- return %el : i8
+// CHECK-LABEL: @vector_extract_element_f32
+func.func @vector_extract_element_f32(%tile: vector<[4]x[4]xf32>, %row: index, %col: index) -> f32 {
+ // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xf32>
+ // CHECK-NEXT: %{{.*}} = llvm.extractelement %[[SLICE]]{{\[}}%{{.*}} : i64] : vector<[4]xf32>
+ %el = vector.extract %tile[%row, %col] : f32 from vector<[4]x[4]xf32>
+ return %el : f32
+}
+
+// -----
+
+// CHECK-LABEL: @vector_extract_element_f64
+func.func @vector_extract_element_f64(%tile: vector<[2]x[2]xf64>, %row: index, %col: index) -> f64 {
+ // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[2]xf64>, vector<[2]xi1>, i32, i32) -> vector<[2]xf64>
+ // CHECK-NEXT: %{{.*}} = llvm.extractelement %[[SLICE]]{{\[}}%{{.*}} : i64] : vector<[2]xf64>
+ %el = vector.extract %tile[%row, %col] : f64 from vector<[2]x[2]xf64>
+ return %el : f64
}
More information about the Mlir-commits
mailing list