[Mlir-commits] [mlir] [mlir][ArmSME] Move vector.extract/insert lowerings to vector-to-arm-sme (NFC) (PR #72852)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Nov 20 03:21:24 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-sme

Author: Benjamin Maxwell (MacDue)

<details>
<summary>Changes</summary>

These were placed in LegalizeForLLVMExport.cpp, which is the wrong stage for these, as these lower to high-level ArmSME ops, not intrinsics.

---
Full diff: https://github.com/llvm/llvm-project/pull/72852.diff


2 Files Affected:

- (modified) mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp (+113-1) 
- (modified) mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp (+3-115) 


``````````diff
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 953a465c18de69f..420d2b6b1c08786 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -576,6 +576,116 @@ struct VectorOuterProductToArmSMELowering
   }
 };
 
+/// Lower `vector.extract` using `arm_sme.move_tile_slice_to_vector`.
+///
+/// Example:
+/// ```
+/// %el = vector.extract %tile[%row, %col]: i32 from vector<[4]x[4]xi32>
+/// ```
+/// Becomes:
+/// ```
+/// %slice = arm_sme.move_tile_slice_to_vector %tile[%row]
+///            : vector<[4]xi32> from vector<[4]x[4]xi32>
+/// %el = vector.extract %slice[%col] : i32 from vector<[4]xi32>
+/// ```
+struct VectorExtractToArmSMELowering
+    : public OpRewritePattern<vector::ExtractOp> {
+  using OpRewritePattern<vector::ExtractOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
+                                PatternRewriter &rewriter) const override {
+    VectorType sourceType = extractOp.getSourceVectorType();
+    if (!arm_sme::isValidSMETileVectorType(sourceType))
+      return failure();
+
+    auto loc = extractOp.getLoc();
+    auto position = extractOp.getMixedPosition();
+
+    Value sourceVector = extractOp.getVector();
+
+    // Extract entire vector. Should be handled by folder, but just to be safe.
+    if (position.empty()) {
+      rewriter.replaceOp(extractOp, sourceVector);
+      return success();
+    }
+
+    Value sliceIndex = vector::getAsValues(rewriter, loc, position[0]).front();
+    auto moveTileSliceToVector =
+        rewriter.create<arm_sme::MoveTileSliceToVectorOp>(loc, sourceVector,
+                                                          sliceIndex);
+
+    if (position.size() == 1) {
+      // Single index case: Extracts a 1D slice.
+      rewriter.replaceOp(extractOp, moveTileSliceToVector);
+      return success();
+    }
+
+    // Two indices case: Extracts a single element.
+    assert(position.size() == 2);
+    rewriter.replaceOpWithNewOp<vector::ExtractOp>(
+        extractOp, moveTileSliceToVector, position[1]);
+
+    return success();
+  }
+};
+
+/// Lower `vector.insert` using `arm_sme.move_vector_to_tile_slice` and
+/// `arm_sme.move_tile_slice_to_vector`.
+///
+/// Example:
+/// ```
+/// %new_tile = vector.insert %el, %tile[%row, %col]
+///                     : i32 into vector<[4]x[4]xi32>
+/// ```
+/// Becomes:
+/// ```
+/// %slice = arm_sme.move_tile_slice_to_vector %tile[%row]
+///            : vector<[4]xi32> from vector<[4]x[4]xi32>
+/// %new_slice = vector.insert %el, %slice[%col] : i32 into vector<[4]xi32>
+/// %new_tile = arm_sme.move_vector_to_tile_slice %new_slice, %tile, %row
+///               : vector<[4]xi32> into vector<[4]x[4]xi32>
+/// ```
+struct VectorInsertToArmSMELowering
+    : public OpRewritePattern<vector::InsertOp> {
+  using OpRewritePattern<vector::InsertOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::InsertOp insertOp,
+                                PatternRewriter &rewriter) const override {
+    VectorType resultType = insertOp.getResult().getType();
+
+    if (!arm_sme::isValidSMETileVectorType(resultType))
+      return failure();
+
+    auto loc = insertOp.getLoc();
+    auto position = insertOp.getMixedPosition();
+
+    Value source = insertOp.getSource();
+
+    // Overwrite entire vector with value. Should be handled by folder, but
+    // just to be safe.
+    if (position.empty()) {
+      rewriter.replaceOp(insertOp, source);
+      return success();
+    }
+
+    Value tileSlice = source;
+    Value sliceIndex = vector::getAsValues(rewriter, loc, position[0]).front();
+    if (position.size() == 2) {
+      // Two indices case: Insert single element into tile.
+      // We need to first extract the existing slice and update the element.
+      tileSlice = rewriter.create<arm_sme::MoveTileSliceToVectorOp>(
+          loc, insertOp.getDest(), sliceIndex);
+      tileSlice = rewriter.create<vector::InsertOp>(loc, source, tileSlice,
+                                                    position[1]);
+    }
+
+    // Insert the slice into the destination tile.
+    rewriter.replaceOpWithNewOp<arm_sme::MoveVectorToTileSliceOp>(
+        insertOp, tileSlice, insertOp.getDest(), sliceIndex);
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
@@ -584,5 +694,7 @@ void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
                SplatOpToArmSMELowering, TransferReadToArmSMELowering,
                TransferWriteToArmSMELowering, TransposeOpToArmSMELowering,
                VectorLoadToArmSMELowering, VectorStoreToArmSMELowering,
-               VectorOuterProductToArmSMELowering>(&ctx);
+               VectorOuterProductToArmSMELowering,
+               VectorExtractToArmSMELowering, VectorInsertToArmSMELowering>(
+      &ctx);
 }
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
index 6078b3f2c5e4708..041a4897a836503 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
@@ -456,7 +456,8 @@ struct OuterProductOpConversion
       //   * half-precision   - +sme2p1,+b16b16
       //
       // It should be possible to control lowering based on target features.
-      // [1] https://developer.arm.com/downloads/-/exploration-tools/feature-names-for-a-profile
+      // [1]
+      // https://developer.arm.com/downloads/-/exploration-tools/feature-names-for-a-profile
       if ((vectorType.getRank() != 2) || !vectorType.allDimsScalable())
         return false;
 
@@ -520,118 +521,6 @@ struct OuterProductOpConversion
   }
 };
 
-/// Lower `vector.extract` using `arm_sme.move_tile_slice_to_vector`.
-///
-/// Example:
-/// ```
-/// %el = vector.extract %tile[%row, %col]: i32 from vector<[4]x[4]xi32>
-/// ```
-/// Becomes:
-/// ```
-/// %slice = arm_sme.move_tile_slice_to_vector %tile[%row]
-///            : vector<[4]xi32> from vector<[4]x[4]xi32>
-/// %el = vector.extract %slice[%col] : i32 from vector<[4]xi32>
-/// ```
-struct VectorExtractToArmSMELowering
-    : public ConvertOpToLLVMPattern<vector::ExtractOp> {
-  using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern;
-
-  LogicalResult
-  matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    VectorType sourceType = extractOp.getSourceVectorType();
-    if (!isValidSMETileVectorType(sourceType))
-      return failure();
-
-    auto loc = extractOp.getLoc();
-    auto position = extractOp.getMixedPosition();
-
-    Value sourceVector = extractOp.getVector();
-
-    // Extract entire vector. Should be handled by folder, but just to be safe.
-    if (position.empty()) {
-      rewriter.replaceOp(extractOp, sourceVector);
-      return success();
-    }
-
-    Value sliceIndex = vector::getAsValues(rewriter, loc, position[0]).front();
-    auto moveTileSliceToVector =
-        rewriter.create<arm_sme::MoveTileSliceToVectorOp>(loc, sourceVector,
-                                                          sliceIndex);
-
-    if (position.size() == 1) {
-      // Single index case: Extracts a 1D slice.
-      rewriter.replaceOp(extractOp, moveTileSliceToVector);
-      return success();
-    }
-
-    // Two indices case: Extracts a single element.
-    assert(position.size() == 2);
-    rewriter.replaceOpWithNewOp<vector::ExtractOp>(
-        extractOp, moveTileSliceToVector, position[1]);
-
-    return success();
-  }
-};
-
-/// Lower `vector.insert` using `arm_sme.move_vector_to_tile_slice` and
-/// `arm_sme.move_tile_slice_to_vector`.
-///
-/// Example:
-/// ```
-/// %new_tile = vector.insert %el, %tile[%row, %col]
-///                     : i32 into vector<[4]x[4]xi32>
-/// ```
-/// Becomes:
-/// ```
-/// %slice = arm_sme.move_tile_slice_to_vector %tile[%row]
-///            : vector<[4]xi32> from vector<[4]x[4]xi32>
-/// %new_slice = vector.insert %el, %slice[%col] : i32 into vector<[4]xi32>
-/// %new_tile = arm_sme.move_vector_to_tile_slice %new_slice, %tile, %row
-///               : vector<[4]xi32> into vector<[4]x[4]xi32>
-/// ```
-struct VectorInsertToArmSMELowering
-    : public ConvertOpToLLVMPattern<vector::InsertOp> {
-  using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern;
-
-  LogicalResult
-  matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    VectorType resultType = insertOp.getResult().getType();
-
-    if (!isValidSMETileVectorType(resultType))
-      return failure();
-
-    auto loc = insertOp.getLoc();
-    auto position = insertOp.getMixedPosition();
-
-    Value source = adaptor.getSource();
-
-    // Overwrite entire vector with value. Should be handled by folder, but
-    // just to be safe.
-    if (position.empty()) {
-      rewriter.replaceOp(insertOp, source);
-      return success();
-    }
-
-    Value tileSlice = source;
-    Value sliceIndex = vector::getAsValues(rewriter, loc, position[0]).front();
-    if (position.size() == 2) {
-      // Two indices case: Insert single element into tile.
-      // We need to first extract the existing slice and update the element.
-      tileSlice = rewriter.create<arm_sme::MoveTileSliceToVectorOp>(
-          loc, adaptor.getDest(), sliceIndex);
-      tileSlice = rewriter.create<vector::InsertOp>(loc, source, tileSlice,
-                                                    position[1]);
-    }
-
-    // Insert the slice into the destination tile.
-    rewriter.replaceOpWithNewOp<arm_sme::MoveVectorToTileSliceOp>(
-        insertOp, tileSlice, adaptor.getDest(), sliceIndex);
-    return success();
-  }
-};
-
 } // namespace
 
 void mlir::configureArmSMELegalizeForExportTarget(
@@ -661,6 +550,5 @@ void mlir::populateArmSMELegalizeForLLVMExportPatterns(
   patterns.add<
       LoadTileSliceToArmSMELowering, MoveTileSliceToVectorArmSMELowering,
       MoveVectorToTileSliceToArmSMELowering, StoreTileSliceToArmSMELowering,
-      OuterProductOpConversion, ZeroOpConversion, VectorExtractToArmSMELowering,
-      VectorInsertToArmSMELowering>(converter);
+      OuterProductOpConversion, ZeroOpConversion>(converter);
 }

``````````

</details>


https://github.com/llvm/llvm-project/pull/72852


More information about the Mlir-commits mailing list