[Mlir-commits] [mlir] 97da414 - [mlir][ArmSME] Lower loads/stores of (.Q) 128-bit tiles to intrinsics
Benjamin Maxwell
llvmlistbot at llvm.org
Wed Aug 23 02:17:13 PDT 2023
Author: Benjamin Maxwell
Date: 2023-08-23T09:16:20Z
New Revision: 97da41418226bb426ae67b1e5b9c0ea3255042f2
URL: https://github.com/llvm/llvm-project/commit/97da41418226bb426ae67b1e5b9c0ea3255042f2
DIFF: https://github.com/llvm/llvm-project/commit/97da41418226bb426ae67b1e5b9c0ea3255042f2.diff
LOG: [mlir][ArmSME] Lower loads/stores of (.Q) 128-bit tiles to intrinsics
This follows from D155306.
Loads and stores of 128-bit tiles have been confirmed to work in the
`load-store-128-bit-tile.mlir` integration test. However, there is
currently a bug in QEMU (see: https://gitlab.com/qemu-project/qemu/-/issues/1833)
which means this test produces incorrect results (a patch for this issue
is available but not yet in any released version of QEMU). Until a
fixed version of QEMU is available the integration test is expected to fail.
Reviewed By: c-rhodes, awarzynski
Differential Revision: https://reviews.llvm.org/D158418
Added:
mlir/test/Integration/Dialect/Vector/CPU/ArmSME/load-store-128-bit-tile.mlir
Modified:
mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
mlir/lib/Dialect/ArmSME/Utils/Utils.cpp
mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
index 2f4dee7ba916e8..e846b63b011a62 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
@@ -223,6 +223,10 @@ struct LoadTileSliceToArmSMELowering
rewriter.create<arm_sme::aarch64_sme_ld1d_horiz>(loc, allActiveMask, ptr,
tileI32, tileSliceI32);
break;
+ case 128:
+ rewriter.create<arm_sme::aarch64_sme_ld1q_horiz>(loc, allActiveMask, ptr,
+ tileI32, tileSliceI32);
+ break;
}
// The load intrinsics have no result, replace 'arm_sme.tile_load' with
@@ -294,6 +298,10 @@ struct StoreTileSliceToArmSMELowering
rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1d_horiz>(
storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
break;
+ case 128:
+ rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1q_horiz>(
+ storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
+ break;
}
return success();
@@ -309,9 +317,10 @@ void mlir::configureArmSMELegalizeForExportTarget(
arm_sme::CastVectorToTile, arm_sme::aarch64_sme_zero,
arm_sme::aarch64_sme_str, arm_sme::aarch64_sme_ld1b_horiz,
arm_sme::aarch64_sme_ld1h_horiz, arm_sme::aarch64_sme_ld1w_horiz,
- arm_sme::aarch64_sme_ld1d_horiz, arm_sme::aarch64_sme_st1b_horiz,
- arm_sme::aarch64_sme_st1h_horiz, arm_sme::aarch64_sme_st1w_horiz,
- arm_sme::aarch64_sme_st1d_horiz, arm_sme::aarch64_sme_za_enable,
+ arm_sme::aarch64_sme_ld1d_horiz, arm_sme::aarch64_sme_ld1q_horiz,
+ arm_sme::aarch64_sme_st1b_horiz, arm_sme::aarch64_sme_st1h_horiz,
+ arm_sme::aarch64_sme_st1w_horiz, arm_sme::aarch64_sme_st1d_horiz,
+ arm_sme::aarch64_sme_st1q_horiz, arm_sme::aarch64_sme_za_enable,
arm_sme::aarch64_sme_za_disable>();
target.addLegalOp<GetTileID>();
diff --git a/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp b/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp
index a5908a5a8f330f..cc9e36cfca4193 100644
--- a/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp
@@ -25,17 +25,15 @@ unsigned mlir::arm_sme::getSMETileSliceMinNumElts(Type type) {
}
bool mlir::arm_sme::isValidSMETileElementType(Type type) {
- // TODO: add support for i128.
return type.isInteger(8) || type.isInteger(16) || type.isInteger(32) ||
- type.isInteger(64) || type.isF16() || type.isBF16() || type.isF32() ||
- type.isF64();
+ type.isInteger(64) || type.isInteger(128) || type.isF16() ||
+ type.isBF16() || type.isF32() || type.isF64() || type.isF128();
}
bool mlir::arm_sme::isValidSMETileVectorType(VectorType vType) {
if ((vType.getRank() != 2) && vType.allDimsScalable())
return false;
- // TODO: add support for i128.
auto elemType = vType.getElementType();
if (!isValidSMETileElementType(elemType))
return false;
diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
index de8bc5f93b2c76..af528295ef6ee2 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
@@ -220,6 +220,20 @@ func.func @vector_load_f64(%arg0 : memref<?x?xf64>) -> vector<[2]x[2]xf64> {
// -----
+// CHECK-LABEL: @vector_load_i128(
+// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xi128>)
+// CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i128
+// CHECK-DAG: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i128 to vector<[1]x[1]xi128>
+// CHECK: %[[TILE_ID_I32:.*]] = arith.trunci %[[TILE_ID]] : i128 to i32
+// CHECK: arm_sme.intr.ld1q.horiz
+func.func @vector_load_i128(%arg0 : memref<?x?xi128>) -> vector<[1]x[1]xi128> {
+ %c0 = arith.constant 0 : index
+ %tile = vector.load %arg0[%c0, %c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
+ return %tile : vector<[1]x[1]xi128>
+}
+
+// -----
+
// CHECK-LABEL: @vector_store_i8(
// CHECK-SAME: %[[TILE:.*]]: vector<[16]x[16]xi8>,
// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xi8>)
@@ -363,3 +377,17 @@ func.func @vector_store_f64(%tile : vector<[2]x[2]xf64>, %arg0 : memref<?x?xf64>
vector.store %tile, %arg0[%c0, %c0] : memref<?x?xf64>, vector<[2]x[2]xf64>
return
}
+
+// -----
+
+// CHECK-LABEL: @vector_store_i128(
+// CHECK-SAME: %[[TILE:.*]]: vector<[1]x[1]xi128>,
+// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xi128>)
+// CHECK: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[1]x[1]xi128> to i128
+// CHECK: %[[TILE_ID_I32:.*]] = arith.trunci %[[CAST_VECTOR_TO_TILE]] : i128 to i32
+// CHECK: arm_sme.intr.st1q.horiz
+func.func @vector_store_i128(%tile : vector<[1]x[1]xi128>, %arg0 : memref<?x?xi128>) {
+ %c0 = arith.constant 0 : index
+ vector.store %tile, %arg0[%c0, %c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
+ return
+}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/load-store-128-bit-tile.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/load-store-128-bit-tile.mlir
new file mode 100644
index 00000000000000..de1fff5bea3f8b
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/load-store-128-bit-tile.mlir
@@ -0,0 +1,113 @@
+// DEFINE: %{entry_point} = test_load_store_zaq0
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE: -enable-arm-streaming="mode=locally enable-za" \
+// DEFINE: -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
+// DEFINE: -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \
+// DEFINE: -allocate-arm-sme-tiles -test-lower-to-llvm
+// DEFINE: %{run} = %mcr_aarch64_cmd \
+// DEFINE: -march=aarch64 -mattr=+sve,+sme \
+// DEFINE: -e %{entry_point} -entry-point-result=void \
+// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils
+
+// RUN: %{compile} | %{run} | FileCheck %s
+
+/// Note: The SME ST1Q/LD1Q instructions are currently broken in QEMU
+/// see: https://gitlab.com/qemu-project/qemu/-/issues/1833
+/// This test is expected to fail until a fixed version of QEMU can be used.
+
+/// FIXME: Remove the 'XFAIL' below once a fixed QEMU version is available
+/// (and installed on CI buildbot).
+// XFAIL: {{.*}}
+
+func.func @print_i8s(%bytes: memref<?xi8>, %len: index) {
+ %c0 = arith.constant 0 : index
+ %c16 = arith.constant 16 : index
+ scf.for %i = %c0 to %len step %c16 {
+ %v = vector.load %bytes[%i] : memref<?xi8>, vector<16xi8>
+ vector.print %v : vector<16xi8>
+ }
+ return
+}
+
+llvm.func @printCString(!llvm.ptr<i8>)
+
+func.func @print_str(%str: !llvm.ptr<array<17 x i8>>) {
+ %c0 = llvm.mlir.constant(0 : index) : i64
+ %str_bytes = llvm.getelementptr %str[%c0, %c0]
+ : (!llvm.ptr<array<17 x i8>>, i64, i64) -> !llvm.ptr<i8>
+ llvm.call @printCString(%str_bytes) : (!llvm.ptr<i8>) -> ()
+ return
+}
+
+func.func @vector_copy_i128(%src: memref<?x?xi128>, %dst: memref<?x?xi128>) {
+ %c0 = arith.constant 0 : index
+ %tile = vector.load %src[%c0, %c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
+ vector.store %tile, %dst[%c0, %c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
+ return
+}
+
+func.func @test_load_store_zaq0() {
+ %init_a_str = llvm.mlir.addressof @init_tile_a : !llvm.ptr<array<17 x i8>>
+ %init_b_str = llvm.mlir.addressof @init_tile_b : !llvm.ptr<array<17 x i8>>
+ %final_a_str = llvm.mlir.addressof @final_tile_a : !llvm.ptr<array<17 x i8>>
+ %final_b_str = llvm.mlir.addressof @final_tile_b : !llvm.ptr<array<17 x i8>>
+
+ %c0 = arith.constant 0 : index
+ %min_elts_q = arith.constant 1 : index
+ %bytes_per_128_bit = arith.constant 16 : index
+
+ /// Calculate the size of an 128-bit tile, e.g. ZA{n}.q, in bytes:
+ %vscale = vector.vscale
+ %svl_q = arith.muli %min_elts_q, %vscale : index
+ %zaq_size = arith.muli %svl_q, %svl_q : index
+ %zaq_size_bytes = arith.muli %zaq_size, %bytes_per_128_bit : index
+
+ /// Allocate memory for two 128-bit tiles (A and B) and fill them a constant.
+ /// The tiles are allocated as bytes so we can fill and print them, as there's
+ /// very little that can be done with 128-bit types directly.
+ %tile_a_bytes = memref.alloca(%zaq_size_bytes) {alignment = 16} : memref<?xi8>
+ %tile_b_bytes = memref.alloca(%zaq_size_bytes) {alignment = 16} : memref<?xi8>
+ %fill_a_i8 = arith.constant 7 : i8
+ %fill_b_i8 = arith.constant 64 : i8
+ linalg.fill ins(%fill_a_i8 : i8) outs(%tile_a_bytes : memref<?xi8>)
+ linalg.fill ins(%fill_b_i8 : i8) outs(%tile_b_bytes : memref<?xi8>)
+
+ /// Get an 128-bit view of the memory for tiles A and B:
+ %tile_a = memref.view %tile_a_bytes[%c0][%svl_q, %svl_q] :
+ memref<?xi8> to memref<?x?xi128>
+ %tile_b = memref.view %tile_b_bytes[%c0][%svl_q, %svl_q] :
+ memref<?xi8> to memref<?x?xi128>
+
+ // CHECK-LABEL: INITIAL TILE A:
+ // CHECK: ( 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7 )
+ func.call @print_str(%init_a_str) : (!llvm.ptr<array<17 x i8>>) -> ()
+ func.call @print_i8s(%tile_a_bytes, %zaq_size_bytes) : (memref<?xi8>, index) -> ()
+ vector.print punctuation <newline>
+
+ // CHECK-LABEL: INITIAL TILE B:
+ // CHECK: ( 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64 )
+ func.call @print_str(%init_b_str) : (!llvm.ptr<array<17 x i8>>) -> ()
+ func.call @print_i8s(%tile_b_bytes, %zaq_size_bytes) : (memref<?xi8>, index) -> ()
+ vector.print punctuation <newline>
+
+ /// Load tile A and store it to tile B:
+ func.call @vector_copy_i128(%tile_a, %tile_b) : (memref<?x?xi128>, memref<?x?xi128>) -> ()
+
+ // CHECK-LABEL: FINAL TILE A:
+ // CHECK: ( 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7 )
+ func.call @print_str(%final_a_str) : (!llvm.ptr<array<17 x i8>>) -> ()
+ func.call @print_i8s(%tile_a_bytes, %zaq_size_bytes) : (memref<?xi8>, index) -> ()
+ vector.print punctuation <newline>
+
+ // CHECK-LABEL: FINAL TILE B:
+ // CHECK: ( 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7 )
+ func.call @print_str(%final_b_str) : (!llvm.ptr<array<17 x i8>>) -> ()
+ func.call @print_i8s(%tile_b_bytes, %zaq_size_bytes) : (memref<?xi8>, index) -> ()
+
+ return
+}
+
+llvm.mlir.global internal constant @init_tile_a ("INITIAL TILE A:\0A\00")
+llvm.mlir.global internal constant @init_tile_b ("INITIAL TILE B:\0A\00")
+llvm.mlir.global internal constant @final_tile_a(" FINAL TILE A:\0A\00")
+llvm.mlir.global internal constant @final_tile_b(" FINAL TILE B:\0A\00")
More information about the Mlir-commits
mailing list