[Mlir-commits] [mlir] [mlir][ArmSME] Add support for vector.transfer_read with transpose (PR #67527)

Cullen Rhodes llvmlistbot at llvm.org
Thu Sep 28 00:43:03 PDT 2023


================
@@ -1,5 +1,212 @@
 // RUN: mlir-opt %s -convert-vector-to-arm-sme -split-input-file -allow-unregistered-dialect | FileCheck %s
 
+//===----------------------------------------------------------------------===//
+// vector.transfer_read (with in-flight transpose)
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @transfer_read_2d_transpose_i8
+// CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xi8>, vector<[16]x[16]xi8>
+func.func @transfer_read_2d_transpose_i8(%src : memref<?x?xi8>) {
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0 : i8
+  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xi8>, vector<[16]x[16]xi8>
+  "prevent.dce"(%0) : (vector<[16]x[16]xi8>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_read_2d_transpose_i16
+// CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xi16>, vector<[8]x[8]xi16>
+func.func @transfer_read_2d_transpose_i16(%src : memref<?x?xi16>) {
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0 : i16
+  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xi16>, vector<[8]x[8]xi16>
+  "prevent.dce"(%0) : (vector<[8]x[8]xi16>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_read_2d_transpose_i32
+// CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
+func.func @transfer_read_2d_transpose_i32(%src : memref<?x?xi32>) {
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0 : i32
+  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xi32>, vector<[4]x[4]xi32>
+  "prevent.dce"(%0) : (vector<[4]x[4]xi32>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_read_2d_transpose_i64
+// CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xi64>, vector<[2]x[2]xi64>
+func.func @transfer_read_2d_transpose_i64(%src : memref<?x?xi64>) {
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0 : i64
+  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xi64>, vector<[2]x[2]xi64>
+  "prevent.dce"(%0) : (vector<[2]x[2]xi64>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_read_2d_transpose_i128
+// CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
+func.func @transfer_read_2d_transpose_i128(%src : memref<?x?xi128>) {
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0 : i128
+  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xi128>, vector<[1]x[1]xi128>
+  "prevent.dce"(%0) : (vector<[1]x[1]xi128>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_read_2d_transpose_f16
+// CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xf16>, vector<[8]x[8]xf16>
+func.func @transfer_read_2d_transpose_f16(%src : memref<?x?xf16>) {
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0.0 : f16
+  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xf16>, vector<[8]x[8]xf16>
+  "prevent.dce"(%0) : (vector<[8]x[8]xf16>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_read_2d_transpose_bf16
+// CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+func.func @transfer_read_2d_transpose_bf16(%src : memref<?x?xbf16>) {
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0.0 : bf16
+  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+  "prevent.dce"(%0) : (vector<[8]x[8]xbf16>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_read_2d_transpose_f32
+// CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
+func.func @transfer_read_2d_transpose_f32(%src : memref<?x?xf32>) {
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0.0 : f32
+  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xf32>, vector<[4]x[4]xf32>
+  "prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_read_2d_transpose_f64
+// CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xf64>, vector<[2]x[2]xf64>
+func.func @transfer_read_2d_transpose_f64(%src : memref<?x?xf64>) {
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0.0 : f64
+  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xf64>, vector<[2]x[2]xf64>
+  "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_read_2d__bad_type
+// CHECK-NOT: arm_sme.tile_load
+// CHECK: vector.transfer_read
+func.func @transfer_read_2d__bad_type(%src : memref<?x?xf64>) {
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0.0 : f64
+  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [false, false]} : memref<?x?xf64>, vector<[4]x[4]xf64>
+  "prevent.dce"(%0) : (vector<[4]x[4]xf64>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_read_2d__non_memref_type
+// CHECK-NOT: arm_sme.tile_load
+// CHECK: vector.transfer_read
+func.func @transfer_read_2d__non_memref_type(%src : tensor<?x?xf64>) {
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0.0 : f64
+  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : tensor<?x?xf64>, vector<[2]x[2]xf64>
+  "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_read_2d__bad_transfer_rank
+// CHECK-NOT: arm_sme.tile_load
+// CHECK: vector.transfer_read
+func.func @transfer_read_2d__bad_transfer_rank(%src : memref<?x?xf64>) {
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0.0 : f64
+  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true]} : memref<?x?xf64>, vector<[2]xf64>
+  "prevent.dce"(%0) : (vector<[2]xf64>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_read_2d__unsupported_mask
+// CHECK-NOT: arm_sme.tile_load
+// CHECK: vector.transfer_read
+func.func @transfer_read_2d__unsupported_mask(%src : memref<?x?xf64>, %mask : vector<[2]x[2]xi1>) {
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0.0 : f64
+  %0 = vector.transfer_read %src[%c0, %c0], %pad, %mask {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xf64>, vector<[2]x[2]xf64>
----------------
c-rhodes wrote:

this is a valid type, I checked debug output as well and the failure is as expected:
> Trying to match "(anonymous namespace)::TransferReadPermutationToArmSMELowering"
    ** Failure : masking not yet supported

https://github.com/llvm/llvm-project/pull/67527


More information about the Mlir-commits mailing list