[Mlir-commits] [mlir] d0c9fb1 - [mlir][Linalg] Improve codegen strategy
Nicolas Vasilache
llvmlistbot at llvm.org
Thu Jan 28 03:03:23 PST 2021
Author: Nicolas Vasilache
Date: 2021-01-28T10:59:16Z
New Revision: d0c9fb1b8ebf50ea8b489354b946da961a384c8b
URL: https://github.com/llvm/llvm-project/commit/d0c9fb1b8ebf50ea8b489354b946da961a384c8b
DIFF: https://github.com/llvm/llvm-project/commit/d0c9fb1b8ebf50ea8b489354b946da961a384c8b.diff
LOG: [mlir][Linalg] Improve codegen strategy
This revision improves the usage of the codegen strategy by adding a few flags that
make it easier to control for the CLI.
Usage of ModuleOp is replaced by FuncOp as this created issues in multi-threaded mode.
A simple benchmarking capability is added for linalg.matmul as well as linalg.matmul_column_major.
This latter op is also added to linalg.
Now obsolete linalg integration tests that also take too long are deleted.
Correctness checks are still missing at this point.
Differential revision: https://reviews.llvm.org/D95531
Added:
mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul.mlir
mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul_column_major.mlir
mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul_column_major_as_row_major.mlir
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc
mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/test/CMakeLists.txt
mlir/test/Dialect/Linalg/codegen-strategy.mlir
mlir/test/lib/Transforms/TestLinalgCodegenStrategy.cpp
Removed:
mlir/test/mlir-cpu-runner/CMakeLists.txt
mlir/test/mlir-cpu-runner/include/mlir_test_cblas.h
mlir/test/mlir-cpu-runner/include/mlir_test_cblas_interface.h
mlir/test/mlir-cpu-runner/linalg_integration_test.mlir
mlir/test/mlir-cpu-runner/mlir_test_cblas.cpp
mlir/test/mlir-cpu-runner/mlir_test_cblas_interface.cpp
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc
index 765e045e9e77..e6d1e1935367 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc
@@ -3,6 +3,11 @@ def matmul(A: f32(M, K), B: f32(K, N)) -> (C: f32(M, N)) {
C(m, n) = std_addf<k>(std_mulf(A(m, k), B(k, n)));
}
+ods_def<MatmulColumnMajorOp>:
+def matmul_column_major(A: f32(K, M), B: f32(N, K)) -> (C: f32(N, M)) {
+ C(n, m) = std_addf<k>(std_mulf(A(k, m), B(n, k)));
+}
+
ods_def<MatvecOp>:
def matvec(A: f32(M, N), y: f32(N)) -> (x: f32(M)) {
x(m) = std_addf<n>(std_mulf(A(m, n), y(n)));
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 436dab1ade2b..c88e1201f84b 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -143,6 +143,10 @@ def CopyOp : LinalgStructured_Op<"copy", [CopyOpInterface]> {
}];
let verifier = [{ return ::verify(*this); }];
+ let assemblyFormat = [{
+ `(` operands `)` attr-dict `:` type(operands)
+ }];
+
let hasFolder = 1;
let hasCanonicalizer = 1;
}
diff --git a/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul.mlir b/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul.mlir
new file mode 100644
index 000000000000..3c589d163857
--- /dev/null
+++ b/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul.mlir
@@ -0,0 +1,99 @@
+// RUN: export M=24 && export K=64 && export N=192 && export ITERS=10 && \
+// RUN: cat %s | sed 's@${M}@'"$M"'@g'| sed 's@${K}@'"$K"'@g' | sed 's@${N}@'"$N"'@g'| sed 's@${ITERS}@'"$ITERS"'@g'| \
+// RUN: mlir-opt -test-linalg-codegen-strategy="anchor-op=linalg.matmul register-tile-sizes=12,32,16 vectorize" | \
+// RUN: mlir-opt -test-linalg-codegen-strategy="anchor-op=linalg.fill register-tile-sizes=4,32 vectorize" | \
+// RUN: mlir-opt -test-linalg-codegen-strategy="anchor-op=linalg.copy register-tile-sizes=4,32 vectorize" | \
+
+// RUN: mlir-opt -canonicalize -convert-vector-to-scf -lower-affine -convert-linalg-to-loops | \
+// RUN: mlir-opt -canonicalize -convert-scf-to-std -convert-vector-to-llvm | \
+// RUN: mlir-cpu-runner -O3 -e main -entry-point-result=void \
+// Activate to dump assembly
+// R_UN: -dump-object-file -object-filename=/tmp/a.o \
+// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext | \
+// Use tee to both print to stderr and FileCheck
+// RUN: tee -a /dev/stderr | FileCheck %s
+
+
+!row_major_A = type memref<${M}x${K}xf32>
+!row_major_B = type memref<${K}x${N}xf32>
+!row_major_C = type memref<${M}x${N}xf32>
+
+func @matmul(%a: !row_major_A, %b: !row_major_B, %c: !row_major_C)
+// TODO: activate manually for now.
+// attributes { passthrough = [["target-cpu", "skylake-avx512"], ["prefer-vector-width", "512"]]}
+{
+ linalg.matmul ins(%a, %b : !row_major_A, !row_major_B)
+ outs(%c: !row_major_C)
+ return
+}
+
+func @print_perf(%iters: index, %total_time: f64) {
+ %c2 = constant 2 : index
+ %cM = constant ${M} : index
+ %cN = constant ${N} : index
+ %cK = constant ${K} : index
+
+ %mn = muli %cM, %cN : index
+ %mnk = muli %mn, %cK : index
+
+ // 2*M*N*K.
+ %flops_per_iter = muli %c2, %mnk : index
+ %flops = muli %iters, %flops_per_iter : index
+ %flops_i64 = index_cast %flops : index to i64
+ %flops_f = sitofp %flops_i64 : i64 to f64
+ %flops_per_s = divf %flops_f, %total_time : f64
+ vector.print %flops_per_s : f64
+
+ return
+}
+
+func @main() {
+ %f0 = constant 0.0 : f32
+ %f1 = constant 1.0 : f32
+
+ %A = alloc() : !row_major_A
+ %B = alloc() : !row_major_B
+ %C = alloc() : !row_major_C
+
+ linalg.fill(%A, %f1) : !row_major_A, f32
+ linalg.fill(%B, %f1) : !row_major_B, f32
+ linalg.fill(%C, %f0) : !row_major_C, f32
+
+ %c0 = constant 0: index
+ %c1 = constant 1: index
+ %iters = constant ${ITERS}: index
+
+ /// Run and dump performance for matmul.
+ /// Preheating run:
+ scf.for %arg0 = %c0 to %iters step %c1 {
+ linalg.fill(%C, %f0) : !row_major_C, f32
+ call @matmul(%A, %B, %C) : (!row_major_A, !row_major_B, !row_major_C) -> ()
+ }
+ %t_start_matmul = call @rtclock() : () -> f64
+ scf.for %arg0 = %c0 to %iters step %c1 {
+ // linalg.matmul writes %C in place, need to reset it to zero every time.
+ // This is accounts for about 10-15% perf hit on small sizes.
+ // Once linalg on tensors is ready, fusing fill at teh register level will
+ // be easy.
+ linalg.fill(%C, %f0) : !row_major_C, f32
+ call @matmul(%A, %B, %C) : (!row_major_A, !row_major_B, !row_major_C) -> ()
+ }
+ %t_end_matmul = call @rtclock() : () -> f64
+ %tmatmul = subf %t_end_matmul, %t_start_matmul: f64
+ call @print_perf(%iters, %tmatmul) : (index, f64) -> ()
+
+ %res = load %C[%c0, %c0]: !row_major_C
+ // CHECK: 64
+ vector.print %res: f32
+
+ dealloc %A : !row_major_A
+ dealloc %B : !row_major_B
+ dealloc %C : !row_major_C
+
+ return
+}
+
+func private @rtclock() -> f64
+
+// TODO: init with random, run and check output.
+// func private @fill_random_f32(memref<*xf32>)
diff --git a/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul_column_major.mlir b/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul_column_major.mlir
new file mode 100644
index 000000000000..a71643fde480
--- /dev/null
+++ b/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul_column_major.mlir
@@ -0,0 +1,98 @@
+// RUN: export M=24 && export K=64 && export N=192 && export ITERS=10 && \
+// RUN: cat %s | sed 's@${M}@'"$M"'@g'| sed 's@${K}@'"$K"'@g' | sed 's@${N}@'"$N"'@g'| sed 's@${ITERS}@'"$ITERS"'@g'| \
+// RUN: mlir-opt -test-linalg-codegen-strategy="anchor-op=linalg.matmul_column_major register-tile-sizes=16,0,32 vectorize" | \
+// RUN: mlir-opt -test-linalg-codegen-strategy="anchor-op=linalg.fill register-tile-sizes=4,16 vectorize" | \
+
+// TODO: linalg.copy vectorization in the presence of permutation map fails. Enable when addressed.
+// RUN: mlir-opt -test-linalg-codegen-strategy="anchor-op=linalg.copy register-tile-sizes=4,16 vectorize" | \
+
+// RUN: mlir-opt -canonicalize -convert-vector-to-scf -lower-affine -convert-linalg-to-loops | \
+// RUN: mlir-opt -canonicalize -convert-scf-to-std -convert-vector-to-llvm | \
+// RUN: mlir-cpu-runner -O3 -e main -entry-point-result=void \
+// Activate to dump assembly
+// R_UN: -dump-object-file -object-filename=/tmp/a.o \
+// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext | \
+// Use tee to both print to stderr and FileCheck
+// RUN: tee -a /dev/stderr | FileCheck %s
+
+!row_major_A = type memref<${M}x${K}xf32>
+!row_major_B = type memref<${K}x${N}xf32>
+!row_major_C = type memref<${M}x${N}xf32>
+!column_major_A = type memref<${K}x${M}xf32>
+!column_major_B = type memref<${N}x${K}xf32>
+!column_major_C = type memref<${N}x${M}xf32>
+
+func @matmul_column_major(%a: !column_major_A, %b: !column_major_B, %c: !column_major_C)
+// TODO: activate manually for now.
+// attributes { passthrough = [["target-cpu", "skylake-avx512"], ["prefer-vector-width", "512"]]}
+{
+ linalg.matmul_column_major ins(%a, %b : !column_major_A, !column_major_B)
+ outs(%c: !column_major_C)
+ return
+}
+
+func @print_perf(%iters: index, %total_time: f64) {
+ %c2 = constant 2 : index
+ %cM = constant ${M} : index
+ %cN = constant ${N} : index
+ %cK = constant ${K} : index
+
+ %mn = muli %cM, %cN : index
+ %mnk = muli %mn, %cK : index
+
+ // 2*M*N*K.
+ %flops_per_iter = muli %c2, %mnk : index
+ %flops = muli %iters, %flops_per_iter : index
+ %flops_i64 = index_cast %flops : index to i64
+ %flops_f = sitofp %flops_i64 : i64 to f64
+ %flops_per_s = divf %flops_f, %total_time : f64
+ vector.print %flops_per_s : f64
+
+ return
+}
+
+func @main() {
+ %f0 = constant 0.0 : f32
+ %f1 = constant 1.0 : f32
+
+ %cA = alloc() : !column_major_A
+ %cB = alloc() : !column_major_B
+ %cC = alloc() : !column_major_C
+
+ linalg.fill(%cA, %f1) : !column_major_A, f32
+ linalg.fill(%cB, %f1) : !column_major_B, f32
+ linalg.fill(%cC, %f0) : !column_major_C, f32
+
+ %c0 = constant 0: index
+ %c1 = constant 1: index
+ %iters = constant ${ITERS}: index
+
+ /// Run and dump performance for matmul_column_major.
+ %t_start_matmul_column_major = call @rtclock() : () -> f64
+ scf.for %arg0 = %c0 to %iters step %c1 {
+ // linalg.matmul writes %C in place, need to reset it to zero every time.
+ // This is accounts for about 10-15% perf hit on small sizes.
+ // Once linalg on tensors is ready, fusing fill at teh register level will
+ // be easy.
+ linalg.fill(%cC, %f0) : !column_major_C, f32
+ call @matmul_column_major(%cA, %cB, %cC) : (!column_major_A, !column_major_B, !column_major_C) -> ()
+ }
+ %t_end_matmul_column_major = call @rtclock() : () -> f64
+ %tmatmul_column_major = subf %t_end_matmul_column_major, %t_start_matmul_column_major: f64
+ call @print_perf(%iters, %tmatmul_column_major) : (index, f64) -> ()
+
+ %res = load %cC[%c0, %c0]: !column_major_C
+ // CHECK: 64
+ vector.print %res: f32
+
+ dealloc %cA : !column_major_A
+ dealloc %cB : !column_major_B
+ dealloc %cC : !column_major_C
+
+ return
+}
+
+func private @rtclock() -> f64
+
+// TODO: init with random, run and check output.
+// func private @fill_random_f32(memref<*xf32>)
diff --git a/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul_column_major_as_row_major.mlir b/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul_column_major_as_row_major.mlir
new file mode 100644
index 000000000000..c8f3fe4b95d4
--- /dev/null
+++ b/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul_column_major_as_row_major.mlir
@@ -0,0 +1,116 @@
+// RUN: export M=24 && export K=64 && export N=192 && export ITERS=10 && \
+// RUN: cat %s | sed 's@${M}@'"$M"'@g'| sed 's@${K}@'"$K"'@g' | sed 's@${N}@'"$N"'@g'| sed 's@${ITERS}@'"$ITERS"'@g'| \
+// RUN: mlir-opt -test-linalg-codegen-strategy="anchor-op=linalg.matmul_column_major register-tile-sizes=16,0,32 vectorize" | \
+// RUN: mlir-opt -test-linalg-codegen-strategy="anchor-op=linalg.matmul register-tile-sizes=12,32,16 vectorize" | \
+// RUN: mlir-opt -test-linalg-codegen-strategy="anchor-op=linalg.fill register-tile-sizes=4,16 vectorize" | \
+
+// TODO: linalg.copy vectorization in the presence of permutation map fails. Enable when addressed.
+// R_UN: mlir-opt -test-linalg-codegen-strategy="anchor-op=linalg.copy register-tile-sizes=4,16 vectorize" | \
+
+// RUN: mlir-opt -canonicalize -convert-vector-to-scf -lower-affine -convert-linalg-to-loops | \
+// RUN: mlir-opt -canonicalize -convert-scf-to-std -convert-vector-to-llvm | \
+// RUN: mlir-cpu-runner -O3 -e main -entry-point-result=void \
+// Activate to dump assembly
+// R_UN: -dump-object-file -object-filename=/tmp/a.o \
+// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext | \
+// Use tee to both print to stderr and FileCheck
+// RUN: tee -a /dev/stderr | FileCheck %s
+
+!row_major_A = type memref<${M}x${K}xf32>
+!row_major_B = type memref<${K}x${N}xf32>
+!row_major_C = type memref<${M}x${N}xf32>
+!column_major_A = type memref<${K}x${M}xf32>
+!column_major_B = type memref<${N}x${K}xf32>
+!column_major_C = type memref<${N}x${M}xf32>
+
+func @matmul_column_major_as_row_major(
+ %ca: !column_major_A, %cb: !column_major_B, %cc: !column_major_C,
+ %a: !row_major_A, %b: !row_major_B, %c: !row_major_C)
+// TODO: activate manually for now.
+// attributes { passthrough = [["target-cpu", "skylake-avx512"], ["prefer-vector-width", "512"]]}
+{
+ linalg.copy(%ca, %a) {inputPermutation = affine_map<(i, j) -> (j, i)> } : !column_major_A, !row_major_A
+ linalg.copy(%cb, %b) {inputPermutation = affine_map<(i, j) -> (j, i)> } : !column_major_B, !row_major_B
+ linalg.matmul ins(%a, %b : !row_major_A, !row_major_B)
+ outs(%c: !row_major_C)
+ linalg.copy(%c, %cc) {inputPermutation = affine_map<(i, j) -> (j, i)> } : !row_major_C, !column_major_C
+ return
+}
+
+func @print_perf(%iters: index, %total_time: f64) {
+ %c2 = constant 2 : index
+ %cM = constant ${M} : index
+ %cN = constant ${N} : index
+ %cK = constant ${K} : index
+
+ %mn = muli %cM, %cN : index
+ %mnk = muli %mn, %cK : index
+
+ // 2*M*N*K.
+ %flops_per_iter = muli %c2, %mnk : index
+ %flops = muli %iters, %flops_per_iter : index
+ %flops_i64 = index_cast %flops : index to i64
+ %flops_f = sitofp %flops_i64 : i64 to f64
+ %flops_per_s = divf %flops_f, %total_time : f64
+ vector.print %flops_per_s : f64
+
+ return
+}
+
+func @main() {
+ %f0 = constant 0.0 : f32
+ %f1 = constant 1.0 : f32
+
+ %cA = alloc() : !column_major_A
+ %cB = alloc() : !column_major_B
+ %cC = alloc() : !column_major_C
+
+ linalg.fill(%cA, %f1) : !column_major_A, f32
+ linalg.fill(%cB, %f1) : !column_major_B, f32
+ linalg.fill(%cC, %f0) : !column_major_C, f32
+
+ %c0 = constant 0: index
+ %c1 = constant 1: index
+ %iters = constant ${ITERS}: index
+
+ /// Run and dump performance for matmul_column_major as a row-major
+ %A = alloc() : !row_major_A
+ %B = alloc() : !row_major_B
+ %C = alloc() : !row_major_C
+ %t_start_matmul_column_major_as_row_major = call @rtclock() : () -> f64
+ scf.for %arg0 = %c0 to %iters step %c1 {
+ // linalg.matmul writes %C in place, need to reset it to zero every time.
+ // This is accounts for about 10-15% perf hit on small sizes.
+ // Once linalg on tensors is ready, fusing fill at teh register level will
+ // be easy.
+ linalg.fill(%C, %f0) : !row_major_C, f32
+ call @matmul_column_major_as_row_major(%cA, %cB, %cC, %A, %B, %C) :
+ (!column_major_A, !column_major_B, !column_major_C,
+ !row_major_A, !row_major_B, !row_major_C) -> ()
+ }
+ %t_end_matmul_column_major_as_row_major = call @rtclock() : () -> f64
+ %tmatmul_column_major_as_row_major = subf %t_end_matmul_column_major_as_row_major, %t_start_matmul_column_major_as_row_major: f64
+ call @print_perf(%iters, %tmatmul_column_major_as_row_major) : (index, f64) -> ()
+
+ %res = load %cC[%c0, %c0]: !column_major_C
+ // CHECK: 64
+ vector.print %res: f32
+ %res2 = load %C[%c0, %c0]: !row_major_C
+ // CHECK: 64
+ vector.print %res2: f32
+
+ dealloc %A : !row_major_A
+ dealloc %B : !row_major_B
+ dealloc %C : !row_major_C
+
+ dealloc %cA : !column_major_A
+ dealloc %cB : !column_major_B
+ dealloc %cC : !column_major_C
+
+ return
+}
+
+func private @rtclock() -> f64
+
+// TODO: init with random, run and check output.
+// func private @fill_random_f32(memref<*xf32>)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp b/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp
index 652a036838ed..02058f886451 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp
@@ -51,10 +51,11 @@ void mlir::linalg::CodegenStrategy::transform(FuncOp func) const {
// Some of these may be too aggressive as a stage 3 that is applied on each
// stage 1 application and may have to be split out to post staged patterns
// application (in which case they could just be passes, TBD).
- PassManager pm(op->getContext());
- pm.addPass(createLoopInvariantCodeMotionPass());
- if (failed(pm.run(op->getParentOfType<ModuleOp>())))
- llvm_unreachable("Unexpected failure in cleanup pass pipeline.");
+ op->walk([&](LoopLikeOpInterface loopLike) {
+ LLVM_DEBUG(loopLike.print(llvm::dbgs() << "\nOriginal loop:\n"));
+ if (failed(moveLoopInvariantCode(loopLike)))
+ llvm_unreachable("unexpected LICM failure");
+ });
promoteSingleIterationLoops(cast<FuncOp>(op));
hoistViewAllocOps(cast<FuncOp>(op));
hoistRedundantVectorTransfers(cast<FuncOp>(op));
@@ -67,13 +68,11 @@ void mlir::linalg::CodegenStrategy::transform(FuncOp func) const {
// Post staged patterns transforms
//===--------------------------------------------------------------------===//
- ModuleOp module = func->getParentOfType<ModuleOp>();
-
// Programmatic splitting of slow/fast path vector transfers.
OwningRewritePatternList patterns;
patterns.insert<vector::VectorTransferFullPartialRewriter>(
context, vectorTransformsOptions);
- applyPatternsAndFoldGreedily(module, std::move(patterns));
+ applyPatternsAndFoldGreedily(func, std::move(patterns));
// Programmatic controlled lowering of vector.contract only.
OwningRewritePatternList vectorContractLoweringPatterns;
@@ -81,17 +80,16 @@ void mlir::linalg::CodegenStrategy::transform(FuncOp func) const {
.insert<ContractionOpToOuterProductOpLowering,
ContractionOpToMatmulOpLowering, ContractionOpLowering>(
vectorTransformsOptions, context);
- applyPatternsAndFoldGreedily(module,
- std::move(vectorContractLoweringPatterns));
+ applyPatternsAndFoldGreedily(func, std::move(vectorContractLoweringPatterns));
// Programmatic controlled lowering of vector.transfer only.
OwningRewritePatternList vectorToLoopsPatterns;
populateVectorToSCFConversionPatterns(vectorToLoopsPatterns, context,
vectorToSCFOptions);
- applyPatternsAndFoldGreedily(module, std::move(vectorToLoopsPatterns));
+ applyPatternsAndFoldGreedily(func, std::move(vectorToLoopsPatterns));
// Ensure we drop the marker in the end.
- module.walk([](LinalgOp op) {
+ func.walk([](LinalgOp op) {
op.removeAttr(LinalgTransforms::kLinalgTransformMarker);
});
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 2a1d4cd2ef57..76a5bb56a4b6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -68,8 +68,8 @@ static bool hasMultiplyAddBody(Region &r) {
// TODO: Should be Tablegen'd from a single source that generates the op itself.
static LogicalResult isContraction(Operation *op) {
// TODO: interface for named ops.
- if (isa<linalg::BatchMatmulOp, linalg::MatmulOp, linalg::MatvecOp,
- linalg::VecmatOp, linalg::DotOp>(op))
+ if (isa<linalg::BatchMatmulOp, linalg::MatmulOp, linalg::MatmulColumnMajorOp,
+ linalg::MatvecOp, linalg::VecmatOp, linalg::DotOp>(op))
return success();
auto genericOp = dyn_cast<linalg::GenericOp>(op);
diff --git a/mlir/test/CMakeLists.txt b/mlir/test/CMakeLists.txt
index 293d93268a11..9ab9c0932dcd 100644
--- a/mlir/test/CMakeLists.txt
+++ b/mlir/test/CMakeLists.txt
@@ -1,7 +1,6 @@
add_subdirectory(Bindings)
add_subdirectory(CAPI)
add_subdirectory(EDSC)
-add_subdirectory(mlir-cpu-runner)
add_subdirectory(SDBM)
add_subdirectory(lib)
@@ -54,8 +53,6 @@ set(MLIR_TEST_DEPENDS
mlir-sdbm-api-test
mlir-tblgen
mlir-translate
- mlir_test_cblas
- mlir_test_cblas_interface
mlir_runner_utils
mlir_c_runner_utils
mlir_async_runtime
diff --git a/mlir/test/Dialect/Linalg/codegen-strategy.mlir b/mlir/test/Dialect/Linalg/codegen-strategy.mlir
index 49ef2e6f26df..60cf4fe89482 100644
--- a/mlir/test/Dialect/Linalg/codegen-strategy.mlir
+++ b/mlir/test/Dialect/Linalg/codegen-strategy.mlir
@@ -1,5 +1,5 @@
-// RUN: mlir-opt %s -test-linalg-codegen-strategy="tile-sizes=2,4,8 vectorize vectorize-contraction-to=matrixintrinsics unroll-vector-transfers=true" | FileCheck %s
-// RUN: mlir-opt %s -test-linalg-codegen-strategy="tile-sizes=16,32,64 promote promote-full-tile-pad register-tile-sizes=2,4,8 vectorize vectorize-contraction-to=outerproduct split-transfers=true unroll-vector-transfers=false" | FileCheck %s --check-prefix=OUTER
+// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-op=linalg.matmul tile-sizes=2,4,8 vectorize vectorize-contraction-to=matrixintrinsics unroll-vector-transfers=true" | FileCheck %s
+// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-op=linalg.matmul tile-sizes=16,32,64 promote promote-full-tile-pad register-tile-sizes=2,4,8 vectorize vectorize-contraction-to=outerproduct split-transfers=true unroll-vector-transfers=false" | FileCheck %s --check-prefix=OUTER
// CHECK-LABEL: func @matmul(
// OUTER-LABEL: func @matmul(
diff --git a/mlir/test/lib/Transforms/TestLinalgCodegenStrategy.cpp b/mlir/test/lib/Transforms/TestLinalgCodegenStrategy.cpp
index 5e8c92c5aa04..8d80de793658 100644
--- a/mlir/test/lib/Transforms/TestLinalgCodegenStrategy.cpp
+++ b/mlir/test/lib/Transforms/TestLinalgCodegenStrategy.cpp
@@ -42,6 +42,9 @@ struct TestLinalgCodegenStrategy
// clang-format on
}
+ template <typename LinalgNamedOp>
+ void applyStrategyToNamedLinalgOp();
+
void runOnFunction() override;
ListOption<int64_t> tileSizes{*this, "tile-sizes",
@@ -91,11 +94,21 @@ struct TestLinalgCodegenStrategy
*this, "unroll-vector-transfers",
llvm::cl::desc("Enable full unrolling of vector.transfer operations"),
llvm::cl::init(false)};
+ Option<std::string> anchorOpName{
+ *this, "anchor-op",
+ llvm::cl::desc(
+ "Which single linalg op is the anchor for the codegen strategy to "
+ "latch on:\n"
+ "\tlinalg.matmul: anchor on linalg.matmul\n"
+ "\tlinalg.matmul_column_major: anchor on linalg.matmul_column_major\n"
+ "\tlinalg.copy: anchor on linalg.copy\n"
+ "\tlinalg.fill: anchor on linalg.fill\n"),
+ llvm::cl::init("")};
};
} // end anonymous namespace
-/// Apply transformations specified as patterns.
-void TestLinalgCodegenStrategy::runOnFunction() {
+template <typename LinalgNamedOp>
+void TestLinalgCodegenStrategy::applyStrategyToNamedLinalgOp() {
LinalgTilingOptions tilingOptions;
if (!tileSizes.empty())
tilingOptions = tilingOptions.setTileSizes(tileSizes);
@@ -121,27 +134,42 @@ void TestLinalgCodegenStrategy::runOnFunction() {
.Default(vector::VectorTransferSplit::None);
CodegenStrategy strategy;
- strategy.tileIf<MatmulOp>(!tileSizes.empty(), tilingOptions)
- .promoteIf<MatmulOp>(promote,
- LinalgPromotionOptions()
- .setAlignment(16)
- .setUseFullTileBuffersByDefault(promoteFullTile))
- .tileIf<MatmulOp>(!registerTileSizes.empty(), registerTilingOptions)
- .promoteIf<MatmulOp>(registerPromote, LinalgPromotionOptions()
- .setAlignment(16)
- .setUseFullTileBuffersByDefault(
- registerPromoteFullTile))
- .vectorizeIf<MatmulOp>(vectorize)
+ strategy.template tileIf<LinalgNamedOp>(!tileSizes.empty(), tilingOptions)
+ .template promoteIf<LinalgNamedOp>(
+ promote, LinalgPromotionOptions()
+ .setAlignment(16)
+ .setUseFullTileBuffersByDefault(promoteFullTile))
+ .template tileIf<LinalgNamedOp>(!registerTileSizes.empty(),
+ registerTilingOptions)
+ .template promoteIf<LinalgNamedOp>(
+ registerPromote,
+ LinalgPromotionOptions()
+ .setAlignment(16)
+ .setUseFullTileBuffersByDefault(registerPromoteFullTile))
+ .template vectorizeIf<LinalgNamedOp>(vectorize)
.setVectorTransformsOptions(
vector::VectorTransformsOptions()
.setVectorTransformsOptions(vectorContractLowering)
.setVectorTransferSplit(vectorTransferSplit))
.setVectorTransferToSCFOptions(
VectorTransferToSCFOptions().setUnroll(unrollVectorTransfers));
-
strategy.transform(getFunction());
}
+/// Apply transformations specified as patterns.
+void TestLinalgCodegenStrategy::runOnFunction() {
+ if (anchorOpName == MatmulOp::getOperationName())
+ applyStrategyToNamedLinalgOp<MatmulOp>();
+ else if (anchorOpName == MatmulColumnMajorOp::getOperationName())
+ applyStrategyToNamedLinalgOp<MatmulColumnMajorOp>();
+ else if (anchorOpName == CopyOp::getOperationName())
+ applyStrategyToNamedLinalgOp<CopyOp>();
+ else if (anchorOpName == FillOp::getOperationName())
+ applyStrategyToNamedLinalgOp<FillOp>();
+ else
+ llvm_unreachable("Unsupported anchor op");
+}
+
namespace mlir {
namespace test {
void registerTestLinalgCodegenStrategy() {
diff --git a/mlir/test/mlir-cpu-runner/CMakeLists.txt b/mlir/test/mlir-cpu-runner/CMakeLists.txt
deleted file mode 100644
index 62f395271855..000000000000
--- a/mlir/test/mlir-cpu-runner/CMakeLists.txt
+++ /dev/null
@@ -1,12 +0,0 @@
-set(LLVM_OPTIONAL_SOURCES
- mlir_test_cblas.cpp
- mlir_test_cblas_interface.cpp
- )
-
-add_llvm_library(mlir_test_cblas SHARED mlir_test_cblas.cpp)
-target_compile_definitions(mlir_test_cblas PRIVATE mlir_test_cblas_EXPORTS)
-
-add_llvm_library(mlir_test_cblas_interface SHARED mlir_test_cblas_interface.cpp)
-target_link_libraries(mlir_test_cblas_interface PRIVATE mlir_test_cblas)
-target_compile_definitions(mlir_test_cblas_interface PRIVATE mlir_test_cblas_interface_EXPORTS)
-
diff --git a/mlir/test/mlir-cpu-runner/include/mlir_test_cblas.h b/mlir/test/mlir-cpu-runner/include/mlir_test_cblas.h
deleted file mode 100644
index 34af2d4c6467..000000000000
--- a/mlir/test/mlir-cpu-runner/include/mlir_test_cblas.h
+++ /dev/null
@@ -1,49 +0,0 @@
-//===- mlir_test_cblas.h - Simple Blas subset -----------------------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-#ifndef MLIR_CPU_RUNNER_MLIR_TEST_CBLAS_H_
-#define MLIR_CPU_RUNNER_MLIR_TEST_CBLAS_H_
-
-#include "mlir/ExecutionEngine/RunnerUtils.h"
-
-#ifdef _WIN32
-#ifndef MLIR_TEST_CBLAS_EXPORT
-#ifdef mlir_test_cblas_EXPORTS
-// We are building this library
-#define MLIR_TEST_CBLAS_EXPORT __declspec(dllexport)
-#else
-// We are using this library
-#define MLIR_TEST_CBLAS_EXPORT __declspec(dllimport)
-#endif // mlir_test_cblas_EXPORTS
-#endif // MLIR_TEST_CBLAS_EXPORT
-#else
-#define MLIR_TEST_CBLAS_EXPORT
-#endif // _WIN32
-
-/// This reproduces a minimal subset of mlir_test_cblas to allow integration
-/// testing without explicitly requiring a dependence on an external library.
-/// Without loss of generality, various mlir_test_cblas implementations may be
-/// swapped in by including the proper headers and linking with the proper
-/// library.
-enum CBLAS_ORDER { CblasRowMajor = 101, CblasColMajor = 102 };
-enum CBLAS_TRANSPOSE {
- CblasNoTrans = 111,
- CblasTrans = 112,
- CblasConjTrans = 113
-};
-
-extern "C" MLIR_TEST_CBLAS_EXPORT float
-mlir_test_cblas_sdot(const int N, const float *X, const int incX,
- const float *Y, const int incY);
-
-extern "C" MLIR_TEST_CBLAS_EXPORT void mlir_test_cblas_sgemm(
- const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA,
- const enum CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
- const float alpha, const float *A, const int lda, const float *B,
- const int ldb, const float beta, float *C, const int ldc);
-
-#endif // MLIR_CPU_RUNNER_MLIR_TEST_CBLAS_H_
diff --git a/mlir/test/mlir-cpu-runner/include/mlir_test_cblas_interface.h b/mlir/test/mlir-cpu-runner/include/mlir_test_cblas_interface.h
deleted file mode 100644
index 1f5a0e618c12..000000000000
--- a/mlir/test/mlir-cpu-runner/include/mlir_test_cblas_interface.h
+++ /dev/null
@@ -1,59 +0,0 @@
-//===- mlir_test_cblas_interface.h - Simple Blas subset interface ---------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-#ifndef MLIR_CPU_RUNNER_MLIR_TEST_CBLAS_INTERFACE_H_
-#define MLIR_CPU_RUNNER_MLIR_TEST_CBLAS_INTERFACE_H_
-
-#include "mlir/ExecutionEngine/RunnerUtils.h"
-
-#ifdef _WIN32
-#ifndef MLIR_TEST_CBLAS_INTERFACE_EXPORT
-#ifdef mlir_test_cblas_interface_EXPORTS
-// We are building this library
-#define MLIR_TEST_CBLAS_INTERFACE_EXPORT __declspec(dllexport)
-#else
-// We are using this library
-#define MLIR_TEST_CBLAS_INTERFACE_EXPORT __declspec(dllimport)
-#endif // mlir_test_cblas_interface_EXPORTS
-#endif // MLIR_TEST_CBLAS_INTERFACE_EXPORT
-#else
-#define MLIR_TEST_CBLAS_INTERFACE_EXPORT
-#endif // _WIN32
-
-extern "C" MLIR_TEST_CBLAS_INTERFACE_EXPORT void
-_mlir_ciface_linalg_fill_viewf32_f32(StridedMemRefType<float, 0> *X, float f);
-
-extern "C" MLIR_TEST_CBLAS_INTERFACE_EXPORT void
-_mlir_ciface_linalg_fill_viewsxf32_f32(StridedMemRefType<float, 1> *X, float f);
-
-extern "C" MLIR_TEST_CBLAS_INTERFACE_EXPORT void
-_mlir_ciface_linalg_fill_viewsxsxf32_f32(StridedMemRefType<float, 2> *X,
- float f);
-
-extern "C" MLIR_TEST_CBLAS_INTERFACE_EXPORT void
-_mlir_ciface_linalg_copy_viewf32_viewf32(StridedMemRefType<float, 0> *I,
- StridedMemRefType<float, 0> *O);
-
-extern "C" MLIR_TEST_CBLAS_INTERFACE_EXPORT void
-_mlir_ciface_linalg_copy_viewsxf32_viewsxf32(StridedMemRefType<float, 1> *I,
- StridedMemRefType<float, 1> *O);
-
-extern "C" MLIR_TEST_CBLAS_INTERFACE_EXPORT void
-_mlir_ciface_linalg_copy_viewsxsxf32_viewsxsxf32(
- StridedMemRefType<float, 2> *I, StridedMemRefType<float, 2> *O);
-
-extern "C" MLIR_TEST_CBLAS_INTERFACE_EXPORT void
-_mlir_ciface_linalg_dot_viewsxf32_viewsxf32_viewf32(
- StridedMemRefType<float, 1> *X, StridedMemRefType<float, 1> *Y,
- StridedMemRefType<float, 0> *Z);
-
-extern "C" MLIR_TEST_CBLAS_INTERFACE_EXPORT void
-_mlir_ciface_linalg_matmul_viewsxsxf32_viewsxsxf32_viewsxsxf32(
- StridedMemRefType<float, 2> *A, StridedMemRefType<float, 2> *B,
- StridedMemRefType<float, 2> *C);
-
-#endif // MLIR_CPU_RUNNER_MLIR_TEST_CBLAS_INTERFACE_H_
diff --git a/mlir/test/mlir-cpu-runner/linalg_integration_test.mlir b/mlir/test/mlir-cpu-runner/linalg_integration_test.mlir
deleted file mode 100644
index 846011a0f488..000000000000
--- a/mlir/test/mlir-cpu-runner/linalg_integration_test.mlir
+++ /dev/null
@@ -1,99 +0,0 @@
-// RUN: mlir-opt %s -convert-linalg-to-std -convert-linalg-to-llvm \
-// RUN: | mlir-cpu-runner -e dot -entry-point-result=f32 -shared-libs=%linalg_test_lib_dir/libmlir_test_cblas%shlibext,%linalg_test_lib_dir/libmlir_test_cblas_interface%shlibext \
-// RUN: | FileCheck %s
-
-// RUN: mlir-opt %s -convert-linalg-to-loops -convert-linalg-to-std -convert-linalg-to-llvm \
-// RUN: | mlir-cpu-runner -e dot -entry-point-result=f32 -shared-libs=%linalg_test_lib_dir/libmlir_test_cblas%shlibext,%linalg_test_lib_dir/libmlir_test_cblas_interface%shlibext \
-// RUN: | FileCheck %s
-
-// RUN: mlir-opt %s -convert-linalg-to-std -convert-linalg-to-llvm \
-// RUN: | mlir-cpu-runner -e matmul -entry-point-result=f32 -shared-libs=%linalg_test_lib_dir/libmlir_test_cblas%shlibext,%linalg_test_lib_dir/libmlir_test_cblas_interface%shlibext \
-// RUN: | FileCheck %s
-
-// RUN: mlir-opt %s -convert-linalg-to-loops -convert-linalg-to-std -convert-linalg-to-llvm \
-// RUN: | mlir-cpu-runner -e matmul -entry-point-result=f32 -shared-libs=%linalg_test_lib_dir/libmlir_test_cblas%shlibext,%linalg_test_lib_dir/libmlir_test_cblas_interface%shlibext \
-// RUN: | FileCheck %s
-
-// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=2,3,4" -linalg-promote-subviews -convert-linalg-to-loops -convert-linalg-to-std -convert-linalg-to-llvm \
-// RUN: | mlir-cpu-runner -e matmul -entry-point-result=f32 -shared-libs=%linalg_test_lib_dir/libmlir_test_cblas%shlibext,%linalg_test_lib_dir/libmlir_test_cblas_interface%shlibext \
-// RUN: | FileCheck %s
-
-// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=2,3,4" -linalg-promote-subviews -convert-linalg-to-std -convert-linalg-to-llvm \
-// RUN: | mlir-cpu-runner -e matmul -entry-point-result=f32 -shared-libs=%linalg_test_lib_dir/libmlir_test_cblas%shlibext,%linalg_test_lib_dir/libmlir_test_cblas_interface%shlibext \
-// RUN: | FileCheck %s
-
-// Creates and returns a 1-D buffer of size %s filled with the value %f
-func @alloc_filled_f32(%s : index, %f : f32) -> memref<?xi8> {
- %c0 = constant 0 : index
- %c1 = constant 1 : index
- %c4 = constant 4 : index
- %s4 = muli %s, %c4: index
- %buf = alloc(%s4) {alignment = 256} : memref<?xi8>
- %V = view %buf[%c0][%s] : memref<?xi8> to memref<?xf32>
- linalg.fill(%V, %f) : memref<?xf32>, f32
- return %buf : memref<?xi8>
-}
-
-// Test for linalg.dot.
-func @dot() -> f32 {
- %c0 = constant 0 : index
- %c1 = constant 1 : index
- %c16 = constant 16 : index
- %f10 = constant 10.00000e+00 : f32
- %f1 = constant 1.00000e+00 : f32
- %f2 = constant 2.00000e+00 : f32
-
- %bA = call @alloc_filled_f32(%c16, %f2) : (index, f32) -> (memref<?xi8>)
- %bB = call @alloc_filled_f32(%c16, %f1) : (index, f32) -> (memref<?xi8>)
- %bC = call @alloc_filled_f32(%c1, %f10) : (index, f32) -> (memref<?xi8>)
-
- %A = view %bA[%c0][%c16] : memref<?xi8> to memref<?xf32>
- %B = view %bB[%c0][%c16] : memref<?xi8> to memref<?xf32>
- %C = view %bC[%c0][] : memref<?xi8> to memref<f32>
-
- linalg.dot ins(%A, %B : memref<?xf32>, memref<?xf32>)
- outs(%C : memref<f32>)
- %res = load %C[] : memref<f32>
-
- dealloc %bC : memref<?xi8>
- dealloc %bB : memref<?xi8>
- dealloc %bA : memref<?xi8>
-
- return %res : f32
-}
-
-// Test for linalg.matmul.
-func @matmul() -> f32 {
- %c0 = constant 0 : index
- %c1 = constant 1 : index
- %c6 = constant 6 : index
- %c7 = constant 7 : index
- %c2 = constant 2 : index
- %c16 = constant 16 : index
- %c4 = constant 4 : index
- %c32 = constant 32 : index
- %f1 = constant 1.00000e+00 : f32
- %f2 = constant 2.00000e+00 : f32
- %f10 = constant 10.00000e+00 : f32
-
- %bA = call @alloc_filled_f32(%c32, %f2) : (index, f32) -> (memref<?xi8>)
- %bB = call @alloc_filled_f32(%c32, %f1) : (index, f32) -> (memref<?xi8>)
- %bC = call @alloc_filled_f32(%c4, %f10) : (index, f32) -> (memref<?xi8>)
-
- %A = view %bA[%c0][%c2, %c16] : memref<?xi8> to memref<?x?xf32>
- %B = view %bB[%c0][%c16, %c2] : memref<?xi8> to memref<?x?xf32>
- %C = view %bC[%c0][%c2, %c2] : memref<?xi8> to memref<?x?xf32>
-
- linalg.matmul ins(%A, %B : memref<?x?xf32>, memref<?x?xf32>)
- outs(%C : memref<?x?xf32>)
- %res = load %C[%c0, %c1] : memref<?x?xf32>
-
- dealloc %bC : memref<?xi8>
- dealloc %bB : memref<?xi8>
- dealloc %bA : memref<?xi8>
-
- return %res : f32
-}
-
-// All tests return this value
-// CHECK: 4.2{{0+}}e+01
diff --git a/mlir/test/mlir-cpu-runner/mlir_test_cblas.cpp b/mlir/test/mlir-cpu-runner/mlir_test_cblas.cpp
deleted file mode 100644
index 99a6a2658422..000000000000
--- a/mlir/test/mlir-cpu-runner/mlir_test_cblas.cpp
+++ /dev/null
@@ -1,46 +0,0 @@
-//===- mlir_test_cblas.cpp - Simple Blas subset implementation ------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// Simple Blas subset implementation.
-//
-//===----------------------------------------------------------------------===//
-
-#include "include/mlir_test_cblas.h"
-#include <assert.h>
-
-extern "C" float mlir_test_cblas_sdot(const int N, const float *X,
- const int incX, const float *Y,
- const int incY) {
- float res = 0.0f;
- for (int i = 0; i < N; ++i)
- res += X[i * incX] * Y[i * incY];
- return res;
-}
-
-extern "C" void mlir_test_cblas_sgemm(
- const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA,
- const enum CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
- const float alpha, const float *A, const int lda, const float *B,
- const int ldb, const float beta, float *C, const int ldc) {
- assert(Order == CBLAS_ORDER::CblasRowMajor);
- assert(TransA == CBLAS_TRANSPOSE::CblasNoTrans);
- assert(TransB == CBLAS_TRANSPOSE::CblasNoTrans);
- for (int m = 0; m < M; ++m) {
- auto *pA = A + m * lda;
- auto *pC = C + m * ldc;
- for (int n = 0; n < N; ++n) {
- float c = pC[n];
- float res = 0.0f;
- for (int k = 0; k < K; ++k) {
- auto *pB = B + k * ldb;
- res += pA[k] * pB[n];
- }
- pC[n] = alpha * c + beta * res;
- }
- }
-}
diff --git a/mlir/test/mlir-cpu-runner/mlir_test_cblas_interface.cpp b/mlir/test/mlir-cpu-runner/mlir_test_cblas_interface.cpp
deleted file mode 100644
index eaeeaf6cd550..000000000000
--- a/mlir/test/mlir-cpu-runner/mlir_test_cblas_interface.cpp
+++ /dev/null
@@ -1,107 +0,0 @@
-//===- mlir_test_cblas_interface.cpp - Simple Blas subset interface -------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// Simple Blas subset interface implementation.
-//
-//===----------------------------------------------------------------------===//
-
-#include "include/mlir_test_cblas_interface.h"
-#include "include/mlir_test_cblas.h"
-#include <assert.h>
-#include <iostream>
-
-extern "C" void
-_mlir_ciface_linalg_fill_viewf32_f32(StridedMemRefType<float, 0> *X, float f) {
- X->data[X->offset] = f;
-}
-
-extern "C" void
-_mlir_ciface_linalg_fill_viewsxf32_f32(StridedMemRefType<float, 1> *X,
- float f) {
- for (unsigned i = 0; i < X->sizes[0]; ++i)
- *(X->data + X->offset + i * X->strides[0]) = f;
-}
-
-extern "C" void
-_mlir_ciface_linalg_fill_viewsxsxf32_f32(StridedMemRefType<float, 2> *X,
- float f) {
- for (unsigned i = 0; i < X->sizes[0]; ++i)
- for (unsigned j = 0; j < X->sizes[1]; ++j)
- *(X->data + X->offset + i * X->strides[0] + j * X->strides[1]) = f;
-}
-
-extern "C" void
-_mlir_ciface_linalg_copy_viewf32_viewf32(StridedMemRefType<float, 0> *I,
- StridedMemRefType<float, 0> *O) {
- O->data[O->offset] = I->data[I->offset];
-}
-
-extern "C" void
-_mlir_ciface_linalg_copy_viewsxf32_viewsxf32(StridedMemRefType<float, 1> *I,
- StridedMemRefType<float, 1> *O) {
- if (I->sizes[0] != O->sizes[0]) {
- std::cerr << "Incompatible strided memrefs\n";
- printMemRefMetaData(std::cerr, *I);
- printMemRefMetaData(std::cerr, *O);
- return;
- }
- for (unsigned i = 0; i < I->sizes[0]; ++i)
- O->data[O->offset + i * O->strides[0]] =
- I->data[I->offset + i * I->strides[0]];
-}
-
-extern "C" void _mlir_ciface_linalg_copy_viewsxsxf32_viewsxsxf32(
- StridedMemRefType<float, 2> *I, StridedMemRefType<float, 2> *O) {
- if (I->sizes[0] != O->sizes[0] || I->sizes[1] != O->sizes[1]) {
- std::cerr << "Incompatible strided memrefs\n";
- printMemRefMetaData(std::cerr, *I);
- printMemRefMetaData(std::cerr, *O);
- return;
- }
- auto so0 = O->strides[0], so1 = O->strides[1];
- auto si0 = I->strides[0], si1 = I->strides[1];
- for (unsigned i = 0; i < I->sizes[0]; ++i)
- for (unsigned j = 0; j < I->sizes[1]; ++j)
- O->data[O->offset + i * so0 + j * so1] =
- I->data[I->offset + i * si0 + j * si1];
-}
-
-extern "C" void _mlir_ciface_linalg_dot_viewsxf32_viewsxf32_viewf32(
- StridedMemRefType<float, 1> *X, StridedMemRefType<float, 1> *Y,
- StridedMemRefType<float, 0> *Z) {
- if (X->strides[0] != 1 || Y->strides[0] != 1 || X->sizes[0] != Y->sizes[0]) {
- std::cerr << "Incompatible strided memrefs\n";
- printMemRefMetaData(std::cerr, *X);
- printMemRefMetaData(std::cerr, *Y);
- printMemRefMetaData(std::cerr, *Z);
- return;
- }
- Z->data[Z->offset] +=
- mlir_test_cblas_sdot(X->sizes[0], X->data + X->offset, X->strides[0],
- Y->data + Y->offset, Y->strides[0]);
-}
-
-extern "C" void _mlir_ciface_linalg_matmul_viewsxsxf32_viewsxsxf32_viewsxsxf32(
- StridedMemRefType<float, 2> *A, StridedMemRefType<float, 2> *B,
- StridedMemRefType<float, 2> *C) {
- if (A->strides[1] != B->strides[1] || A->strides[1] != C->strides[1] ||
- A->strides[1] != 1 || A->sizes[0] < A->strides[1] ||
- B->sizes[0] < B->strides[1] || C->sizes[0] < C->strides[1] ||
- C->sizes[0] != A->sizes[0] || C->sizes[1] != B->sizes[1] ||
- A->sizes[1] != B->sizes[0]) {
- printMemRefMetaData(std::cerr, *A);
- printMemRefMetaData(std::cerr, *B);
- printMemRefMetaData(std::cerr, *C);
- return;
- }
- mlir_test_cblas_sgemm(
- CBLAS_ORDER::CblasRowMajor, CBLAS_TRANSPOSE::CblasNoTrans,
- CBLAS_TRANSPOSE::CblasNoTrans, C->sizes[0], C->sizes[1], A->sizes[1],
- 1.0f, A->data + A->offset, A->strides[0], B->data + B->offset,
- B->strides[0], 1.0f, C->data + C->offset, C->strides[0]);
-}
More information about the Mlir-commits
mailing list