[Mlir-commits] [mlir] 97da414 - [mlir][ArmSME] Lower loads/stores of (.Q) 128-bit tiles to intrinsics

Benjamin Maxwell llvmlistbot at llvm.org
Wed Aug 23 02:17:13 PDT 2023


Author: Benjamin Maxwell
Date: 2023-08-23T09:16:20Z
New Revision: 97da41418226bb426ae67b1e5b9c0ea3255042f2

URL: https://github.com/llvm/llvm-project/commit/97da41418226bb426ae67b1e5b9c0ea3255042f2
DIFF: https://github.com/llvm/llvm-project/commit/97da41418226bb426ae67b1e5b9c0ea3255042f2.diff

LOG: [mlir][ArmSME] Lower loads/stores of (.Q) 128-bit tiles to intrinsics

This follows from D155306.

Loads and stores of 128-bit tiles have been confirmed to work in the
`load-store-128-bit-tile.mlir` integration test. However, there is
currently a bug in QEMU (see: https://gitlab.com/qemu-project/qemu/-/issues/1833)
which means this test produces incorrect results (a patch for this issue
is available but not yet in any released version of QEMU). Until a
fixed version of QEMU is available the integration test is expected to fail.

Reviewed By: c-rhodes, awarzynski

Differential Revision: https://reviews.llvm.org/D158418

Added: 
    mlir/test/Integration/Dialect/Vector/CPU/ArmSME/load-store-128-bit-tile.mlir

Modified: 
    mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
    mlir/lib/Dialect/ArmSME/Utils/Utils.cpp
    mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
index 2f4dee7ba916e8..e846b63b011a62 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
@@ -223,6 +223,10 @@ struct LoadTileSliceToArmSMELowering
       rewriter.create<arm_sme::aarch64_sme_ld1d_horiz>(loc, allActiveMask, ptr,
                                                        tileI32, tileSliceI32);
       break;
+    case 128:
+      rewriter.create<arm_sme::aarch64_sme_ld1q_horiz>(loc, allActiveMask, ptr,
+                                                       tileI32, tileSliceI32);
+      break;
     }
 
     // The load intrinsics have no result, replace 'arm_sme.tile_load' with
@@ -294,6 +298,10 @@ struct StoreTileSliceToArmSMELowering
       rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1d_horiz>(
           storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
       break;
+    case 128:
+      rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1q_horiz>(
+          storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
+      break;
     }
 
     return success();
@@ -309,9 +317,10 @@ void mlir::configureArmSMELegalizeForExportTarget(
       arm_sme::CastVectorToTile, arm_sme::aarch64_sme_zero,
       arm_sme::aarch64_sme_str, arm_sme::aarch64_sme_ld1b_horiz,
       arm_sme::aarch64_sme_ld1h_horiz, arm_sme::aarch64_sme_ld1w_horiz,
-      arm_sme::aarch64_sme_ld1d_horiz, arm_sme::aarch64_sme_st1b_horiz,
-      arm_sme::aarch64_sme_st1h_horiz, arm_sme::aarch64_sme_st1w_horiz,
-      arm_sme::aarch64_sme_st1d_horiz, arm_sme::aarch64_sme_za_enable,
+      arm_sme::aarch64_sme_ld1d_horiz, arm_sme::aarch64_sme_ld1q_horiz,
+      arm_sme::aarch64_sme_st1b_horiz, arm_sme::aarch64_sme_st1h_horiz,
+      arm_sme::aarch64_sme_st1w_horiz, arm_sme::aarch64_sme_st1d_horiz,
+      arm_sme::aarch64_sme_st1q_horiz, arm_sme::aarch64_sme_za_enable,
       arm_sme::aarch64_sme_za_disable>();
   target.addLegalOp<GetTileID>();
 

diff  --git a/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp b/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp
index a5908a5a8f330f..cc9e36cfca4193 100644
--- a/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp
@@ -25,17 +25,15 @@ unsigned mlir::arm_sme::getSMETileSliceMinNumElts(Type type) {
 }
 
 bool mlir::arm_sme::isValidSMETileElementType(Type type) {
-  // TODO: add support for i128.
   return type.isInteger(8) || type.isInteger(16) || type.isInteger(32) ||
-         type.isInteger(64) || type.isF16() || type.isBF16() || type.isF32() ||
-         type.isF64();
+         type.isInteger(64) || type.isInteger(128) || type.isF16() ||
+         type.isBF16() || type.isF32() || type.isF64() || type.isF128();
 }
 
 bool mlir::arm_sme::isValidSMETileVectorType(VectorType vType) {
   if ((vType.getRank() != 2) && vType.allDimsScalable())
     return false;
 
-  // TODO: add support for i128.
   auto elemType = vType.getElementType();
   if (!isValidSMETileElementType(elemType))
     return false;

diff  --git a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
index de8bc5f93b2c76..af528295ef6ee2 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
@@ -220,6 +220,20 @@ func.func @vector_load_f64(%arg0 : memref<?x?xf64>) -> vector<[2]x[2]xf64> {
 
 // -----
 
+// CHECK-LABEL: @vector_load_i128(
+// CHECK-SAME:                    %[[ARG0:.*]]: memref<?x?xi128>)
+// CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i128
+// CHECK-DAG: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i128 to vector<[1]x[1]xi128>
+// CHECK:       %[[TILE_ID_I32:.*]] = arith.trunci %[[TILE_ID]] : i128 to i32
+// CHECK:       arm_sme.intr.ld1q.horiz
+func.func @vector_load_i128(%arg0 : memref<?x?xi128>) -> vector<[1]x[1]xi128> {
+  %c0 = arith.constant 0 : index
+  %tile = vector.load %arg0[%c0, %c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
+  return %tile : vector<[1]x[1]xi128>
+}
+
+// -----
+
 // CHECK-LABEL: @vector_store_i8(
 // CHECK-SAME:                   %[[TILE:.*]]: vector<[16]x[16]xi8>,
 // CHECK-SAME:                   %[[ARG0:.*]]: memref<?x?xi8>)
@@ -363,3 +377,17 @@ func.func @vector_store_f64(%tile : vector<[2]x[2]xf64>, %arg0 : memref<?x?xf64>
   vector.store %tile, %arg0[%c0, %c0] : memref<?x?xf64>, vector<[2]x[2]xf64>
   return
 }
+
+// -----
+
+// CHECK-LABEL: @vector_store_i128(
+// CHECK-SAME:                     %[[TILE:.*]]: vector<[1]x[1]xi128>,
+// CHECK-SAME:                     %[[ARG0:.*]]: memref<?x?xi128>)
+// CHECK:       %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[1]x[1]xi128> to i128
+// CHECK:       %[[TILE_ID_I32:.*]] = arith.trunci %[[CAST_VECTOR_TO_TILE]] : i128 to i32
+// CHECK:       arm_sme.intr.st1q.horiz
+func.func @vector_store_i128(%tile : vector<[1]x[1]xi128>, %arg0 : memref<?x?xi128>) {
+  %c0 = arith.constant 0 : index
+  vector.store %tile, %arg0[%c0, %c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
+  return
+}

diff  --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/load-store-128-bit-tile.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/load-store-128-bit-tile.mlir
new file mode 100644
index 00000000000000..de1fff5bea3f8b
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/load-store-128-bit-tile.mlir
@@ -0,0 +1,113 @@
+// DEFINE: %{entry_point} = test_load_store_zaq0
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE:   -enable-arm-streaming="mode=locally enable-za" \
+// DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
+// DEFINE:   -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \
+// DEFINE:   -allocate-arm-sme-tiles -test-lower-to-llvm
+// DEFINE: %{run} = %mcr_aarch64_cmd \
+// DEFINE:  -march=aarch64 -mattr=+sve,+sme \
+// DEFINE:  -e %{entry_point} -entry-point-result=void \
+// DEFINE:  -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils
+
+// RUN: %{compile} | %{run} | FileCheck %s
+
+/// Note: The SME ST1Q/LD1Q instructions are currently broken in QEMU
+/// see: https://gitlab.com/qemu-project/qemu/-/issues/1833
+/// This test is expected to fail until a fixed version of QEMU can be used.
+
+/// FIXME: Remove the 'XFAIL' below once a fixed QEMU version is available
+/// (and installed on CI buildbot).
+// XFAIL: {{.*}}
+
+func.func @print_i8s(%bytes: memref<?xi8>, %len: index) {
+  %c0 = arith.constant 0 : index
+  %c16 = arith.constant 16 : index
+  scf.for %i = %c0 to %len step %c16 {
+    %v = vector.load %bytes[%i] : memref<?xi8>, vector<16xi8>
+    vector.print %v : vector<16xi8>
+  }
+  return
+}
+
+llvm.func @printCString(!llvm.ptr<i8>)
+
+func.func @print_str(%str: !llvm.ptr<array<17 x i8>>) {
+  %c0 = llvm.mlir.constant(0 : index) : i64
+  %str_bytes = llvm.getelementptr %str[%c0, %c0]
+    : (!llvm.ptr<array<17 x i8>>, i64, i64) -> !llvm.ptr<i8>
+  llvm.call @printCString(%str_bytes) : (!llvm.ptr<i8>) -> ()
+  return
+}
+
+func.func @vector_copy_i128(%src: memref<?x?xi128>, %dst: memref<?x?xi128>) {
+  %c0 = arith.constant 0 : index
+  %tile = vector.load %src[%c0, %c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
+  vector.store %tile, %dst[%c0, %c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
+  return
+}
+
+func.func @test_load_store_zaq0() {
+  %init_a_str = llvm.mlir.addressof @init_tile_a : !llvm.ptr<array<17 x i8>>
+  %init_b_str = llvm.mlir.addressof @init_tile_b : !llvm.ptr<array<17 x i8>>
+  %final_a_str = llvm.mlir.addressof @final_tile_a : !llvm.ptr<array<17 x i8>>
+  %final_b_str = llvm.mlir.addressof @final_tile_b : !llvm.ptr<array<17 x i8>>
+
+  %c0 = arith.constant 0 : index
+  %min_elts_q = arith.constant 1 : index
+  %bytes_per_128_bit = arith.constant 16 : index
+
+  /// Calculate the size of an 128-bit tile, e.g. ZA{n}.q, in bytes:
+  %vscale = vector.vscale
+  %svl_q = arith.muli %min_elts_q, %vscale : index
+  %zaq_size = arith.muli %svl_q, %svl_q : index
+  %zaq_size_bytes = arith.muli %zaq_size, %bytes_per_128_bit : index
+
+  /// Allocate memory for two 128-bit tiles (A and B) and fill them a constant.
+  /// The tiles are allocated as bytes so we can fill and print them, as there's
+  /// very little that can be done with 128-bit types directly.
+  %tile_a_bytes = memref.alloca(%zaq_size_bytes) {alignment = 16} : memref<?xi8>
+  %tile_b_bytes = memref.alloca(%zaq_size_bytes) {alignment = 16} : memref<?xi8>
+  %fill_a_i8 = arith.constant 7 : i8
+  %fill_b_i8 = arith.constant 64 : i8
+  linalg.fill ins(%fill_a_i8 : i8) outs(%tile_a_bytes : memref<?xi8>)
+  linalg.fill ins(%fill_b_i8 : i8) outs(%tile_b_bytes : memref<?xi8>)
+
+  /// Get an 128-bit view of the memory for tiles A and B:
+  %tile_a = memref.view %tile_a_bytes[%c0][%svl_q, %svl_q] :
+    memref<?xi8> to memref<?x?xi128>
+  %tile_b = memref.view %tile_b_bytes[%c0][%svl_q, %svl_q] :
+    memref<?xi8> to memref<?x?xi128>
+
+  // CHECK-LABEL: INITIAL TILE A:
+  // CHECK: ( 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7 )
+  func.call @print_str(%init_a_str) : (!llvm.ptr<array<17 x i8>>) -> ()
+  func.call @print_i8s(%tile_a_bytes, %zaq_size_bytes) : (memref<?xi8>, index) -> ()
+  vector.print punctuation <newline>
+
+  // CHECK-LABEL: INITIAL TILE B:
+  // CHECK: ( 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64 )
+  func.call @print_str(%init_b_str) : (!llvm.ptr<array<17 x i8>>) -> ()
+  func.call @print_i8s(%tile_b_bytes, %zaq_size_bytes) : (memref<?xi8>, index) -> ()
+  vector.print punctuation <newline>
+
+  /// Load tile A and store it to tile B:
+  func.call @vector_copy_i128(%tile_a, %tile_b) : (memref<?x?xi128>, memref<?x?xi128>) -> ()
+
+  // CHECK-LABEL: FINAL TILE A:
+  // CHECK: ( 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7 )
+  func.call @print_str(%final_a_str) : (!llvm.ptr<array<17 x i8>>) -> ()
+  func.call @print_i8s(%tile_a_bytes, %zaq_size_bytes) : (memref<?xi8>, index) -> ()
+  vector.print punctuation <newline>
+
+  // CHECK-LABEL: FINAL TILE B:
+  // CHECK: ( 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7 )
+  func.call @print_str(%final_b_str) : (!llvm.ptr<array<17 x i8>>) -> ()
+  func.call @print_i8s(%tile_b_bytes, %zaq_size_bytes) : (memref<?xi8>, index) -> ()
+
+  return
+}
+
+llvm.mlir.global internal constant @init_tile_a ("INITIAL TILE A:\0A\00")
+llvm.mlir.global internal constant @init_tile_b ("INITIAL TILE B:\0A\00")
+llvm.mlir.global internal constant @final_tile_a("  FINAL TILE A:\0A\00")
+llvm.mlir.global internal constant @final_tile_b("  FINAL TILE B:\0A\00")


        


More information about the Mlir-commits mailing list