[Mlir-commits] [mlir] c555308 - [mlir][linalg] Add an end-to-end test for linalg.fill to ArmSME
Cullen Rhodes
llvmlistbot at llvm.org
Tue Aug 29 02:46:20 PDT 2023
Author: Cullen Rhodes
Date: 2023-08-29T09:44:52Z
New Revision: c555308d4fed150636622ccf3798dbf4a440fbff
URL: https://github.com/llvm/llvm-project/commit/c555308d4fed150636622ccf3798dbf4a440fbff
DIFF: https://github.com/llvm/llvm-project/commit/c555308d4fed150636622ccf3798dbf4a440fbff.diff
LOG: [mlir][linalg] Add an end-to-end test for linalg.fill to ArmSME
This patch adds the first integration test for ArmSME in Linalg. It
fills a 2-d scalable vector that represents an SME ZA tile with a
pre-defined f32 value and prints it to stdout.
This test is predicated on the MLIR_RUN_ARM_SME_TESTS configuration flag
being set to true.
Depends on D158586
Reviewed By: awarzynski
Differential Revision: https://reviews.llvm.org/D158619
Added:
mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir
mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/lit.local.cfg
Modified:
Removed:
################################################################################
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir
new file mode 100644
index 00000000000000..dabf0dac4680e5
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir
@@ -0,0 +1,119 @@
+// RUN: mlir-opt %s \
+// RUN: -test-transform-dialect-interpreter \
+// RUN: -test-transform-dialect-erase-schedule \
+// RUN: -lower-vector-mask \
+// RUN: -one-shot-bufferize="bufferize-function-boundaries" \
+// RUN: -enable-arm-streaming="mode=locally enable-za" \
+// RUN: -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
+// RUN: -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \
+// RUN: -allocate-arm-sme-tiles -test-lower-to-llvm | \
+// RUN: %mcr_aarch64_cmd \
+// RUN: -e=entry -entry-point-result=void \
+// RUN: -march=aarch64 -mattr="+sve,+sme" \
+// RUN: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils | \
+// RUN: FileCheck %s
+
+func.func @printTestEnd() {
+ %0 = llvm.mlir.addressof @str_sme_end : !llvm.ptr<array<24 x i8>>
+ %1 = llvm.mlir.constant(0 : index) : i64
+ %2 = llvm.getelementptr %0[%1, %1]
+ : (!llvm.ptr<array<24 x i8>>, i64, i64) -> !llvm.ptr<i8>
+ llvm.call @printCString(%2) : (!llvm.ptr<i8>) -> ()
+ return
+}
+
+func.func @entry() {
+ %c0 = arith.constant 0 : index
+ %c4 = arith.constant 4 : index
+ %step = arith.constant 1 : index
+
+ %c123_f32 = arith.constant 123.0 : f32
+
+ %min_elts_s = arith.constant 4 : index
+ %vscale = vector.vscale
+
+ // "svl" refers to the Streaming Vector Length and "svl_s" the number of
+ // 32-bit elements in a vector of SVL bits.
+ %svl_s = arith.muli %min_elts_s, %vscale : index
+
+ %tile_init = bufferization.alloc_tensor(%svl_s, %svl_s) : tensor<?x?xf32>
+
+ // Initialize tile with "123.0".
+ // TODO: this could be simplified to tensor.splat + tensor.insert_slice once
+ // splat supports dynamically shaped tensors.
+ %tile_0 = scf.for %i = %c0 to %svl_s step %step iter_args(%tile_partial = %tile_init) -> tensor<?x?xf32> {
+ %inner_tile = scf.for %j = %c0 to %svl_s step %step iter_args(%inner_tile_partial = %tile_partial) -> tensor<?x?xf32> {
+ %tile_update = tensor.insert %c123_f32 into %inner_tile_partial[%i, %j] : tensor<?x?xf32>
+ scf.yield %tile_update : tensor<?x?xf32>
+ }
+ scf.yield %inner_tile : tensor<?x?xf32>
+ }
+
+ // Print tile after initialization. The smallest SVL is 128-bits so the tile
+ // will be at least 4x4xf32.
+ //
+ // CHECK: ( 123, 123, 123, 123
+ // CHECK-NEXT: ( 123, 123, 123, 123
+ // CHECK-NEXT: ( 123, 123, 123, 123
+ // CHECK-NEXT: ( 123, 123, 123, 123
+ scf.for %i = %c0 to %svl_s step %step {
+ vector.print punctuation <open>
+ scf.for %j = %c0 to %svl_s step %step {
+ %element = tensor.extract %tile_0[%i, %j] : tensor<?x?xf32>
+ vector.print %element : f32 punctuation <no_punctuation>
+
+ // Print comma unless last element.
+ %c1_index = arith.constant 1 : index
+ %last_i = arith.subi %svl_s, %c1_index : index
+ %isNotLastIter = arith.cmpi ult, %j, %last_i : index
+ scf.if %isNotLastIter {
+ vector.print punctuation <comma>
+ }
+ }
+ vector.print punctuation <close>
+ vector.print punctuation <newline>
+ }
+
+ // Fill tile with pi.
+ %pi = arith.constant 3.14 : f32
+ %tile_1 = linalg.fill ins(%pi : f32) outs(%tile_0 : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+ // Print tile after filling with pi. The smallest SVL is 128-bits so the tile
+ // will be at least 4x4xf32.
+ //
+ // CHECK: ( 3.14, 3.14, 3.14, 3.14
+ // CHECK-NEXT: ( 3.14, 3.14, 3.14, 3.14
+ // CHECK-NEXT: ( 3.14, 3.14, 3.14, 3.14
+ // CHECK-NEXT: ( 3.14, 3.14, 3.14, 3.14
+ scf.for %i = %c0 to %svl_s step %step {
+ vector.print punctuation <open>
+ scf.for %j = %c0 to %svl_s step %step {
+ %element = tensor.extract %tile_1[%i, %j] : tensor<?x?xf32>
+ vector.print %element : f32 punctuation <no_punctuation>
+
+ // Print comma unless last element.
+ %c1_index = arith.constant 1 : index
+ %last_i = arith.subi %svl_s, %c1_index : index
+ %isNotLastIter = arith.cmpi ult, %j, %last_i : index
+ scf.if %isNotLastIter {
+ vector.print punctuation <comma>
+ }
+ }
+ vector.print punctuation <close>
+ vector.print punctuation <newline>
+ }
+
+ // CHECK: SME: END OF TEST OUTPUT
+ func.call @printTestEnd() : () -> ()
+
+ return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+ %0 = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.structured.masked_vectorize %0 vector_sizes [[4], [4]] : !transform.any_op
+}
+
+llvm.func @printCString(!llvm.ptr<i8>)
+llvm.mlir.global internal constant @str_sme_end("SME: END OF TEST OUTPUT\0A")
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/lit.local.cfg b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/lit.local.cfg
new file mode 100644
index 00000000000000..296b4419438e8a
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/lit.local.cfg
@@ -0,0 +1,9 @@
+import sys
+
+# ArmSME tests must be enabled via build flag.
+if not config.mlir_run_arm_sme_tests:
+ config.unsupported = True
+
+# No JIT on win32.
+if sys.platform == "win32":
+ config.unsupported = True
More information about the Mlir-commits
mailing list