[Mlir-commits] [mlir] [mlir][vector] Add scalable lowering for `transfer_write(transpose)` (PR #101353)

Andrzej WarzyƄski llvmlistbot at llvm.org
Wed Aug 7 13:30:55 PDT 2024


================
@@ -36,3 +36,16 @@ func.func @transfer_write_2d(%A : tensor<?x?xf32>, %vec : vector<2x3xf32>,
   return %t : tensor<?x?xf32>
 }
 
+// -----
+
+// CHECK-LABEL: func @scalable_transpose_store
+//  CHECK-SAME: %[[TENSOR:[a-z0-9]+]]: tensor<?x?xf32>
+//       CHECK: %[[RESULT:.*]] = scf.for {{.*}} iter_args(%[[ITER_ARG:.*]] = %[[TENSOR]]) -> (tensor<?x?xf32>)
+//       CHECK:   %[[WRITE_SLICE:.*]] = vector.transfer_write %{{.*}} %[[ITER_ARG]]
+//       CHECK:   scf.yield %[[WRITE_SLICE]]
+//       CHECK: return %[[RESULT]]
+func.func @scalable_transpose_store(%vec: vector<4x[4]xf32>, %dest: tensor<?x?xf32>, %i: index, %j: index) -> tensor<?x?xf32> {
+  %transpose = vector.transpose %vec, [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
+  %result = vector.transfer_write %transpose, %dest[%i, %j] {in_bounds = [true, true]} : vector<[4]x4xf32>,  tensor<?x?xf32>
+  return %result : tensor<?x?xf32>
+}
----------------
banach-space wrote:

Please maintain the style that's already present in the file:
* `%i` -> `%base1`
* `%j` -> `%base2`
* `%dest` -> `%A`

https://github.com/llvm/llvm-project/pull/101353


More information about the Mlir-commits mailing list