[Mlir-commits] [mlir] a5257ae - [mlir][amx] add a full tile matrix mult example to integration tests

Aart Bik llvmlistbot at llvm.org
Wed Jan 26 14:13:37 PST 2022


Author: Aart Bik
Date: 2022-01-26T14:13:30-08:00
New Revision: a5257ae277732261dc3b9de210b8d49cca2cf021

URL: https://github.com/llvm/llvm-project/commit/a5257ae277732261dc3b9de210b8d49cca2cf021
DIFF: https://github.com/llvm/llvm-project/commit/a5257ae277732261dc3b9de210b8d49cca2cf021.diff

LOG: [mlir][amx] add a full tile matrix mult example to integration tests

Rationale:
Demonstrates the maximum tile size allowed for the f32 <= bf16 x bf16 op

Reviewed By: dcaballe

Differential Revision: https://reviews.llvm.org/D118277

Added: 
    mlir/test/Integration/Dialect/Vector/CPU/AMX/test-mulf-full.mlir

Modified: 
    

Removed: 
    


################################################################################
diff  --git a/mlir/test/Integration/Dialect/Vector/CPU/AMX/test-mulf-full.mlir b/mlir/test/Integration/Dialect/Vector/CPU/AMX/test-mulf-full.mlir
new file mode 100644
index 0000000000000..bc1d6333bd721
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/AMX/test-mulf-full.mlir
@@ -0,0 +1,141 @@
+// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-std -convert-vector-to-llvm="enable-amx" -convert-memref-to-llvm -convert-std-to-llvm -reconcile-unrealized-casts | \
+// RUN: mlir-translate -mlir-to-llvmir | \
+// RUN: %lli --entry-function=entry --mattr="+amx-tile,+amx-int8,+amx-bf16" --dlopen=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+// Note: To run this test, your CPU must support AMX.
+
+// Multiply full size tiles into zero destination.
+func @kernel(%arg0: memref<16x32xbf16>,
+             %arg1: memref<16x32xbf16>,
+             %arg2: memref<16x16xf32>) {
+  %0 = arith.constant 0 : index
+  %1 = amx.tile_load %arg0[%0, %0] : memref<16x32xbf16>  into vector<16x32xbf16>
+  %2 = amx.tile_load %arg1[%0, %0] : memref<16x32xbf16>  into vector<16x32xbf16>
+  %3 = amx.tile_zero : vector<16x16xf32>
+  %4 = amx.tile_mulf %1, %2, %3 : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32>
+  amx.tile_store %arg2[%0, %0], %4 : memref<16x16xf32>, vector<16x16xf32>
+  return
+}
+
+func @entry() -> i32 {
+  %fu = arith.constant -1.0: f32
+  %c0 = arith.constant 0: index
+  %c1 = arith.constant 1: index
+  %c16 = arith.constant 16: index
+  %c32 = arith.constant 32: index
+
+  // Setup simple test data. Note that bf16 does not seem to work well with the
+  // tensor type yet, which is why we use vectors and transfer the data into memref.
+  // TODO: use tensors and bufferization.to_memref instead
+  %0 = arith.constant dense<[
+     [ 1.1, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
+       1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
+     [ 1.0, 1.2, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
+       1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
+     [ 1.0, 1.0, 1.3, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
+       1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
+     [ 1.0, 1.0, 1.0, 1.4, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
+       1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
+     [ 1.0, 1.0, 1.0, 1.0, 1.5, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
+       1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
+     [ 1.0, 1.0, 1.0, 1.0, 1.0, 1.6, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
+       1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
+     [ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.7, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
+       1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
+     [ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.8, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
+       1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
+     [ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.8, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
+       1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
+     [ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.7, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
+       1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
+     [ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.6, 1.0, 1.0, 1.0, 1.0, 1.0,
+       1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
+     [ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.5, 1.0, 1.0, 1.0, 1.0,
+       1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
+     [ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.4, 1.0, 1.0, 1.0,
+       1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
+     [ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.3, 1.0, 1.0,
+       1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
+     [ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.2, 1.0,
+       1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
+     [ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.1,
+       1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ]
+  ]> : vector<16x32xbf16>
+  %1 = arith.constant dense<[
+     [ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.1, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
+       1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
+     [ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.2, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
+       1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
+     [ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.3, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
+       1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
+     [ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.4, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
+       1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
+     [ 1.0, 1.0, 1.0, 1.0, 1.0, 1.5, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
+       1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
+     [ 1.0, 1.0, 1.0, 1.0, 1.6, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
+       1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
+     [ 1.0, 1.0, 1.0, 1.7, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
+       1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
+     [ 1.0, 1.0, 1.8, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
+       1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
+     [ 1.0, 1.0, 1.0, 1.8, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
+       1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
+     [ 1.0, 1.0, 1.0, 1.0, 1.7, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
+       1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
+     [ 1.0, 1.0, 1.0, 1.0, 1.0, 1.6, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
+       1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
+     [ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.5, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
+       1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
+     [ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.4, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
+       1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
+     [ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.3, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
+       1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
+     [ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.2, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
+       1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
+     [ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.1, 1.0, 1.0, 1.0, 1.0, 1.0,
+       1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ]
+  ]> : vector<16x32xbf16>
+
+  // Set up memory.
+  %a = memref.alloc() : memref<16x32xbf16>
+  %b = memref.alloc() : memref<16x32xbf16>
+  %c = memref.alloc() : memref<16x16xf32>
+  vector.transfer_write %0, %a[%c0, %c0] : vector<16x32xbf16>, memref<16x32xbf16>
+  vector.transfer_write %1, %b[%c0, %c0] : vector<16x32xbf16>, memref<16x32xbf16>
+
+  // Call kernel.
+  call @kernel(%a, %b, %c) : (memref<16x32xbf16>, memref<16x32xbf16>, memref<16x16xf32>) -> ()
+
+  //
+  // Print and verify the 16x16 result.
+  //
+  // CHECK:      ( 32.1016, 34.3984, 34.5078, 33.6953, 32.9062, 32.2031, 32.1016, 32.1016, 32.1016, 32.1016, 32.1016, 32.1016, 32.1016, 32.1016, 32.1016, 32.1016 )
+  // CHECK-NEXT: ( 32.2031, 34.5, 34.6094, 33.7969, 33.0284, 32.3047, 32.2031, 32.2031, 32.2031, 32.2031, 32.2031, 32.2031, 32.2031, 32.2031, 32.2031, 32.2031 )
+  // CHECK-NEXT: ( 32.2969, 34.5938, 34.7031, 33.8906, 33.1619, 32.3984, 32.2969, 32.2969, 32.2969, 32.2969, 32.2969, 32.2969, 32.2969, 32.2969, 32.2969, 32.2969 )
+  // CHECK-NEXT: ( 32.3984, 34.6953, 34.8047, 33.9922, 33.2031, 32.5, 32.3984, 32.3984, 32.3984, 32.3984, 32.3984, 32.3984, 32.3984, 32.3984, 32.3984, 32.3984 )
+  // CHECK-NEXT: ( 32.5, 34.7969, 34.9062, 34.0938, 33.3047, 32.6016, 32.5, 32.5, 32.5, 32.5, 32.5, 32.5, 32.5, 32.5, 32.5, 32.5 )
+  // CHECK-NEXT: ( 32.6016, 34.8984, 35.0078, 34.3739, 33.4062, 32.7031, 32.6016, 32.6016, 32.6016, 32.6016, 32.6016, 32.6016, 32.6016, 32.6016, 32.6016, 32.6016 )
+  // CHECK-NEXT: ( 32.7031, 35, 35.1094, 34.577, 33.5078, 32.8047, 32.7031, 32.7031, 32.7031, 32.7031, 32.7031, 32.7031, 32.7031, 32.7031, 32.7031, 32.7031 )
+  // CHECK-NEXT: ( 32.7969, 35.0938, 35.2031, 34.3906, 33.6016, 32.8984, 32.7969, 32.7969, 32.7969, 32.7969, 32.7969, 32.7969, 32.7969, 32.7969, 32.7969, 32.7969 )
+  // CHECK-NEXT: ( 32.7969, 35.0938, 35.2031, 34.3906, 33.6016, 32.8984, 32.7969, 32.7969, 32.7969, 32.7969, 32.7969, 32.7969, 32.7969, 32.7969, 32.7969, 32.7969 )
+  // CHECK-NEXT: ( 32.7031, 35, 35.4609, 34.2969, 33.5078, 32.8047, 32.7031, 32.7031, 32.7031, 32.7031, 32.7031, 32.7031, 32.7031, 32.7031, 32.7031, 32.7031 )
+  // CHECK-NEXT: ( 32.6016, 34.8984, 35.3697, 34.1953, 33.4062, 32.7031, 32.6016, 32.6016, 32.6016, 32.6016, 32.6016, 32.6016, 32.6016, 32.6016, 32.6016, 32.6016 )
+  // CHECK-NEXT: ( 32.5, 34.7969, 34.9062, 34.0938, 33.3047, 32.6016, 32.5, 32.5, 32.5, 32.5, 32.5, 32.5, 32.5, 32.5, 32.5, 32.5 )
+  // CHECK-NEXT: ( 32.3984, 34.6953, 34.8047, 33.9922, 33.2031, 32.5, 32.3984, 32.3984, 32.3984, 32.3984, 32.3984, 32.3984, 32.3984, 32.3984, 32.3984, 32.3984 )
+  // CHECK-NEXT: ( 32.2969, 34.8025, 34.7031, 33.8906, 33.1016, 32.3984, 32.2969, 32.2969, 32.2969, 32.2969, 32.2969, 32.2969, 32.2969, 32.2969, 32.2969, 32.2969 )
+  // CHECK-NEXT: ( 32.2031, 34.6619, 34.6094, 33.7969, 33.0078, 32.3047, 32.2031, 32.2031, 32.2031, 32.2031, 32.2031, 32.2031, 32.2031, 32.2031, 32.2031, 32.2031 )
+  // CHECK-NEXT: ( 32.1016, 34.3984, 34.5078, 33.6953, 32.9062, 32.2031, 32.1016, 32.1016, 32.1016, 32.1016, 32.1016, 32.1016, 32.1016, 32.1016, 32.1016, 32.1016 )
+  //
+  scf.for %i = %c0 to %c16 step %c1 {
+    %v = vector.transfer_read %c[%i, %c0], %fu: memref<16x16xf32>, vector<16xf32>
+    vector.print %v : vector<16xf32>
+  }
+
+  // Release resources.
+  memref.dealloc %a : memref<16x32xbf16>
+  memref.dealloc %b : memref<16x32xbf16>
+  memref.dealloc %c : memref<16x16xf32>
+  %i0 = arith.constant 0 : i32
+  return %i0 : i32
+}


        


More information about the Mlir-commits mailing list