[Mlir-commits] [mlir] [mlir][ArmSME] Add rewrite to lift illegal vector.transposes to memory (PR #80170)

Cullen Rhodes llvmlistbot at llvm.org
Thu Feb 1 01:15:02 PST 2024


================
@@ -266,3 +266,78 @@ func.func @transpose_f32_scalable_4x16_via_write(%src: memref<?x?xf32>, %dest: m
   vector.transfer_write %0, %dest[%c0, %c0] {permutation_map = #transpose, in_bounds = [true, true]} : vector<[4]x[16]xf32>, memref<?x?xf32>
   return
 }
+
+// -----
+
+// CHECK-LABEL: @lift_illegal_transpose_to_memory_no_mask(
+// CHECK-SAME:                                            %[[INDEXA:[a-z0-9]+]]: index,
+// CHECK-SAME:                                            %[[INDEXB:[a-z0-9]+]]: index,
+// CHECK-SAME:                                            %[[MEMREF:[a-z0-9]+]]: memref<?x?xf32>)
+func.func @lift_illegal_transpose_to_memory_no_mask(%a: index, %b: index, %memref: memref<?x?xf32>) -> vector<4x[8]xf32> {
+  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
+  // CHECK-DAG: %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f32
+  // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
+  // CHECK-DAG: %[[C8_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C8]] : index
+  // CHECK-NEXT: %[[READ_SUBVIEW:.*]] = memref.subview %[[MEMREF]][%[[INDEXA]], %[[INDEXB]]] [%[[C8_VSCALE]], 4] [1, 1] : memref<?x?xf32> to memref<?x4xf32, strided<[?, 1], offset: ?>>
+  // CHECK-NEXT: %[[CAST:.*]] = memref.cast %[[READ_SUBVIEW]] : memref<?x4xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
+  // CHECK-NEXT: %[[TRANSPOSE:.*]] = memref.transpose %[[CAST]] (d0, d1) -> (d1, d0) : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
+  // CHECK-NEXT: %[[LEGAL_READ:.*]]  = vector.transfer_read %[[TRANSPOSE]][%c0, %c0], %[[C0_F32]] : memref<?x?xf32, strided<[?, ?], offset: ?>>, vector<4x[8]xf32>
+  // CHECK-NEXT: return %[[LEGAL_READ]]
+  %pad = arith.constant 0.0 : f32
+  %illegalRead = vector.transfer_read %memref[%a, %b], %pad : memref<?x?xf32>, vector<[8]x4xf32>
+  %legalType = vector.transpose %illegalRead, [1, 0] : vector<[8]x4xf32> to vector<4x[8]xf32>
+  return %legalType : vector<4x[8]xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @lift_illegal_transpose_to_memory(
+// CHECK-SAME:                                    %[[INDEXA:[a-z0-9]+]]: index,
+// CHECK-SAME:                                    %[[INDEXB:[a-z0-9]+]]: index,
+// CHECK-SAME:                                    %[[DIM0:[a-z0-9]+]]: index,
+// CHECK-SAME:                                    %[[DIM1:[a-z0-9]+]]: index,
+// CHECK-SAME:                                    %[[MEMREF:[a-z0-9]+]]: memref<?x?xf32>)
+func.func @lift_illegal_transpose_to_memory(%a: index, %b: index, %dim0: index, %dim1: index, %memref: memref<?x?xf32>) -> vector<4x[8]xf32> {
+  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
+  // CHECK-DAG: %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f32
+  // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
+  // CHECK-DAG: %[[C8_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C8]] : index
+  // CHECK-NEXT: %[[READ_SUBVIEW:.*]] = memref.subview %[[MEMREF]][%[[INDEXA]], %[[INDEXB]]] [%[[C8_VSCALE]], 4] [1, 1] : memref<?x?xf32> to memref<?x4xf32, strided<[?, 1], offset: ?>>
+  // CHECK-NEXT: %[[CAST:.*]] = memref.cast %[[READ_SUBVIEW]] : memref<?x4xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
+  // CHECK-NEXT: %[[MASK:.*]] = vector.create_mask %[[DIM1]], %[[DIM0]] : vector<4x[8]xi1>
+  // CHECK-NEXT: %[[TRANSPOSE:.*]] = memref.transpose %[[CAST]] (d0, d1) -> (d1, d0) : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
+  // CHECK-NEXT: %[[LEGAL_READ:.*]]  = vector.transfer_read %[[TRANSPOSE]][%c0, %c0], %[[C0_F32]], %[[MASK]] : memref<?x?xf32, strided<[?, ?], offset: ?>>, vector<4x[8]xf32>
+  // CHECK-NEXT: return %[[LEGAL_READ]]
----------------
c-rhodes wrote:

most of this is checked above and can be removed to make it clearer the mask is being transposed
```suggestion
  // CHECK: %[[MASK:.*]] = vector.create_mask %[[DIM1]], %[[DIM0]] : vector<4x[8]xi1>
  // CHECK: %[[LEGAL_READ:.*]]  = vector.transfer_read %{{.*}}[{{.*}}], %{{.*}}, %{{.*}}, %[[MASK]] : memref<?x?xf32, strided<[?, ?], offset: ?>>, vector<4x[8]xf32>
```

https://github.com/llvm/llvm-project/pull/80170


More information about the Mlir-commits mailing list