[Mlir-commits] [mlir] [mlir][ArmSME] Add tile slice layout attr to vector <-> tile ops (PR #69186)

Cullen Rhodes llvmlistbot at llvm.org
Mon Oct 16 03:09:51 PDT 2023


https://github.com/c-rhodes created https://github.com/llvm/llvm-project/pull/69186

This is used in #69148 when lowering masked tile_store with non-zero pad, see https://github.com/llvm/llvm-project/pull/69148/commits/8589e503f836de48356cc4bbed105d661767f148

>From 497ff8469720dcce882c6e3c67ca689ff0f96273 Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Sun, 15 Oct 2023 10:19:34 +0000
Subject: [PATCH] [mlir][ArmSME] Add tile slice layout attr to vector <-> tile
 ops

---
 .../mlir/Dialect/ArmSME/IR/ArmSMEOps.td       | 36 ++++++++++------
 .../Transforms/LegalizeForLLVMExport.cpp      | 43 +++++++++++++------
 mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir | 32 ++++++++++++++
 mlir/test/Dialect/ArmSME/roundtrip.mlir       | 16 +++++++
 4 files changed, 100 insertions(+), 27 deletions(-)

diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
index dab54b63d8d22be..9b9dbff10ea2da6 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
@@ -441,21 +441,24 @@ def MoveVectorToTileSliceOp : ArmSME_Op<"move_vector_to_tile_slice", [
     of a 2-D scalable vector tile at the given index. The type of the 1-D
     scalable vector to be moved must match the type of the tile slice. A tile
     slice is a 1-D vector of horizontally or vertically contiguous elements
-    within a ZA tile. Horizontal tile slices are currently assumed when
-    lowering to intrinsics. The updated tile is returned as the result.
+    within a ZA tile. The updated tile is returned as the result.
 
-    Example 1: Move a vector<[16]xi8> into tile at given index.
+    An optional tile slice layout attribute specifies whether the tile slice is
+    horizontal (default) or vertical.
+
+    Example 1: Move a vector<[16]xi8> into tile horizontally (default) at given index.
     ```mlir
     %tile_update = arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index : vector<[16]xi8> into vector<[16]x[16]xi8>
     ```
 
-    Example 2: Move a vector<[2]xf64> into tile at given index.
+    Example 2: Move a vector<[2]xf64> into tile vertically at given index.
     ```mlir
-    %tile_update = arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index : vector<[2]xf64> into vector<[2]x[2]xf64>
+    %tile_update = arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index layout<vertical> : vector<[2]xf64> into vector<[2]x[2]xf64>
     ```
   }];
   let arguments = (ins
-      SVEVector:$vector, SMETile:$tile, Index:$tile_slice_index);
+      SVEVector:$vector, SMETile:$tile, Index:$tile_slice_index,
+      ArmSME_TileSliceLayoutAttr:$layout);
   let results = (outs SMETile:$result);
 
   let extraClassDeclaration = [{
@@ -465,7 +468,7 @@ def MoveVectorToTileSliceOp : ArmSME_Op<"move_vector_to_tile_slice", [
   }];
 
   let assemblyFormat = [{
-    $vector `,` $tile `,` $tile_slice_index
+    $vector `,` $tile `,` $tile_slice_index (`layout` `` $layout^)?
       attr-dict `:` type($vector) `into` type($result)
   }];
 }
@@ -480,21 +483,26 @@ def MoveTileSliceToVectorOp : ArmSME_Op<"move_tile_slice_to_vector", [Pure,
   let description = [{
     The tile slice to vector operation extracts a 1-D scalable slice from a 2-D
     scalable tile at the given index. A tile slice is a 1-D vector of
-    horizontally or vertically contiguous elements within a ZA tile. Horizontal
-    tile slices are currently assumed when lowering to intrinsics.
+    horizontally or vertically contiguous elements within a ZA tile.
+
+    An optional tile slice layout attribute specifies whether the tile slice is
+    horizontal (default) or vertical.
 
-    Example 1: Extract `vector<[16]xi8>` from tile at the given index.
+    Example 1: Extract `vector<[16]xi8>` from tile horizontally at the given index.
     ```mlir
     %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[16]xi8> from vector<[16]x[16]xi8>
     ```
 
-    Example 2: Extract `vector<[2]xf64>` from tile at the given index.
+    Example 2: Extract `vector<[2]xf64>` from tile vertically at the given index.
     ```mlir
-    %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[2]xf64> from vector<[2]x[2]xf64>
+    %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] layout<vertical> : vector<[2]xf64> from vector<[2]x[2]xf64>
     ```
   }];
 
-  let arguments = (ins SMETile:$tile, Index:$tile_slice_index);
+  let arguments = (ins
+    SMETile:$tile, Index:$tile_slice_index,
+    ArmSME_TileSliceLayoutAttr:$layout
+  );
   let results = (outs SVEVector:$result);
 
   let extraClassDeclaration = [{
@@ -502,7 +510,7 @@ def MoveTileSliceToVectorOp : ArmSME_Op<"move_tile_slice_to_vector", [Pure,
   }];
 
   let assemblyFormat = [{
-      $tile `[` $tile_slice_index `]` attr-dict
+      $tile `[` $tile_slice_index `]` (`layout` `` $layout^)? attr-dict
       `:` type($result) `from` type($tile)
   }];
 }
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
index 5e13707ea0aa2b9..1231da356f8ed95 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
@@ -350,8 +350,7 @@ struct StoreTileSliceToArmSMELowering
   }
 };
 
-/// Lower `arm_sme.move_vector_to_tile_slice` to SME intrinsics. Only horizontal
-/// tile slices are currently supported.
+/// Lower `arm_sme.move_vector_to_tile_slice` to SME intrinsics.
 struct MoveVectorToTileSliceToArmSMELowering
     : public ConvertOpToLLVMPattern<arm_sme::MoveVectorToTileSliceOp> {
   using ConvertOpToLLVMPattern<
@@ -388,10 +387,19 @@ struct MoveVectorToTileSliceToArmSMELowering
 
     auto tileI32 = castTileIDToI32(tile, loc, rewriter);
 
-    // Create 'arm_sme.intr.write.horiz' to write vector to tile slice.
-    rewriter.create<arm_sme::aarch64_sme_write_horiz>(
-        loc, tileI32, tileSliceI32, allActiveMask,
-        moveVectorToTileSliceOp.getVector());
+    // Create 'arm_sme.intr.write.(horiz|vert)' to write vector to tile slice.
+    switch (moveVectorToTileSliceOp.getLayout()) {
+    case arm_sme::TileSliceLayout::Horizontal:
+      rewriter.create<arm_sme::aarch64_sme_write_horiz>(
+          loc, tileI32, tileSliceI32, allActiveMask,
+          moveVectorToTileSliceOp.getVector());
+      break;
+    case arm_sme::TileSliceLayout::Vertical:
+      rewriter.create<arm_sme::aarch64_sme_write_vert>(
+          loc, tileI32, tileSliceI32, allActiveMask,
+          moveVectorToTileSliceOp.getVector());
+      break;
+    }
 
     // Intrinsic has no result, replace 'arm_sme.move_vector_to_tile_slice' with
     // 'arm_sme.cast_tile_to_vector' to preserve dataflow.
@@ -402,8 +410,7 @@ struct MoveVectorToTileSliceToArmSMELowering
   }
 };
 
-/// Lower `arm_sme.move_tile_slice_to_vector` to SME intrinsics. Only horizontal
-/// tile slices are currently supported.
+/// Lower `arm_sme.move_tile_slice_to_vector` to SME intrinsics.
 struct MoveTileSliceToVectorArmSMELowering
     : public ConvertOpToLLVMPattern<arm_sme::MoveTileSliceToVectorOp> {
   using ConvertOpToLLVMPattern<
@@ -435,10 +442,19 @@ struct MoveTileSliceToVectorArmSMELowering
     auto sliceIndexI32 = rewriter.create<arith::IndexCastOp>(
         loc, rewriter.getI32Type(), sliceIndex);
 
-    // Create 'arm_sme.intr.read.horiz' to extract the tile slice.
-    rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_read_horiz>(
-        moveTileSliceToVector, sliceType, zeroVector, allTruePredicate,
-        tileIdI32, sliceIndexI32);
+    // Create 'arm_sme.intr.read.(horiz|vert)' to extract the tile slice.
+    switch (moveTileSliceToVector.getLayout()) {
+    case arm_sme::TileSliceLayout::Horizontal:
+      rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_read_horiz>(
+          moveTileSliceToVector, sliceType, zeroVector, allTruePredicate,
+          tileIdI32, sliceIndexI32);
+      break;
+    case arm_sme::TileSliceLayout::Vertical:
+      rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_read_vert>(
+          moveTileSliceToVector, sliceType, zeroVector, allTruePredicate,
+          tileIdI32, sliceIndexI32);
+      break;
+    }
 
     return success();
   }
@@ -680,7 +696,8 @@ void mlir::configureArmSMELegalizeForExportTarget(
       arm_sme::aarch64_sme_st1b_vert, arm_sme::aarch64_sme_st1h_vert,
       arm_sme::aarch64_sme_st1w_vert, arm_sme::aarch64_sme_st1d_vert,
       arm_sme::aarch64_sme_st1q_vert, arm_sme::aarch64_sme_read_horiz,
-      arm_sme::aarch64_sme_write_horiz, arm_sme::aarch64_sme_mopa,
+      arm_sme::aarch64_sme_read_vert, arm_sme::aarch64_sme_write_horiz,
+      arm_sme::aarch64_sme_write_vert, arm_sme::aarch64_sme_mopa,
       arm_sme::aarch64_sme_za_enable, arm_sme::aarch64_sme_za_disable>();
   target.addLegalOp<GetTileID>();
   target.addIllegalOp<vector::OuterProductOp>();
diff --git a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
index 07485b3ee8ddf86..9074f0a7ee655c1 100644
--- a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
+++ b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
@@ -400,6 +400,29 @@ func.func @arm_sme_store_tile_slice_ver_f64(%tile : vector<[2]x[2]xf64>, %tile_s
   return
 }
 
+//===----------------------------------------------------------------------===//
+// arm_sme.move_vector_to_tile_slice
+//===----------------------------------------------------------------------===//
+
+// -----
+
+// CHECK-LABEL: @arm_sme_move_vector_to_tile_slice_hor_i32
+// CHECK: "arm_sme.intr.write.horiz"({{.*}}) : (i32, i32, vector<[4]xi1>, vector<[4]xi32>) -> ()
+func.func @arm_sme_move_vector_to_tile_slice_hor_i32(%vector : vector<[4]xi32>, %tile : vector<[4]x[4]xi32>, %tile_slice_index : index) -> () {
+  %c0 = arith.constant 0 : index
+  arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index : vector<[4]xi32> into vector<[4]x[4]xi32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_move_vector_to_tile_slice_ver_bf16
+// CHECK: "arm_sme.intr.write.vert"({{.*}}) : (i32, i32, vector<[8]xi1>, vector<[8]xbf16>) -> ()
+func.func @arm_sme_move_vector_to_tile_slice_ver_bf16(%vector : vector<[8]xbf16>, %tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) -> () {
+  %c0 = arith.constant 0 : index
+  arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index layout<vertical> : vector<[8]xbf16> into vector<[8]x[8]xbf16>
+  return
+}
 
 //===----------------------------------------------------------------------===//
 // arm_sme.move_tile_slice_to_vector
@@ -485,3 +508,12 @@ func.func @arm_sme_move_tile_slice_to_vector_f64(%tile : vector<[2]x[2]xf64>, %t
   %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[2]xf64> from vector<[2]x[2]xf64>
   return %slice : vector<[2]xf64>
 }
+
+// -----
+
+// CHECK-LABEL: @arm_sme_move_tile_slice_to_vector_ver_i128
+// CHECK: "arm_sme.intr.read.vert"({{.*}}) : (vector<[1]xi128>, vector<[1]xi1>, i32, i32) -> vector<[1]xi128>
+func.func @arm_sme_move_tile_slice_to_vector_ver_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index) -> vector<[1]xi128> {
+  %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] layout<vertical> : vector<[1]xi128> from vector<[1]x[1]xi128>
+  return %slice : vector<[1]xi128>
+}
diff --git a/mlir/test/Dialect/ArmSME/roundtrip.mlir b/mlir/test/Dialect/ArmSME/roundtrip.mlir
index 427154158e797fd..e5ba81eff836027 100644
--- a/mlir/test/Dialect/ArmSME/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSME/roundtrip.mlir
@@ -1059,6 +1059,14 @@ func.func @arm_sme_move_vector_to_tile_slice_f64(%vector : vector<[2]xf64>, %til
   return
 }
 
+// -----
+
+func.func @arm_sme_move_vector_to_tile_slice_ver_i8(%vector : vector<[16]xi8>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) -> () {
+  // CHECK: arm_sme.move_vector_to_tile_slice {{.*}} layout<vertical> : vector<[16]xi8> into vector<[16]x[16]xi8>
+  %c0 = arith.constant 0 : index
+  arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index layout<vertical> : vector<[16]xi8> into vector<[16]x[16]xi8>
+  return
+}
 
 //===----------------------------------------------------------------------===//
 // arm_sme.move_tile_slice_to_vector
@@ -1135,3 +1143,11 @@ func.func @arm_sme_move_tile_slice_to_vector_f64(%tile : vector<[2]x[2]xf64>, %t
   %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[2]xf64> from vector<[2]x[2]xf64>
   return %slice : vector<[2]xf64>
 }
+
+// -----
+
+func.func @arm_sme_move_tile_slice_to_vector_ver_f64(%tile : vector<[2]x[2]xf64>, %tile_slice_index : index) -> vector<[2]xf64> {
+  // CHECK: arm_sme.move_tile_slice_to_vector {{.*}} layout<vertical> : vector<[2]xf64> from vector<[2]x[2]xf64>
+  %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] layout<vertical> : vector<[2]xf64> from vector<[2]x[2]xf64>
+  return %slice : vector<[2]xf64>
+}



More information about the Mlir-commits mailing list