[Mlir-commits] [mlir] [mlir][ArmSME] Lower vector.extract/insert on SME tiles to MOVA intrinsics (PR #67786)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Sep 29 04:06:34 PDT 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-sme
<details>
<summary>Changes</summary>
This patch adds support for lowering vector.insert/extract of tile slices or elements to ArmSME MOVA intrinsic.
This enables the following operations for ArmSME:
```
// Extract slice from tile:
%slice = vector.extract %tile[%y]: vector<[4]x[4]xi32>
```
```
// Extract element from tile:
%el = vector.extract %tile[%y, %x]: vector<[4]x[4]xi32>
```
```
// Insert slice into tile:
%new_tile = vector.insert %slice, %tile[%y]
: vector<[4]xi32> into vector<[4]x[4]xi32>
```
```
// Insert element into tile;
%new_tile = vector.insert %el, %tile[%y, %x]
: i32 into vector<[4]x[4]xi32>
```
---
Full diff: https://github.com/llvm/llvm-project/pull/67786.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp (+112-5)
- (modified) mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir (+78)
``````````diff
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
index 0322c2f3fcd14d4..edf9e333b0e4784 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
@@ -550,6 +550,113 @@ struct VectorOuterProductToArmSMELowering
}
};
+/// Lower `vector.extract` using SME MOVA intrinsics.
+///
+/// Example:
+/// ```
+/// %el = vector.extract %tile[%y,%x]: i32 from vector<[4]x[4]xi32>
+/// ```
+/// Becomes:
+/// ```
+/// %slice = arm_sme.move_tile_slice_to_vector %tile[%y]
+/// : vector<[4]xi32> from vector<[4]x[4]xi32>
+/// %el = vector.extract %slice[%x] : i32 from vector<[4]xi32>
+/// ```
+struct VectorExtractToArmSMELowering
+ : public ConvertOpToLLVMPattern<vector::ExtractOp> {
+ using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ VectorType sourceType = extractOp.getSourceVectorType();
+ if (!isValidSMETileVectorType(sourceType))
+ return failure();
+
+ auto loc = extractOp.getLoc();
+ auto position = extractOp.getMixedPosition();
+
+ Value sourceVector = extractOp.getVector();
+
+ if (position.empty()) {
+ rewriter.replaceOp(extractOp, sourceVector);
+ return success();
+ }
+
+ Value sliceIndex = vector::getAsValues(rewriter, loc, position[0])[0];
+ auto moveTileSliceToVector =
+ rewriter.create<arm_sme::MoveTileSliceToVectorOp>(loc, sourceVector,
+ sliceIndex);
+
+ if (position.size() == 1) {
+ // Single index case: Extracts a 1D slice.
+ rewriter.replaceOp(extractOp, moveTileSliceToVector);
+ return success();
+ }
+
+ // Two indices case: Extracts a single element.
+ assert(position.size() == 2);
+ rewriter.replaceOpWithNewOp<vector::ExtractOp>(
+ extractOp, moveTileSliceToVector, position[1]);
+
+ return success();
+ }
+};
+
+/// Lower `vector.insert` using SME MOVA intrinsics.
+///
+/// Example:
+/// ```
+/// %new_tile = vector.insert %el, %tile[%y,%x] : i32 into vector<[4]x[4]xi32>
+/// ```
+/// Becomes:
+/// ```
+/// %slice = arm_sme.move_tile_slice_to_vector %tile[%y]
+/// : vector<[4]xi32> from vector<[4]x[4]xi32>
+/// %new_slice = vector.insert %el, %slice[%x] : i32 into vector<[4]xi32>
+/// %new_tile = arm_sme.move_vector_to_tile_slice %new_slice, %tile, %y
+/// : vector<[4]xi32> into vector<[4]x[4]xi32>
+/// ```
+struct VectorInsertToArmSMELowering
+ : public ConvertOpToLLVMPattern<vector::InsertOp> {
+ using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ VectorType resultType = insertOp.getResult().getType();
+
+ if (!isValidSMETileVectorType(resultType))
+ return failure();
+
+ auto loc = insertOp.getLoc();
+ auto position = insertOp.getMixedPosition();
+
+ Value source = adaptor.getSource();
+
+ if (position.empty()) {
+ rewriter.replaceOp(insertOp, source);
+ return success();
+ }
+
+ Value tileSlice = source;
+ Value sliceIndex = vector::getAsValues(rewriter, loc, position[0])[0];
+ if (position.size() == 2) {
+ // Two indices case: Insert signle element into tile.
+ // We need to first extract the existing slice and update the element.
+ tileSlice = rewriter.create<arm_sme::MoveTileSliceToVectorOp>(
+ loc, adaptor.getDest(), sliceIndex);
+ tileSlice = rewriter.create<vector::InsertOp>(loc, source, tileSlice,
+ position[1]);
+ }
+
+ // Insert the slice into the destination tile.
+ rewriter.replaceOpWithNewOp<arm_sme::MoveVectorToTileSliceOp>(
+ insertOp, tileSlice, adaptor.getDest(), sliceIndex);
+ return success();
+ }
+};
+
} // namespace
void mlir::configureArmSMELegalizeForExportTarget(
@@ -601,9 +708,9 @@ void mlir::configureArmSMELegalizeForExportTarget(
void mlir::populateArmSMELegalizeForLLVMExportPatterns(
LLVMTypeConverter &converter, RewritePatternSet &patterns) {
patterns.add<EnableZAPattern, DisableZAPattern>(patterns.getContext());
- patterns
- .add<ZeroOpConversion, StoreTileSliceToArmSMELowering,
- LoadTileSliceToArmSMELowering, MoveTileSliceToVectorArmSMELowering,
- MoveVectorToTileSliceToArmSMELowering,
- VectorOuterProductToArmSMELowering>(converter);
+ patterns.add<
+ ZeroOpConversion, StoreTileSliceToArmSMELowering,
+ LoadTileSliceToArmSMELowering, MoveTileSliceToVectorArmSMELowering,
+ MoveVectorToTileSliceToArmSMELowering, VectorOuterProductToArmSMELowering,
+ VectorExtractToArmSMELowering, VectorInsertToArmSMELowering>(converter);
}
diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
index 687ef79385334cf..9678a0c91c38a32 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
@@ -496,3 +496,81 @@ func.func @vector_outerproduct_add_masked_f32(%lhs : vector<[4]xf32>, %rhs : vec
%0 = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
"prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> ()
}
+
+//===----------------------------------------------------------------------===//
+// vector.insert
+//===----------------------------------------------------------------------===//
+
+// -----
+
+// CHECK-LABEL: @vector_insert_slice(
+// CHECK-SAME: %[[TILE:.*]]: vector<[4]x[4]xi32>,
+// CHECK-SAME: %[[SLICE:.*]]: vector<[4]xi32>,
+// CHECK-SAME: %[[INDEX:.*]]: index)
+func.func @vector_insert_slice(%tile: vector<[4]x[4]xi32>, %slice: vector<[4]xi32>, %y: index) -> vector<[4]x[4]xi32>{
+ // CHECK-NEXT: %[[PTRUE:.*]] = arith.constant dense<true> : vector<[4]xi1>
+ // CHECK-NEXT: %[[TILE_ID:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[4]x[4]xi32> to i32
+ // CHECK-NEXT: %[[TILE_SLICE_INDEX:.*]] = arith.index_castui %[[INDEX]] : index to i32
+ // CHECK-NEXT: "arm_sme.intr.write.horiz"(%[[TILE_ID]], %[[TILE_SLICE_INDEX]], %[[PTRUE]], %[[SLICE]]) : (i32, i32, vector<[4]xi1>, vector<[4]xi32>) -> ()
+ %new_tile = vector.insert %slice, %tile[%y] : vector<[4]xi32> into vector<[4]x[4]xi32>
+ return %new_tile : vector<[4]x[4]xi32>
+}
+
+// -----
+
+
+// CHECK-LABEL: @vector_insert_element(
+// CHECK-SAME: %[[TILE:.*]]: vector<[4]x[4]xi32>,
+// CHECK-SAME: %[[EL:.*]]: i32,
+// CHECK-SAME: %[[Y:.*]]: index,
+// CHECK-SAME: %[[X:.*]]: index)
+func.func @vector_insert_element(%tile: vector<[4]x[4]xi32>, %el: i32, %y: index, %x: index) -> vector<[4]x[4]xi32> {
+ // CHECK-NEXT: %[[ZERO_VEC:.*]] = arith.constant dense<0> : vector<[4]xi32>
+ // CHECK-NEXT: %[[PTRUE:.*]] = arith.constant dense<true> : vector<[4]xi1>
+ // CHECK-NEXT: %[[X_I32:.*]] = builtin.unrealized_conversion_cast %[[X]] : index to i64
+ // CHECK-NEXT: %[[TILE_ID:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[4]x[4]xi32> to i32
+ // CHECK-NEXT: %[[Y_I32:.*]] = arith.index_cast %[[Y]] : index to i32
+ // CHECK-NEXT: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%[[ZERO_VEC]], %[[PTRUE]], %[[TILE_ID]], %[[Y_I32]]) : (vector<[4]xi32>, vector<[4]xi1>, i32, i32) -> vector<[4]xi32>
+ // CHECK-NEXT: %[[NEW_SLICE:.*]] = llvm.insertelement %[[EL]], %[[SLICE]]{{\[}}%[[X_I32]] : i64] : vector<[4]xi32>
+ // CHECK-NEXT: %[[SLICE_INDEX:.*]] = arith.index_castui %[[Y]] : index to i32
+ // CHECK-NEXT: "arm_sme.intr.write.horiz"(%[[TILE_ID]], %[[SLICE_INDEX]], %[[PTRUE]], %[[NEW_SLICE]]) : (i32, i32, vector<[4]xi1>, vector<[4]xi32>) -> ()
+ %new_tile = vector.insert %el, %tile[%y,%x] : i32 into vector<[4]x[4]xi32>
+ return %new_tile : vector<[4]x[4]xi32>
+}
+
+//===----------------------------------------------------------------------===//
+// vector.extract
+//===----------------------------------------------------------------------===//
+
+// -----
+
+// CHECK-LABEL: @extract_insert_slice(
+// CHECK-SAME: %[[TILE:.*]]: vector<[4]x[4]xi32>,
+// CHECK-SAME: %[[INDEX:.*]]: index)
+func.func @extract_insert_slice(%tile: vector<[4]x[4]xi32>, %y: index) -> vector<[4]xi32> {
+ // CHECK-NEXT: %[[ZERO_VEC:.*]] = arith.constant dense<0> : vector<[4]xi32>
+ // CHECK-NEXT: %[[PTRUE:.*]] = arith.constant dense<true> : vector<[4]xi1>
+ // CHECK-NEXT: %[[TILE_ID:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[4]x[4]xi32> to i32
+ // CHECK-NEXT: %[[TILE_SLICE_INDEX:.*]] = arith.index_cast %[[INDEX]] : index to i32
+ // CHECK-NEXT: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%[[ZERO_VEC]], %[[PTRUE]], %[[TILE_ID]], %[[TILE_SLICE_INDEX]]) : (vector<[4]xi32>, vector<[4]xi1>, i32, i32) -> vector<[4]xi32>
+ %slice = vector.extract %tile[%y] : vector<[4]xi32> from vector<[4]x[4]xi32>
+ return %slice : vector<[4]xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @vector_extract_element(
+// CHECK-SAME: %[[TILE:.*]]: vector<[4]x[4]xi32>,
+// CHECK-SAME: %[[Y:.*]]: index,
+// CHECK-SAME: %[[X:.*]]: index)
+func.func @vector_extract_element(%tile: vector<[4]x[4]xi32>, %y: index, %x: index) -> i32 {
+ // CHECK-NEXT: %[[ZERO_VEC:.*]] = arith.constant dense<0> : vector<[4]xi32>
+ // CHECK-NEXT: %[[PTRUE:.*]] = arith.constant dense<true> : vector<[4]xi1>
+ // CHECK-NEXT: %[[X_I32:.*]] = builtin.unrealized_conversion_cast %[[X]] : index to i64
+ // CHECK-NEXT: %[[TILE_ID:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[4]x[4]xi32> to i32
+ // CHECK-NEXT: %[[Y_I32:.*]] = arith.index_cast %[[Y]] : index to i32
+ // CHECK-NEXT: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%[[ZERO_VEC]], %[[PTRUE]], %[[TILE_ID]], %[[Y_I32]]) : (vector<[4]xi32>, vector<[4]xi1>, i32, i32) -> vector<[4]xi32>
+ // CHECK-NEXT: %[[EL:.*]] = llvm.extractelement %[[SLICE]]{{\[}}%[[X_I32]] : i64] : vector<[4]xi32>
+ %el = vector.extract %tile[%y,%x] : i32 from vector<[4]x[4]xi32>
+ return %el : i32
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/67786
More information about the Mlir-commits
mailing list