[Mlir-commits] [mlir] [mlir][ArmSME] Move vector.extract/insert lowerings to vector-to-arm-sme (NFC) (PR #72852)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Nov 20 03:21:24 PST 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-sme
Author: Benjamin Maxwell (MacDue)
<details>
<summary>Changes</summary>
These were placed in LegalizeForLLVMExport.cpp, which is the wrong stage for these, as these lower to high-level ArmSME ops, not intrinsics.
---
Full diff: https://github.com/llvm/llvm-project/pull/72852.diff
2 Files Affected:
- (modified) mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp (+113-1)
- (modified) mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp (+3-115)
``````````diff
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 953a465c18de69f..420d2b6b1c08786 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -576,6 +576,116 @@ struct VectorOuterProductToArmSMELowering
}
};
+/// Lower `vector.extract` using `arm_sme.move_tile_slice_to_vector`.
+///
+/// Example:
+/// ```
+/// %el = vector.extract %tile[%row, %col]: i32 from vector<[4]x[4]xi32>
+/// ```
+/// Becomes:
+/// ```
+/// %slice = arm_sme.move_tile_slice_to_vector %tile[%row]
+/// : vector<[4]xi32> from vector<[4]x[4]xi32>
+/// %el = vector.extract %slice[%col] : i32 from vector<[4]xi32>
+/// ```
+struct VectorExtractToArmSMELowering
+ : public OpRewritePattern<vector::ExtractOp> {
+ using OpRewritePattern<vector::ExtractOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
+ PatternRewriter &rewriter) const override {
+ VectorType sourceType = extractOp.getSourceVectorType();
+ if (!arm_sme::isValidSMETileVectorType(sourceType))
+ return failure();
+
+ auto loc = extractOp.getLoc();
+ auto position = extractOp.getMixedPosition();
+
+ Value sourceVector = extractOp.getVector();
+
+ // Extract entire vector. Should be handled by folder, but just to be safe.
+ if (position.empty()) {
+ rewriter.replaceOp(extractOp, sourceVector);
+ return success();
+ }
+
+ Value sliceIndex = vector::getAsValues(rewriter, loc, position[0]).front();
+ auto moveTileSliceToVector =
+ rewriter.create<arm_sme::MoveTileSliceToVectorOp>(loc, sourceVector,
+ sliceIndex);
+
+ if (position.size() == 1) {
+ // Single index case: Extracts a 1D slice.
+ rewriter.replaceOp(extractOp, moveTileSliceToVector);
+ return success();
+ }
+
+ // Two indices case: Extracts a single element.
+ assert(position.size() == 2);
+ rewriter.replaceOpWithNewOp<vector::ExtractOp>(
+ extractOp, moveTileSliceToVector, position[1]);
+
+ return success();
+ }
+};
+
+/// Lower `vector.insert` using `arm_sme.move_vector_to_tile_slice` and
+/// `arm_sme.move_tile_slice_to_vector`.
+///
+/// Example:
+/// ```
+/// %new_tile = vector.insert %el, %tile[%row, %col]
+/// : i32 into vector<[4]x[4]xi32>
+/// ```
+/// Becomes:
+/// ```
+/// %slice = arm_sme.move_tile_slice_to_vector %tile[%row]
+/// : vector<[4]xi32> from vector<[4]x[4]xi32>
+/// %new_slice = vector.insert %el, %slice[%col] : i32 into vector<[4]xi32>
+/// %new_tile = arm_sme.move_vector_to_tile_slice %new_slice, %tile, %row
+/// : vector<[4]xi32> into vector<[4]x[4]xi32>
+/// ```
+struct VectorInsertToArmSMELowering
+ : public OpRewritePattern<vector::InsertOp> {
+ using OpRewritePattern<vector::InsertOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::InsertOp insertOp,
+ PatternRewriter &rewriter) const override {
+ VectorType resultType = insertOp.getResult().getType();
+
+ if (!arm_sme::isValidSMETileVectorType(resultType))
+ return failure();
+
+ auto loc = insertOp.getLoc();
+ auto position = insertOp.getMixedPosition();
+
+ Value source = insertOp.getSource();
+
+ // Overwrite entire vector with value. Should be handled by folder, but
+ // just to be safe.
+ if (position.empty()) {
+ rewriter.replaceOp(insertOp, source);
+ return success();
+ }
+
+ Value tileSlice = source;
+ Value sliceIndex = vector::getAsValues(rewriter, loc, position[0]).front();
+ if (position.size() == 2) {
+ // Two indices case: Insert single element into tile.
+ // We need to first extract the existing slice and update the element.
+ tileSlice = rewriter.create<arm_sme::MoveTileSliceToVectorOp>(
+ loc, insertOp.getDest(), sliceIndex);
+ tileSlice = rewriter.create<vector::InsertOp>(loc, source, tileSlice,
+ position[1]);
+ }
+
+ // Insert the slice into the destination tile.
+ rewriter.replaceOpWithNewOp<arm_sme::MoveVectorToTileSliceOp>(
+ insertOp, tileSlice, insertOp.getDest(), sliceIndex);
+ return success();
+ }
+};
+
} // namespace
void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
@@ -584,5 +694,7 @@ void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
SplatOpToArmSMELowering, TransferReadToArmSMELowering,
TransferWriteToArmSMELowering, TransposeOpToArmSMELowering,
VectorLoadToArmSMELowering, VectorStoreToArmSMELowering,
- VectorOuterProductToArmSMELowering>(&ctx);
+ VectorOuterProductToArmSMELowering,
+ VectorExtractToArmSMELowering, VectorInsertToArmSMELowering>(
+ &ctx);
}
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
index 6078b3f2c5e4708..041a4897a836503 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
@@ -456,7 +456,8 @@ struct OuterProductOpConversion
// * half-precision - +sme2p1,+b16b16
//
// It should be possible to control lowering based on target features.
- // [1] https://developer.arm.com/downloads/-/exploration-tools/feature-names-for-a-profile
+ // [1]
+ // https://developer.arm.com/downloads/-/exploration-tools/feature-names-for-a-profile
if ((vectorType.getRank() != 2) || !vectorType.allDimsScalable())
return false;
@@ -520,118 +521,6 @@ struct OuterProductOpConversion
}
};
-/// Lower `vector.extract` using `arm_sme.move_tile_slice_to_vector`.
-///
-/// Example:
-/// ```
-/// %el = vector.extract %tile[%row, %col]: i32 from vector<[4]x[4]xi32>
-/// ```
-/// Becomes:
-/// ```
-/// %slice = arm_sme.move_tile_slice_to_vector %tile[%row]
-/// : vector<[4]xi32> from vector<[4]x[4]xi32>
-/// %el = vector.extract %slice[%col] : i32 from vector<[4]xi32>
-/// ```
-struct VectorExtractToArmSMELowering
- : public ConvertOpToLLVMPattern<vector::ExtractOp> {
- using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern;
-
- LogicalResult
- matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- VectorType sourceType = extractOp.getSourceVectorType();
- if (!isValidSMETileVectorType(sourceType))
- return failure();
-
- auto loc = extractOp.getLoc();
- auto position = extractOp.getMixedPosition();
-
- Value sourceVector = extractOp.getVector();
-
- // Extract entire vector. Should be handled by folder, but just to be safe.
- if (position.empty()) {
- rewriter.replaceOp(extractOp, sourceVector);
- return success();
- }
-
- Value sliceIndex = vector::getAsValues(rewriter, loc, position[0]).front();
- auto moveTileSliceToVector =
- rewriter.create<arm_sme::MoveTileSliceToVectorOp>(loc, sourceVector,
- sliceIndex);
-
- if (position.size() == 1) {
- // Single index case: Extracts a 1D slice.
- rewriter.replaceOp(extractOp, moveTileSliceToVector);
- return success();
- }
-
- // Two indices case: Extracts a single element.
- assert(position.size() == 2);
- rewriter.replaceOpWithNewOp<vector::ExtractOp>(
- extractOp, moveTileSliceToVector, position[1]);
-
- return success();
- }
-};
-
-/// Lower `vector.insert` using `arm_sme.move_vector_to_tile_slice` and
-/// `arm_sme.move_tile_slice_to_vector`.
-///
-/// Example:
-/// ```
-/// %new_tile = vector.insert %el, %tile[%row, %col]
-/// : i32 into vector<[4]x[4]xi32>
-/// ```
-/// Becomes:
-/// ```
-/// %slice = arm_sme.move_tile_slice_to_vector %tile[%row]
-/// : vector<[4]xi32> from vector<[4]x[4]xi32>
-/// %new_slice = vector.insert %el, %slice[%col] : i32 into vector<[4]xi32>
-/// %new_tile = arm_sme.move_vector_to_tile_slice %new_slice, %tile, %row
-/// : vector<[4]xi32> into vector<[4]x[4]xi32>
-/// ```
-struct VectorInsertToArmSMELowering
- : public ConvertOpToLLVMPattern<vector::InsertOp> {
- using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern;
-
- LogicalResult
- matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- VectorType resultType = insertOp.getResult().getType();
-
- if (!isValidSMETileVectorType(resultType))
- return failure();
-
- auto loc = insertOp.getLoc();
- auto position = insertOp.getMixedPosition();
-
- Value source = adaptor.getSource();
-
- // Overwrite entire vector with value. Should be handled by folder, but
- // just to be safe.
- if (position.empty()) {
- rewriter.replaceOp(insertOp, source);
- return success();
- }
-
- Value tileSlice = source;
- Value sliceIndex = vector::getAsValues(rewriter, loc, position[0]).front();
- if (position.size() == 2) {
- // Two indices case: Insert single element into tile.
- // We need to first extract the existing slice and update the element.
- tileSlice = rewriter.create<arm_sme::MoveTileSliceToVectorOp>(
- loc, adaptor.getDest(), sliceIndex);
- tileSlice = rewriter.create<vector::InsertOp>(loc, source, tileSlice,
- position[1]);
- }
-
- // Insert the slice into the destination tile.
- rewriter.replaceOpWithNewOp<arm_sme::MoveVectorToTileSliceOp>(
- insertOp, tileSlice, adaptor.getDest(), sliceIndex);
- return success();
- }
-};
-
} // namespace
void mlir::configureArmSMELegalizeForExportTarget(
@@ -661,6 +550,5 @@ void mlir::populateArmSMELegalizeForLLVMExportPatterns(
patterns.add<
LoadTileSliceToArmSMELowering, MoveTileSliceToVectorArmSMELowering,
MoveVectorToTileSliceToArmSMELowering, StoreTileSliceToArmSMELowering,
- OuterProductOpConversion, ZeroOpConversion, VectorExtractToArmSMELowering,
- VectorInsertToArmSMELowering>(converter);
+ OuterProductOpConversion, ZeroOpConversion>(converter);
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/72852
More information about the Mlir-commits
mailing list