[Mlir-commits] [mlir] [mlir][linalg] Add e2e test for linalg.mmt4d (PR #84964)

Cullen Rhodes llvmlistbot at llvm.org
Wed Mar 13 03:34:32 PDT 2024


================
@@ -0,0 +1,174 @@
+// DEFINE: %{compile} =  mlir-opt %s \
+// DEFINE:    -transform-interpreter -test-transform-dialect-erase-schedule \
+// DEFINE:    -one-shot-bufferize -func-bufferize -cse -canonicalize -convert-vector-to-scf -test-lower-to-llvm -o %t
+// DEFINE: %{entry_point} = main
+// DEFINE: %{run} = mlir-cpu-runner %t -e %{entry_point} -entry-point-result=void \
+// DEFINE:    -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils
+
+// RUN: %{compile}
+
+// RUN: %{run} | FileCheck %s
+
+/// End-to-end test for computing matrix-multiplicatin using linalg.mmt4d. In
+/// particular, demonstrates how the following MLIR sequence (implemented in @mmt4d):
+///
+///   A_pack = tensor.pack A
+///   B_pack = tensor.pack B
+///   C_pack = tensor.pack C
+///   out_pack = linalg.mmt4d(A_pack, B_ pack, C_pack)
+///
+/// is equivalent to:
+///
+///  linalg.matmul(A, B, C)
+///
+/// (implemented in @matmul).
+
+func.func @main() {
+  // Allocate and initialise the inputs
+  %A_alloc = tensor.empty() : tensor<7x16xi32>
+  %B_alloc = tensor.empty() : tensor<16x13xi32>
+
+  %three = arith.constant 3 : i32
+  %four = arith.constant 4 : i32
+  %A = linalg.fill ins(%three : i32) outs(%A_alloc : tensor<7x16xi32>) -> tensor<7x16xi32>
+  %B = linalg.fill ins(%four : i32) outs(%B_alloc : tensor<16x13xi32>) -> tensor<16x13xi32>
+  %C = arith.constant dense<[
+    [ 1,  8, 15, 22, 29, 36, 43, 50, 57, 64, 71, 78, 85],
+    [ 2,  9, 16, 23, 30, 37, 44, 51, 58, 65, 72, 79, 86],
+    [ 3, 10, 17, 24, 31, 38, 45, 52, 59, 66, 73, 80, 87],
+    [ 4, 11, 18, 25, 32, 39, 46, 53, 60, 67, 74, 81, 88],
+    [ 5, 12, 19, 26, 33, 40, 47, 54, 61, 68, 75, 82, 89],
+    [ 6, 13, 20, 27, 34, 41, 48, 55, 62, 69, 76, 83, 90],
+    [ 7, 14, 21, 28, 35, 42, 49, 56, 63, 70, 77, 84, 91]
+  ]> : tensor<7x13xi32>
+  
+  // Matrix multiplication via linalg.mmt4d
+  // CHECK: Unranked Memref
+  // CHECK:  [193,   200,   207,   214,   221,   228,   235,   242,   249,   256,   263,   270,   277]
+  // CHECK:  [194,   201,   208,   215,   222,   229,   236,   243,   250,   257,   264,   271,   278]
+  // CHECK:  [195,   202,   209,   216,   223,   230,   237,   244,   251,   258,   265,   272,   279]
+  // CHECK:  [196,   203,   210,   217,   224,   231,   238,   245,   252,   259,   266,   273,   280]
+  // CHECK:  [197,   204,   211,   218,   225,   232,   239,   246,   253,   260,   267,   274,   281]
+  // CHECK:  [198,   205,   212,   219,   226,   233,   240,   247,   254,   261,   268,   275,   282]
+  // CHECK:  [199,   206,   213,   220,   227,   234,   241,   248,   255,   262,   269,   276,   283]
+  %C_mmt4d = func.call @mmt4d(%A, %B, %C) : (tensor<7x16xi32>, tensor<16x13xi32>, tensor<7x13xi32>) -> tensor<7x13xi32>
+  %xf = tensor.cast %C_mmt4d : tensor<7x13xi32> to tensor<*xi32>
+  call @printMemrefI32(%xf) : (tensor<*xi32>) -> ()
+
+  // Matrix multiplicaiton with linalg.matmul
+  // CHECK: Unranked Memref
+  // CHECK:  [193,   200,   207,   214,   221,   228,   235,   242,   249,   256,   263,   270,   277]
+  // CHECK:  [194,   201,   208,   215,   222,   229,   236,   243,   250,   257,   264,   271,   278]
+  // CHECK:  [195,   202,   209,   216,   223,   230,   237,   244,   251,   258,   265,   272,   279]
+  // CHECK:  [196,   203,   210,   217,   224,   231,   238,   245,   252,   259,   266,   273,   280]
+  // CHECK:  [197,   204,   211,   218,   225,   232,   239,   246,   253,   260,   267,   274,   281]
+  // CHECK:  [198,   205,   212,   219,   226,   233,   240,   247,   254,   261,   268,   275,   282]
+  // CHECK:  [199,   206,   213,   220,   227,   234,   241,   248,   255,   262,   269,   276,   283]
+  %C_matmul = func.call @matmul(%A, %B, %C) : (tensor<7x16xi32>, tensor<16x13xi32>, tensor<7x13xi32>) -> tensor<7x13xi32>
+  %xf_2 = tensor.cast %C_matmul : tensor<7x13xi32> to tensor<*xi32>
+  call @printMemrefI32(%xf_2) : (tensor<*xi32>) -> ()
+
+  return
+}
+
+func.func @matmul(%A: tensor<7x16xi32>, %B: tensor<16x13xi32>, %C: tensor<7x13xi32>) -> tensor<7x13xi32> {
+  %C_matmul = linalg.matmul ins(%A, %B: tensor<7x16xi32>, tensor<16x13xi32>) 
+                            outs(%C: tensor<7x13xi32>) -> tensor<7x13xi32>
+
+  return %C_matmul : tensor<7x13xi32>
+}
+
+func.func @mmt4d(%A: tensor<7x16xi32>, %B: tensor<16x13xi32>, %C: tensor<7x13xi32>) -> tensor<7x13xi32> {
+  %zero = arith.constant 0 : i32
+
+  %cst = arith.constant 0 : i32
+  %A_pack_empty = tensor.empty() : tensor<2x16x8x1xi32>
+  %B_pack_empty = tensor.empty() : tensor<2x16x8x1xi32>
+  %C_pack_empty = tensor.empty() : tensor<2x2x8x8xi32>
+
+  // Pack matrices
+  %A_pack = tensor.pack %A padding_value(%zero : i32) outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %A_pack_empty : tensor<7x16xi32> -> tensor<2x16x8x1xi32>
+  %B_pack = tensor.pack %B padding_value(%zero : i32) outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [8, 1] into %B_pack_empty : tensor<16x13xi32> -> tensor<2x16x8x1xi32>
+  %C_pack = tensor.pack %C padding_value(%zero : i32) outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %C_pack_empty : tensor<7x13xi32> -> tensor<2x2x8x8xi32>
----------------
c-rhodes wrote:

is this a realistic packing? The second 8x8 isn't used at all,
```
Unranked Memref base@ = 0xaaaaee4866c0 rank = 4 offset = 0 sizes = [2, 2, 8, 8] strides = [128, 64, 8, 1] data =
[[[[1,     8,     15,     22,     29,     36,     43,     50],
   [2,     9,     16,     23,     30,     37,     44,     51],
   [3,     10,     17,     24,     31,     38,     45,     52],
   [4,     11,     18,     25,     32,     39,     46,     53],
   [5,     12,     19,     26,     33,     40,     47,     54],
   [6,     13,     20,     27,     34,     41,     48,     55],
   [7,     14,     21,     28,     35,     42,     49,     56],
   [0,     0,     0,     0,     0,     0,     0,     0]],
  [[57,     64,     71,     78,     85,     0,     0,     0],
   [58,     65,     72,     79,     86,     0,     0,     0],
   [59,     66,     73,     80,     87,     0,     0,     0],
   [60,     67,     74,     81,     88,     0,     0,     0],
   [61,     68,     75,     82,     89,     0,     0,     0],
   [62,     69,     76,     83,     90,     0,     0,     0],
   [63,     70,     77,     84,     91,     0,     0,     0],
   [0,     0,     0,     0,     0,     0,     0,     0]]],
 [[[0,     0,     0,     0,     0,     0,     0,     0],
   [0,     0,     0,     0,     0,     0,     0,     0],
   [0,     0,     0,     0,     0,     0,     0,     0],
   [0,     0,     0,     0,     0,     0,     0,     0],
   [0,     0,     0,     0,     0,     0,     0,     0],
   [0,     0,     0,     0,     0,     0,     0,     0],
   [0,     0,     0,     0,     0,     0,     0,     0],
   [0,     0,     0,     0,     0,     0,     0,     0]],
  [[0,     0,     0,     0,     0,     0,     0,     0],
   [0,     0,     0,     0,     0,     0,     0,     0],
   [0,     0,     0,     0,     0,     0,     0,     0],
   [0,     0,     0,     0,     0,     0,     0,     0],
   [0,     0,     0,     0,     0,     0,     0,     0],
   [0,     0,     0,     0,     0,     0,     0,     0],
   [0,     0,     0,     0,     0,     0,     0,     0],
   [0,     0,     0,     0,     0,     0,     0,     0]]]]
```

https://github.com/llvm/llvm-project/pull/84964


More information about the Mlir-commits mailing list