[Mlir-commits] [mlir] [mlir][ArmSME] Add arm_sme.move_tile_slice_to_vector op (PR #67652)

Benjamin Maxwell llvmlistbot at llvm.org
Thu Sep 28 05:18:00 PDT 2023


https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/67652

>From d09d976e08280ae6d6e3cfac0fba74f3730df946 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 28 Sep 2023 09:17:35 +0000
Subject: [PATCH 1/2] [mlir][ArmSME] Add arm_sme.move_tile_slice_to_vector op

This adds a simple higher-level op for the tile slice to vector
intrinsics (and updates the existing vector.print lowering to use it).
This op will be used a few more times to implement vector.insert/extract
lowerings in later patches.
---
 mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td | 33 +++++++
 .../Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp    | 32 +------
 .../Transforms/LegalizeForLLVMExport.cpp      | 51 ++++++++++-
 .../ArmSMEToSCF/arm-sme-to-scf.mlir           | 16 ++--
 mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir | 86 +++++++++++++++++++
 mlir/test/Dialect/ArmSME/invalid.mlir         |  8 ++
 mlir/test/Dialect/ArmSME/roundtrip.mlir       | 77 +++++++++++++++++
 7 files changed, 261 insertions(+), 42 deletions(-)

diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
index d10cee5956d5e5f..10b2ae32b06dd37 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
@@ -501,6 +501,39 @@ def MoveVectorToTileSliceOp : ArmSME_Op<"move_vector_to_tile_slice", [
   }];
 }
 
+def MoveTileSliceToVectorOp : ArmSME_Op<"move_tile_slice_to_vector", [Pure,
+    TypesMatchWith<
+      "type of 'result' matches type of 'tile' slice",
+      "tile", "result",
+      "VectorType(VectorType::Builder(::llvm::cast<mlir::VectorType>($_self)).dropDim(0))">,
+]> {
+  let summary = "Move slice of a 2-D tile to a 1-D scalable vector";
+  let description = [{
+    The tile slice to vector operation extracts a 1-D scalable slice from a 2-D
+    scalable tile at the give index. A tile slice is a 1-D vector of
+    horizontally or vertically contiguous elements within a ZA tile. Horizontal
+    tile slices are currently assumed when lowering to intrinsics.
+
+    Example 1: Extract `vector<[16]xi8>` from tile at the given index.
+    ```mlir
+    %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[16]xi8> from vector<[16]x[16]xi8>
+    ```
+
+    Example 2: Extract `vector<[2]xf64>` from tile at the given index.
+    ```mlir
+    %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[2]xf64> from vector<[2]x[2]xf64>
+    ```
+  }];
+
+  let arguments = (ins SMETile:$tile, Index:$tile_slice_index);
+  let results = (outs SVEVector:$result);
+
+  let assemblyFormat = [{
+      $tile `[` $tile_slice_index `]` attr-dict
+      `:` type($result) `from` type($tile)
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // ArmSME Intrinsic op definitions
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index d084373439ab6bf..ec625a6806e2bee 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -202,16 +202,11 @@ struct TileStoreOpConversion : public OpRewritePattern<arm_sme::TileStoreOp> {
 ///  %c0 = arith.constant 0 : index
 ///  %c1 = arith.constant 1 : index
 ///  %c4 = arith.constant 4 : index
-///  %ptrue = arith.constant dense<true> : vector<[4]xi1>
-///  %tile_id = arm_sme.cast_vector_to_tile %tile : vector<[4]x[4]xf32> to i32
 ///  %vscale = vector.vscale
 ///  %svl_s = arith.muli %c4, %vscale : index
-///  %cst = arith.constant dense<0.000000e+00> : vector<[4]xf32>
 ///  scf.for %i = %c0 to %svl_s step %c1 {
-///    %slice_idx = arith.index_cast %i : index to i32
-///    %tile_slice = "arm_sme.intr.read.horiz"
-///        (%cst, %ptrue, %tile_id, %slice_idx)
-///      : (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xf32>
+///    %tile_slice = arm_sme.move_tile_slice_to_vector %tile[%i]
+///                     : vector<[4]xf32> from vector<[4]x[4]xf32>
 ///    vector.print %tile_slice : vector<[4]xf32>
 ///  }
 ///  ```
@@ -229,23 +224,6 @@ struct TileVectorPrintOpConversion : public OpRewritePattern<vector::PrintOp> {
 
     auto loc = printOp.getLoc();
 
-    // Create an 'all true' predicate for each tile row.
-    auto predicateType =
-        VectorType::get(vectorType.getDimSize(1), rewriter.getI1Type(), true);
-    auto allTruePredicate = rewriter.create<arith::ConstantOp>(
-        loc, DenseElementsAttr::get(predicateType, true));
-
-    // Cast tile to i32 tile ID.
-    auto tileId =
-        rewriter.create<arm_sme::CastVectorToTile>(loc, printOp.getSource());
-    Value tileIdI32 = castTileIDToI32(tileId, loc, rewriter);
-
-    // Zero destination/fallback for tile slice extraction.
-    auto rowType = VectorType::get(vectorType.getDimSize(1),
-                                   vectorType.getElementType(), true);
-    auto zeroVector = rewriter.create<arith::ConstantOp>(
-        loc, rowType, rewriter.getZeroAttr(rowType));
-
     // Create a loop over the rows of the tile.
     auto vscale = rewriter.create<vector::VectorScaleOp>(loc);
     auto minTileRows =
@@ -259,10 +237,8 @@ struct TileVectorPrintOpConversion : public OpRewritePattern<vector::PrintOp> {
       rewriter.setInsertionPointToStart(forOp.getBody());
       // Extract the current row from the tile.
       Value rowIndex = forOp.getInductionVar();
-      auto rowIndexI32 = rewriter.create<arith::IndexCastOp>(
-          loc, rewriter.getI32Type(), rowIndex);
-      auto tileSlice = rewriter.create<arm_sme::aarch64_sme_read_horiz>(
-          loc, rowType, zeroVector, allTruePredicate, tileIdI32, rowIndexI32);
+      auto tileSlice = rewriter.create<arm_sme::MoveTileSliceToVectorOp>(
+          loc, printOp.getSource(), rowIndex);
       // Print the row with a 1D vector.print.
       rewriter.create<vector::PrintOp>(loc, tileSlice,
                                        printOp.getPunctuation());
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
index 945716eea57543d..7157a915ee8a02a 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
@@ -402,6 +402,48 @@ struct MoveVectorToTileSliceToArmSMELowering
   }
 };
 
+/// Lower `arm_sme.move_tile_slice_to_vector` to SME intrinsics. Only horizontal
+/// tile slices are currently supported.
+struct MoveTileSliceToVectorArmSMELowering
+    : public ConvertOpToLLVMPattern<arm_sme::MoveTileSliceToVectorOp> {
+  using ConvertOpToLLVMPattern<
+      arm_sme::MoveTileSliceToVectorOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(arm_sme::MoveTileSliceToVectorOp moveTileSliceToVector,
+                  OpAdaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = moveTileSliceToVector.getLoc();
+    auto sliceType = moveTileSliceToVector.getResult().getType();
+    auto tile = moveTileSliceToVector.getTile();
+    auto sliceIndex = moveTileSliceToVector.getTileSliceIndex();
+
+    // Cast tile to i32 tile ID.
+    auto tileId = rewriter.create<arm_sme::CastVectorToTile>(loc, tile);
+    Value tileIdI32 = castTileIDToI32(tileId, loc, rewriter);
+
+    // Create an 'all true' predicate for the tile slice.
+    auto predicateType = sliceType.cloneWith({}, rewriter.getI1Type());
+    auto allTruePredicate = rewriter.create<arith::ConstantOp>(
+        loc, DenseElementsAttr::get(predicateType, true));
+
+    // Zero destination/fallback for tile slice extraction.
+    auto zeroVector = rewriter.create<arith::ConstantOp>(
+        loc, sliceType, rewriter.getZeroAttr(sliceType));
+
+    // Cast tile slice from index to i32 for intrinsic.
+    auto sliceIndexI32 = rewriter.create<arith::IndexCastOp>(
+        loc, rewriter.getI32Type(), sliceIndex);
+
+    // Create 'arm_sme.intr.read.horiz' to extract the tile slice.
+    rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_read_horiz>(
+        moveTileSliceToVector, sliceType, zeroVector, allTruePredicate,
+        tileIdI32, sliceIndexI32);
+
+    return success();
+  }
+};
+
 /// Lower `vector.outerproduct` to SME MOPA intrinsics.
 ///
 /// Example:
@@ -525,9 +567,9 @@ void mlir::configureArmSMELegalizeForExportTarget(
       arm_sme::aarch64_sme_ld1d_vert, arm_sme::aarch64_sme_ld1q_vert,
       arm_sme::aarch64_sme_st1b_vert, arm_sme::aarch64_sme_st1h_vert,
       arm_sme::aarch64_sme_st1w_vert, arm_sme::aarch64_sme_st1d_vert,
-      arm_sme::aarch64_sme_st1q_vert, arm_sme::aarch64_sme_write_horiz,
-      arm_sme::aarch64_sme_mopa, arm_sme::aarch64_sme_za_enable,
-      arm_sme::aarch64_sme_za_disable>();
+      arm_sme::aarch64_sme_st1q_vert, arm_sme::aarch64_sme_read_horiz,
+      arm_sme::aarch64_sme_write_horiz, arm_sme::aarch64_sme_mopa,
+      arm_sme::aarch64_sme_za_enable, arm_sme::aarch64_sme_za_disable>();
   target.addLegalOp<GetTileID>();
   target.addIllegalOp<vector::OuterProductOp>();
 
@@ -561,6 +603,7 @@ void mlir::populateArmSMELegalizeForLLVMExportPatterns(
   patterns.add<EnableZAPattern, DisableZAPattern>(patterns.getContext());
   patterns
       .add<ZeroOpConversion, StoreTileSliceToArmSMELowering,
-           LoadTileSliceToArmSMELowering, MoveVectorToTileSliceToArmSMELowering,
+           LoadTileSliceToArmSMELowering, MoveTileSliceToVectorArmSMELowering,
+           MoveVectorToTileSliceToArmSMELowering,
            VectorOuterProductToArmSMELowering>(converter);
 }
diff --git a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
index b287a00171d9bd5..09f148bcd42f593 100644
--- a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
+++ b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
@@ -66,15 +66,11 @@ func.func @arm_sme_tile_print(%tile: vector<[4]x[4]xf32>)
 }
 // CHECK-LABEL:   func.func @arm_sme_tile_print(
 // CHECK-SAME:                                  %[[TILE:.*]]: vector<[4]x[4]xf32>) {
-// CHECK-DAG:       %[[C0:.*]] = arith.constant 0 : index
-// CHECK-DAG:       %[[C1:.*]] = arith.constant 1 : index
-// CHECK-DAG:       %[[C4:.*]] = arith.constant 4 : index
-// CHECK-DAG:       %[[VSCALE:.*]] = vector.vscale
-// CHECK-DAG:       %[[PTRUE:.*]] = arith.constant dense<true> : vector<[4]xi1>
-// CHECK-DAG:       %[[TILE_ID:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[4]x[4]xf32> to i32
-// CHECK-DAG:       %[[ZERO_VECTOR:.*]] = arith.constant dense<0.000000e+00> : vector<[4]xf32>
-// CHECK-DAG:       %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index
+// CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:     %[[C4:.*]] = arith.constant 4 : index
+// CHECK-DAG:     %[[VSCALE:.*]] = vector.vscale
+// CHECK-DAG:     %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index
 // CHECK-NEXT:      scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] {
-// CHECK-NEXT:        %[[TILE_SLICE_INDEX_I32:.*]] = arith.index_cast %[[TILE_SLICE_INDEX]] : index to i32
-// CHECK-NEXT:        %[[TILE_SLICE:.*]] = "arm_sme.intr.read.horiz"(%[[ZERO_VECTOR]], %[[PTRUE]], %[[TILE_ID]], %[[TILE_SLICE_INDEX_I32]]) : (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xf32>
+// CHECK-NEXT:        %[[TILE_SLICE:.*]] = arm_sme.move_tile_slice_to_vector %[[TILE]][%[[TILE_SLICE_INDEX]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
 // CHECK-NEXT:        vector.print %[[TILE_SLICE]] : vector<[4]xf32>
diff --git a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
index a6b5217181481de..4c16e5c488a74cd 100644
--- a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
+++ b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
@@ -399,3 +399,89 @@ func.func @arm_sme_store_tile_slice_ver_f64(%tile : vector<[2]x[2]xf64>, %tile_s
   arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0], <vertical> : memref<?x?xf64>, vector<[2]x[2]xf64>
   return
 }
+
+
+//===----------------------------------------------------------------------===//
+// arm_sme.move_tile_slice_to_vector
+//===----------------------------------------------------------------------===//
+
+// -----
+
+// CHECK-LABEL: @arm_sme_move_tile_slice_to_vector_i8
+// CHECK: "arm_sme.intr.read.horiz"({{.*}}) : (vector<[16]xi8>, vector<[16]xi1>, i32, i32) -> vector<[16]xi8>
+func.func @arm_sme_move_tile_slice_to_vector_i8(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index) -> vector<[16]xi8> {
+  %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[16]xi8> from vector<[16]x[16]xi8>
+  return %slice : vector<[16]xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_move_tile_slice_to_vector_i16
+// CHECK: "arm_sme.intr.read.horiz"({{.*}}) : (vector<[8]xi16>, vector<[8]xi1>, i32, i32) -> vector<[8]xi16>
+func.func @arm_sme_move_tile_slice_to_vector_i16(%tile : vector<[8]x[8]xi16>, %tile_slice_index : index) -> vector<[8]xi16> {
+  %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[8]xi16> from vector<[8]x[8]xi16>
+  return %slice : vector<[8]xi16>
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_move_tile_slice_to_vector_i32
+// CHECK: "arm_sme.intr.read.horiz"({{.*}}) : (vector<[4]xi32>, vector<[4]xi1>, i32, i32) -> vector<[4]xi32>
+func.func @arm_sme_move_tile_slice_to_vector_i32(%tile : vector<[4]x[4]xi32>, %tile_slice_index : index) -> vector<[4]xi32> {
+  %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[4]xi32> from vector<[4]x[4]xi32>
+  return %slice : vector<[4]xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_move_tile_slice_to_vector_i64
+// CHECK: "arm_sme.intr.read.horiz"({{.*}}) : (vector<[2]xi64>, vector<[2]xi1>, i32, i32) -> vector<[2]xi64>
+func.func @arm_sme_move_tile_slice_to_vector_i64(%tile : vector<[2]x[2]xi64>, %tile_slice_index : index) -> vector<[2]xi64> {
+  %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[2]xi64> from vector<[2]x[2]xi64>
+  return %slice : vector<[2]xi64>
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_move_tile_slice_to_vector_i128
+// CHECK: "arm_sme.intr.read.horiz"({{.*}}) : (vector<[1]xi128>, vector<[1]xi1>, i32, i32) -> vector<[1]xi128>
+func.func @arm_sme_move_tile_slice_to_vector_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index) -> vector<[1]xi128> {
+  %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[1]xi128> from vector<[1]x[1]xi128>
+  return %slice : vector<[1]xi128>
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_move_tile_slice_to_vector_f16
+// CHECK: "arm_sme.intr.read.horiz"({{.*}}) : (vector<[8]xf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xf16>
+func.func @arm_sme_move_tile_slice_to_vector_f16(%tile : vector<[8]x[8]xf16>, %tile_slice_index : index) -> vector<[8]xf16> {
+  %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[8]xf16> from vector<[8]x[8]xf16>
+  return %slice : vector<[8]xf16>
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_move_tile_slice_to_vector_bf16
+// CHECK: "arm_sme.intr.read.horiz"({{.*}}) : (vector<[8]xbf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xbf16>
+func.func @arm_sme_move_tile_slice_to_vector_bf16(%tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) -> vector<[8]xbf16> {
+  %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[8]xbf16> from vector<[8]x[8]xbf16>
+  return %slice : vector<[8]xbf16>
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_move_tile_slice_to_vector_f32
+// CHECK: "arm_sme.intr.read.horiz"({{.*}}) : (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xf32>
+func.func @arm_sme_move_tile_slice_to_vector_f32(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index) -> vector<[4]xf32> {
+  %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[4]xf32> from vector<[4]x[4]xf32>
+  return %slice : vector<[4]xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_move_tile_slice_to_vector_f64
+// CHECK: "arm_sme.intr.read.horiz"({{.*}}) : (vector<[2]xf64>, vector<[2]xi1>, i32, i32) -> vector<[2]xf64>
+func.func @arm_sme_move_tile_slice_to_vector_f64(%tile : vector<[2]x[2]xf64>, %tile_slice_index : index) -> vector<[2]xf64> {
+  %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[2]xf64> from vector<[2]x[2]xf64>
+  return %slice : vector<[2]xf64>
+}
diff --git a/mlir/test/Dialect/ArmSME/invalid.mlir b/mlir/test/Dialect/ArmSME/invalid.mlir
index f529b47cb98c3d9..431009b1b9ede2f 100644
--- a/mlir/test/Dialect/ArmSME/invalid.mlir
+++ b/mlir/test/Dialect/ArmSME/invalid.mlir
@@ -89,3 +89,11 @@ func.func @arm_sme_move_vector_to_tile_slice_f32__bad_vector_type(%vector : vect
   %0 = arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index : vector<[8]xf32> into vector<[4]x[4]xf32>
   return %0 : vector<[4]x[4]xf32>
 }
+
+// -----
+
+func.func @arm_sme_move_tile_slice_to_vector__bad_result_type(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index) -> vector<[2]xf64> {
+  // expected-error at +1 {{op failed to verify that type of 'result' matches type of 'tile' slice}}
+  %0 = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[2]xf64> from vector<[4]x[4]xf32>
+  return %0 : vector<[2]xf64>
+}
diff --git a/mlir/test/Dialect/ArmSME/roundtrip.mlir b/mlir/test/Dialect/ArmSME/roundtrip.mlir
index 3232bad7996b486..f6d19359b8e3af8 100644
--- a/mlir/test/Dialect/ArmSME/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSME/roundtrip.mlir
@@ -1058,3 +1058,80 @@ func.func @arm_sme_move_vector_to_tile_slice_f64(%vector : vector<[2]xf64>, %til
   arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index : vector<[2]xf64> into vector<[2]x[2]xf64>
   return
 }
+
+
+//===----------------------------------------------------------------------===//
+// arm_sme.move_tile_slice_to_vector
+//===----------------------------------------------------------------------===//
+
+// -----
+
+func.func @arm_sme_move_tile_slice_to_vector_i8(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index) -> vector<[16]xi8> {
+  // CHECK: arm_sme.move_tile_slice_to_vector {{.*}} : vector<[16]xi8> from vector<[16]x[16]xi8>
+  %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[16]xi8> from vector<[16]x[16]xi8>
+  return %slice : vector<[16]xi8>
+}
+
+// -----
+
+func.func @arm_sme_move_tile_slice_to_vector_i16(%tile : vector<[8]x[8]xi16>, %tile_slice_index : index) -> vector<[8]xi16> {
+  // CHECK: arm_sme.move_tile_slice_to_vector {{.*}} : vector<[8]xi16> from vector<[8]x[8]xi16>
+  %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[8]xi16> from vector<[8]x[8]xi16>
+  return %slice : vector<[8]xi16>
+}
+
+// -----
+
+func.func @arm_sme_move_tile_slice_to_vector_i32(%tile : vector<[4]x[4]xi32>, %tile_slice_index : index) -> vector<[4]xi32> {
+  // CHECK: arm_sme.move_tile_slice_to_vector {{.*}} : vector<[4]xi32> from vector<[4]x[4]xi32>
+  %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[4]xi32> from vector<[4]x[4]xi32>
+  return %slice : vector<[4]xi32>
+}
+
+// -----
+
+func.func @arm_sme_move_tile_slice_to_vector_i64(%tile : vector<[2]x[2]xi64>, %tile_slice_index : index) -> vector<[2]xi64> {
+  // CHECK: arm_sme.move_tile_slice_to_vector {{.*}} : vector<[2]xi64> from vector<[2]x[2]xi64>
+  %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[2]xi64> from vector<[2]x[2]xi64>
+  return %slice : vector<[2]xi64>
+}
+
+// -----
+
+func.func @arm_sme_move_tile_slice_to_vector_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index) -> vector<[1]xi128> {
+  // CHECK: arm_sme.move_tile_slice_to_vector {{.*}} : vector<[1]xi128> from vector<[1]x[1]xi128>
+  %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[1]xi128> from vector<[1]x[1]xi128>
+  return %slice : vector<[1]xi128>
+}
+
+// -----
+
+func.func @arm_sme_move_tile_slice_to_vector_f16(%tile : vector<[8]x[8]xf16>, %tile_slice_index : index) -> vector<[8]xf16> {
+  // CHECK: arm_sme.move_tile_slice_to_vector {{.*}} : vector<[8]xf16> from vector<[8]x[8]xf16>
+  %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[8]xf16> from vector<[8]x[8]xf16>
+  return %slice : vector<[8]xf16>
+}
+
+// -----
+
+func.func @arm_sme_move_tile_slice_to_vector_bf16(%tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) -> vector<[8]xbf16> {
+  // CHECK: arm_sme.move_tile_slice_to_vector {{.*}} : vector<[8]xbf16> from vector<[8]x[8]xbf16>
+  %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[8]xbf16> from vector<[8]x[8]xbf16>
+  return %slice : vector<[8]xbf16>
+}
+
+// -----
+
+func.func @arm_sme_move_tile_slice_to_vector_f32(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index) -> vector<[4]xf32> {
+  // CHECK: arm_sme.move_tile_slice_to_vector {{.*}} : vector<[4]xf32> from vector<[4]x[4]xf32>
+  %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[4]xf32> from vector<[4]x[4]xf32>
+  return %slice : vector<[4]xf32>
+}
+
+// -----
+
+func.func @arm_sme_move_tile_slice_to_vector_f64(%tile : vector<[2]x[2]xf64>, %tile_slice_index : index) -> vector<[2]xf64> {
+  // CHECK: arm_sme.move_tile_slice_to_vector {{.*}} : vector<[2]xf64> from vector<[2]x[2]xf64>
+  %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[2]xf64> from vector<[2]x[2]xf64>
+  return %slice : vector<[2]xf64>
+}

>From 097ffbfb6d999b6e24fff6d7053e3148357ab109 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 28 Sep 2023 12:16:41 +0000
Subject: [PATCH 2/2] Fixup: Add getSliceType()

---
 mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td                | 4 ++++
 mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp | 2 +-
 2 files changed, 5 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
index 10b2ae32b06dd37..4fd92dbbbf3391a 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
@@ -528,6 +528,10 @@ def MoveTileSliceToVectorOp : ArmSME_Op<"move_tile_slice_to_vector", [Pure,
   let arguments = (ins SMETile:$tile, Index:$tile_slice_index);
   let results = (outs SVEVector:$result);
 
+  let extraClassDeclaration = [{
+    VectorType getSliceType() { return getResult().getType(); }
+  }];
+
   let assemblyFormat = [{
       $tile `[` $tile_slice_index `]` attr-dict
       `:` type($result) `from` type($tile)
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
index 7157a915ee8a02a..0322c2f3fcd14d4 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
@@ -414,7 +414,7 @@ struct MoveTileSliceToVectorArmSMELowering
                   OpAdaptor,
                   ConversionPatternRewriter &rewriter) const override {
     auto loc = moveTileSliceToVector.getLoc();
-    auto sliceType = moveTileSliceToVector.getResult().getType();
+    auto sliceType = moveTileSliceToVector.getSliceType();
     auto tile = moveTileSliceToVector.getTile();
     auto sliceIndex = moveTileSliceToVector.getTileSliceIndex();
 



More information about the Mlir-commits mailing list