[Mlir-commits] [mlir] [mlir][linalg] Add e2e test for linalg.mmt4d (PR #84964)
Cullen Rhodes
llvmlistbot at llvm.org
Wed Mar 13 03:34:32 PDT 2024
================
@@ -0,0 +1,174 @@
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE: -transform-interpreter -test-transform-dialect-erase-schedule \
+// DEFINE: -one-shot-bufferize -func-bufferize -cse -canonicalize -convert-vector-to-scf -test-lower-to-llvm -o %t
+// DEFINE: %{entry_point} = main
+// DEFINE: %{run} = mlir-cpu-runner %t -e %{entry_point} -entry-point-result=void \
+// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils
+
+// RUN: %{compile}
+
+// RUN: %{run} | FileCheck %s
+
+/// End-to-end test for computing matrix-multiplicatin using linalg.mmt4d. In
+/// particular, demonstrates how the following MLIR sequence (implemented in @mmt4d):
+///
+/// A_pack = tensor.pack A
+/// B_pack = tensor.pack B
+/// C_pack = tensor.pack C
+/// out_pack = linalg.mmt4d(A_pack, B_ pack, C_pack)
+///
+/// is equivalent to:
+///
+/// linalg.matmul(A, B, C)
+///
+/// (implemented in @matmul).
+
+func.func @main() {
+ // Allocate and initialise the inputs
+ %A_alloc = tensor.empty() : tensor<7x16xi32>
+ %B_alloc = tensor.empty() : tensor<16x13xi32>
+
+ %three = arith.constant 3 : i32
+ %four = arith.constant 4 : i32
+ %A = linalg.fill ins(%three : i32) outs(%A_alloc : tensor<7x16xi32>) -> tensor<7x16xi32>
+ %B = linalg.fill ins(%four : i32) outs(%B_alloc : tensor<16x13xi32>) -> tensor<16x13xi32>
+ %C = arith.constant dense<[
+ [ 1, 8, 15, 22, 29, 36, 43, 50, 57, 64, 71, 78, 85],
+ [ 2, 9, 16, 23, 30, 37, 44, 51, 58, 65, 72, 79, 86],
+ [ 3, 10, 17, 24, 31, 38, 45, 52, 59, 66, 73, 80, 87],
+ [ 4, 11, 18, 25, 32, 39, 46, 53, 60, 67, 74, 81, 88],
+ [ 5, 12, 19, 26, 33, 40, 47, 54, 61, 68, 75, 82, 89],
+ [ 6, 13, 20, 27, 34, 41, 48, 55, 62, 69, 76, 83, 90],
+ [ 7, 14, 21, 28, 35, 42, 49, 56, 63, 70, 77, 84, 91]
+ ]> : tensor<7x13xi32>
+
+ // Matrix multiplication via linalg.mmt4d
+ // CHECK: Unranked Memref
+ // CHECK: [193, 200, 207, 214, 221, 228, 235, 242, 249, 256, 263, 270, 277]
+ // CHECK: [194, 201, 208, 215, 222, 229, 236, 243, 250, 257, 264, 271, 278]
+ // CHECK: [195, 202, 209, 216, 223, 230, 237, 244, 251, 258, 265, 272, 279]
+ // CHECK: [196, 203, 210, 217, 224, 231, 238, 245, 252, 259, 266, 273, 280]
+ // CHECK: [197, 204, 211, 218, 225, 232, 239, 246, 253, 260, 267, 274, 281]
+ // CHECK: [198, 205, 212, 219, 226, 233, 240, 247, 254, 261, 268, 275, 282]
+ // CHECK: [199, 206, 213, 220, 227, 234, 241, 248, 255, 262, 269, 276, 283]
+ %C_mmt4d = func.call @mmt4d(%A, %B, %C) : (tensor<7x16xi32>, tensor<16x13xi32>, tensor<7x13xi32>) -> tensor<7x13xi32>
+ %xf = tensor.cast %C_mmt4d : tensor<7x13xi32> to tensor<*xi32>
+ call @printMemrefI32(%xf) : (tensor<*xi32>) -> ()
+
+ // Matrix multiplicaiton with linalg.matmul
+ // CHECK: Unranked Memref
+ // CHECK: [193, 200, 207, 214, 221, 228, 235, 242, 249, 256, 263, 270, 277]
+ // CHECK: [194, 201, 208, 215, 222, 229, 236, 243, 250, 257, 264, 271, 278]
+ // CHECK: [195, 202, 209, 216, 223, 230, 237, 244, 251, 258, 265, 272, 279]
+ // CHECK: [196, 203, 210, 217, 224, 231, 238, 245, 252, 259, 266, 273, 280]
+ // CHECK: [197, 204, 211, 218, 225, 232, 239, 246, 253, 260, 267, 274, 281]
+ // CHECK: [198, 205, 212, 219, 226, 233, 240, 247, 254, 261, 268, 275, 282]
+ // CHECK: [199, 206, 213, 220, 227, 234, 241, 248, 255, 262, 269, 276, 283]
+ %C_matmul = func.call @matmul(%A, %B, %C) : (tensor<7x16xi32>, tensor<16x13xi32>, tensor<7x13xi32>) -> tensor<7x13xi32>
+ %xf_2 = tensor.cast %C_matmul : tensor<7x13xi32> to tensor<*xi32>
+ call @printMemrefI32(%xf_2) : (tensor<*xi32>) -> ()
+
+ return
+}
+
+func.func @matmul(%A: tensor<7x16xi32>, %B: tensor<16x13xi32>, %C: tensor<7x13xi32>) -> tensor<7x13xi32> {
+ %C_matmul = linalg.matmul ins(%A, %B: tensor<7x16xi32>, tensor<16x13xi32>)
+ outs(%C: tensor<7x13xi32>) -> tensor<7x13xi32>
+
+ return %C_matmul : tensor<7x13xi32>
+}
+
+func.func @mmt4d(%A: tensor<7x16xi32>, %B: tensor<16x13xi32>, %C: tensor<7x13xi32>) -> tensor<7x13xi32> {
+ %zero = arith.constant 0 : i32
+
+ %cst = arith.constant 0 : i32
+ %A_pack_empty = tensor.empty() : tensor<2x16x8x1xi32>
+ %B_pack_empty = tensor.empty() : tensor<2x16x8x1xi32>
+ %C_pack_empty = tensor.empty() : tensor<2x2x8x8xi32>
+
+ // Pack matrices
+ %A_pack = tensor.pack %A padding_value(%zero : i32) outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %A_pack_empty : tensor<7x16xi32> -> tensor<2x16x8x1xi32>
+ %B_pack = tensor.pack %B padding_value(%zero : i32) outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [8, 1] into %B_pack_empty : tensor<16x13xi32> -> tensor<2x16x8x1xi32>
+ %C_pack = tensor.pack %C padding_value(%zero : i32) outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %C_pack_empty : tensor<7x13xi32> -> tensor<2x2x8x8xi32>
----------------
c-rhodes wrote:
is this a realistic packing? The second 8x8 isn't used at all,
```
Unranked Memref base@ = 0xaaaaee4866c0 rank = 4 offset = 0 sizes = [2, 2, 8, 8] strides = [128, 64, 8, 1] data =
[[[[1, 8, 15, 22, 29, 36, 43, 50],
[2, 9, 16, 23, 30, 37, 44, 51],
[3, 10, 17, 24, 31, 38, 45, 52],
[4, 11, 18, 25, 32, 39, 46, 53],
[5, 12, 19, 26, 33, 40, 47, 54],
[6, 13, 20, 27, 34, 41, 48, 55],
[7, 14, 21, 28, 35, 42, 49, 56],
[0, 0, 0, 0, 0, 0, 0, 0]],
[[57, 64, 71, 78, 85, 0, 0, 0],
[58, 65, 72, 79, 86, 0, 0, 0],
[59, 66, 73, 80, 87, 0, 0, 0],
[60, 67, 74, 81, 88, 0, 0, 0],
[61, 68, 75, 82, 89, 0, 0, 0],
[62, 69, 76, 83, 90, 0, 0, 0],
[63, 70, 77, 84, 91, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0]]],
[[[0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0]],
[[0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0]]]]
```
https://github.com/llvm/llvm-project/pull/84964
More information about the Mlir-commits
mailing list