[Mlir-commits] [mlir] d0c9fb1 - [mlir][Linalg] Improve codegen strategy

Nicolas Vasilache llvmlistbot at llvm.org
Thu Jan 28 03:03:23 PST 2021


Author: Nicolas Vasilache
Date: 2021-01-28T10:59:16Z
New Revision: d0c9fb1b8ebf50ea8b489354b946da961a384c8b

URL: https://github.com/llvm/llvm-project/commit/d0c9fb1b8ebf50ea8b489354b946da961a384c8b
DIFF: https://github.com/llvm/llvm-project/commit/d0c9fb1b8ebf50ea8b489354b946da961a384c8b.diff

LOG: [mlir][Linalg] Improve codegen strategy

This revision improves the usage of the codegen strategy by adding a few flags that
make it easier to control for the CLI.
Usage of ModuleOp is replaced by FuncOp as this created issues in multi-threaded mode.

A simple benchmarking capability is added for linalg.matmul as well as linalg.matmul_column_major.
This latter op is also added to linalg.

Now obsolete linalg integration tests that also take too long are deleted.

Correctness checks are still missing at this point.

Differential revision: https://reviews.llvm.org/D95531

Added: 
    mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul.mlir
    mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul_column_major.mlir
    mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul_column_major_as_row_major.mlir

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc
    mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
    mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/test/CMakeLists.txt
    mlir/test/Dialect/Linalg/codegen-strategy.mlir
    mlir/test/lib/Transforms/TestLinalgCodegenStrategy.cpp

Removed: 
    mlir/test/mlir-cpu-runner/CMakeLists.txt
    mlir/test/mlir-cpu-runner/include/mlir_test_cblas.h
    mlir/test/mlir-cpu-runner/include/mlir_test_cblas_interface.h
    mlir/test/mlir-cpu-runner/linalg_integration_test.mlir
    mlir/test/mlir-cpu-runner/mlir_test_cblas.cpp
    mlir/test/mlir-cpu-runner/mlir_test_cblas_interface.cpp


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc
index 765e045e9e77..e6d1e1935367 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc
@@ -3,6 +3,11 @@ def matmul(A: f32(M, K), B: f32(K, N)) -> (C: f32(M, N)) {
   C(m, n) = std_addf<k>(std_mulf(A(m, k), B(k, n)));
 }
 
+ods_def<MatmulColumnMajorOp>:
+def matmul_column_major(A: f32(K, M), B: f32(N, K)) -> (C: f32(N, M)) {
+  C(n, m) = std_addf<k>(std_mulf(A(k, m), B(n, k)));
+}
+
 ods_def<MatvecOp>:
 def matvec(A: f32(M, N), y: f32(N)) -> (x: f32(M)) {
   x(m) = std_addf<n>(std_mulf(A(m, n), y(n)));

diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 436dab1ade2b..c88e1201f84b 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -143,6 +143,10 @@ def CopyOp : LinalgStructured_Op<"copy", [CopyOpInterface]> {
   }];
   let verifier = [{ return ::verify(*this); }];
 
+  let assemblyFormat = [{
+    `(` operands `)` attr-dict `:` type(operands)
+  }];
+
   let hasFolder = 1;
   let hasCanonicalizer = 1;
 }

diff  --git a/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul.mlir b/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul.mlir
new file mode 100644
index 000000000000..3c589d163857
--- /dev/null
+++ b/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul.mlir
@@ -0,0 +1,99 @@
+// RUN: export M=24 && export K=64 && export N=192 && export ITERS=10 && \
+// RUN: cat %s | sed 's@${M}@'"$M"'@g'| sed 's@${K}@'"$K"'@g' | sed 's@${N}@'"$N"'@g'| sed 's@${ITERS}@'"$ITERS"'@g'| \
+// RUN: mlir-opt -test-linalg-codegen-strategy="anchor-op=linalg.matmul register-tile-sizes=12,32,16 vectorize" | \
+// RUN: mlir-opt -test-linalg-codegen-strategy="anchor-op=linalg.fill register-tile-sizes=4,32 vectorize" | \
+// RUN: mlir-opt -test-linalg-codegen-strategy="anchor-op=linalg.copy register-tile-sizes=4,32 vectorize" | \
+
+// RUN: mlir-opt -canonicalize -convert-vector-to-scf -lower-affine -convert-linalg-to-loops | \
+// RUN: mlir-opt -canonicalize -convert-scf-to-std -convert-vector-to-llvm | \
+// RUN: mlir-cpu-runner -O3 -e main -entry-point-result=void \
+// Activate to dump assembly
+// R_UN:   -dump-object-file -object-filename=/tmp/a.o \
+// RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext | \
+// Use tee to both print to stderr and FileCheck
+// RUN: tee -a /dev/stderr | FileCheck %s
+
+
+!row_major_A = type memref<${M}x${K}xf32>
+!row_major_B = type memref<${K}x${N}xf32>
+!row_major_C = type memref<${M}x${N}xf32>
+
+func @matmul(%a: !row_major_A, %b: !row_major_B, %c: !row_major_C)
+// TODO: activate manually for now.
+// attributes { passthrough = [["target-cpu", "skylake-avx512"], ["prefer-vector-width", "512"]]}
+{
+  linalg.matmul ins(%a, %b : !row_major_A, !row_major_B)
+    outs(%c: !row_major_C)
+  return
+}
+
+func @print_perf(%iters: index, %total_time: f64) {
+  %c2 = constant 2 : index
+  %cM = constant ${M} : index
+  %cN = constant ${N} : index
+  %cK = constant ${K} : index
+
+  %mn = muli %cM, %cN : index
+  %mnk = muli %mn, %cK : index
+
+  // 2*M*N*K.
+  %flops_per_iter = muli %c2, %mnk : index
+  %flops = muli %iters, %flops_per_iter : index
+  %flops_i64 = index_cast %flops : index to i64
+  %flops_f = sitofp %flops_i64 : i64 to f64
+  %flops_per_s = divf %flops_f, %total_time : f64
+  vector.print %flops_per_s : f64
+
+  return
+}
+
+func @main() {
+  %f0 = constant 0.0 : f32
+  %f1 = constant 1.0 : f32
+
+  %A = alloc() : !row_major_A
+  %B = alloc() : !row_major_B
+  %C = alloc() : !row_major_C
+
+  linalg.fill(%A, %f1) : !row_major_A, f32
+  linalg.fill(%B, %f1) : !row_major_B, f32
+  linalg.fill(%C, %f0) : !row_major_C, f32
+
+  %c0 = constant 0: index
+  %c1 = constant 1: index
+  %iters = constant ${ITERS}: index
+
+  /// Run and dump performance for matmul.
+  /// Preheating run:
+  scf.for %arg0 = %c0 to %iters step %c1 {
+    linalg.fill(%C, %f0) : !row_major_C, f32
+    call @matmul(%A, %B, %C) : (!row_major_A, !row_major_B, !row_major_C) -> ()
+  }
+  %t_start_matmul = call @rtclock() : () -> f64
+  scf.for %arg0 = %c0 to %iters step %c1 {
+    // linalg.matmul writes %C in place, need to reset it to zero every time.
+    // This is accounts for about 10-15% perf hit on small sizes.
+    // Once linalg on tensors is ready, fusing fill at teh register level will
+    // be easy.
+    linalg.fill(%C, %f0) : !row_major_C, f32
+    call @matmul(%A, %B, %C) : (!row_major_A, !row_major_B, !row_major_C) -> ()
+  }
+  %t_end_matmul = call @rtclock() : () -> f64
+  %tmatmul = subf %t_end_matmul, %t_start_matmul: f64
+  call @print_perf(%iters, %tmatmul) : (index, f64) -> ()
+
+  %res = load %C[%c0, %c0]: !row_major_C
+  // CHECK: 64
+  vector.print %res: f32
+
+  dealloc %A : !row_major_A
+  dealloc %B : !row_major_B
+  dealloc %C : !row_major_C
+
+  return
+}
+
+func private @rtclock() -> f64
+
+// TODO: init with random, run and check output.
+// func private @fill_random_f32(memref<*xf32>)

diff  --git a/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul_column_major.mlir b/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul_column_major.mlir
new file mode 100644
index 000000000000..a71643fde480
--- /dev/null
+++ b/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul_column_major.mlir
@@ -0,0 +1,98 @@
+// RUN: export M=24 && export K=64 && export N=192 && export ITERS=10 && \
+// RUN: cat %s | sed 's@${M}@'"$M"'@g'| sed 's@${K}@'"$K"'@g' | sed 's@${N}@'"$N"'@g'| sed 's@${ITERS}@'"$ITERS"'@g'| \
+// RUN: mlir-opt -test-linalg-codegen-strategy="anchor-op=linalg.matmul_column_major register-tile-sizes=16,0,32 vectorize" | \
+// RUN: mlir-opt -test-linalg-codegen-strategy="anchor-op=linalg.fill register-tile-sizes=4,16 vectorize" | \
+
+// TODO: linalg.copy vectorization in the presence of permutation map fails. Enable when addressed.
+// RUN: mlir-opt -test-linalg-codegen-strategy="anchor-op=linalg.copy register-tile-sizes=4,16 vectorize" | \
+
+// RUN: mlir-opt -canonicalize -convert-vector-to-scf -lower-affine -convert-linalg-to-loops | \
+// RUN: mlir-opt -canonicalize -convert-scf-to-std -convert-vector-to-llvm | \
+// RUN: mlir-cpu-runner -O3 -e main -entry-point-result=void \
+// Activate to dump assembly
+// R_UN:   -dump-object-file -object-filename=/tmp/a.o \
+// RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext | \
+// Use tee to both print to stderr and FileCheck
+// RUN: tee -a /dev/stderr | FileCheck %s
+
+!row_major_A = type memref<${M}x${K}xf32>
+!row_major_B = type memref<${K}x${N}xf32>
+!row_major_C = type memref<${M}x${N}xf32>
+!column_major_A = type memref<${K}x${M}xf32>
+!column_major_B = type memref<${N}x${K}xf32>
+!column_major_C = type memref<${N}x${M}xf32>
+
+func @matmul_column_major(%a: !column_major_A, %b: !column_major_B, %c: !column_major_C)
+// TODO: activate manually for now.
+// attributes { passthrough = [["target-cpu", "skylake-avx512"], ["prefer-vector-width", "512"]]}
+{
+  linalg.matmul_column_major ins(%a, %b : !column_major_A, !column_major_B)
+    outs(%c: !column_major_C)
+  return
+}
+
+func @print_perf(%iters: index, %total_time: f64) {
+  %c2 = constant 2 : index
+  %cM = constant ${M} : index
+  %cN = constant ${N} : index
+  %cK = constant ${K} : index
+
+  %mn = muli %cM, %cN : index
+  %mnk = muli %mn, %cK : index
+
+  // 2*M*N*K.
+  %flops_per_iter = muli %c2, %mnk : index
+  %flops = muli %iters, %flops_per_iter : index
+  %flops_i64 = index_cast %flops : index to i64
+  %flops_f = sitofp %flops_i64 : i64 to f64
+  %flops_per_s = divf %flops_f, %total_time : f64
+  vector.print %flops_per_s : f64
+
+  return
+}
+
+func @main() {
+  %f0 = constant 0.0 : f32
+  %f1 = constant 1.0 : f32
+
+  %cA = alloc() : !column_major_A
+  %cB = alloc() : !column_major_B
+  %cC = alloc() : !column_major_C
+
+  linalg.fill(%cA, %f1) : !column_major_A, f32
+  linalg.fill(%cB, %f1) : !column_major_B, f32
+  linalg.fill(%cC, %f0) : !column_major_C, f32
+
+  %c0 = constant 0: index
+  %c1 = constant 1: index
+  %iters = constant ${ITERS}: index
+
+  /// Run and dump performance for matmul_column_major.
+  %t_start_matmul_column_major = call @rtclock() : () -> f64
+  scf.for %arg0 = %c0 to %iters step %c1 {
+    // linalg.matmul writes %C in place, need to reset it to zero every time.
+    // This is accounts for about 10-15% perf hit on small sizes.
+    // Once linalg on tensors is ready, fusing fill at teh register level will
+    // be easy.
+    linalg.fill(%cC, %f0) : !column_major_C, f32
+    call @matmul_column_major(%cA, %cB, %cC) : (!column_major_A, !column_major_B, !column_major_C) -> ()
+  }
+  %t_end_matmul_column_major = call @rtclock() : () -> f64
+  %tmatmul_column_major = subf %t_end_matmul_column_major, %t_start_matmul_column_major: f64
+  call @print_perf(%iters, %tmatmul_column_major) : (index, f64) -> ()
+
+  %res = load %cC[%c0, %c0]: !column_major_C
+  // CHECK: 64
+  vector.print %res: f32
+
+  dealloc %cA : !column_major_A
+  dealloc %cB : !column_major_B
+  dealloc %cC : !column_major_C
+
+  return
+}
+
+func private @rtclock() -> f64
+
+// TODO: init with random, run and check output.
+// func private @fill_random_f32(memref<*xf32>)

diff  --git a/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul_column_major_as_row_major.mlir b/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul_column_major_as_row_major.mlir
new file mode 100644
index 000000000000..c8f3fe4b95d4
--- /dev/null
+++ b/mlir/integration_test/Dialect/Linalg/CPU/benchmark_matmul_column_major_as_row_major.mlir
@@ -0,0 +1,116 @@
+// RUN: export M=24 && export K=64 && export N=192 && export ITERS=10 && \
+// RUN: cat %s | sed 's@${M}@'"$M"'@g'| sed 's@${K}@'"$K"'@g' | sed 's@${N}@'"$N"'@g'| sed 's@${ITERS}@'"$ITERS"'@g'| \
+// RUN: mlir-opt -test-linalg-codegen-strategy="anchor-op=linalg.matmul_column_major register-tile-sizes=16,0,32 vectorize" | \
+// RUN: mlir-opt -test-linalg-codegen-strategy="anchor-op=linalg.matmul register-tile-sizes=12,32,16 vectorize" | \
+// RUN: mlir-opt -test-linalg-codegen-strategy="anchor-op=linalg.fill register-tile-sizes=4,16 vectorize" | \
+
+// TODO: linalg.copy vectorization in the presence of permutation map fails. Enable when addressed.
+// R_UN: mlir-opt -test-linalg-codegen-strategy="anchor-op=linalg.copy register-tile-sizes=4,16 vectorize" | \
+
+// RUN: mlir-opt -canonicalize -convert-vector-to-scf -lower-affine -convert-linalg-to-loops | \
+// RUN: mlir-opt -canonicalize -convert-scf-to-std -convert-vector-to-llvm | \
+// RUN: mlir-cpu-runner -O3 -e main -entry-point-result=void \
+// Activate to dump assembly
+// R_UN:   -dump-object-file -object-filename=/tmp/a.o \
+// RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext | \
+// Use tee to both print to stderr and FileCheck
+// RUN: tee -a /dev/stderr | FileCheck %s
+
+!row_major_A = type memref<${M}x${K}xf32>
+!row_major_B = type memref<${K}x${N}xf32>
+!row_major_C = type memref<${M}x${N}xf32>
+!column_major_A = type memref<${K}x${M}xf32>
+!column_major_B = type memref<${N}x${K}xf32>
+!column_major_C = type memref<${N}x${M}xf32>
+
+func @matmul_column_major_as_row_major(
+  %ca: !column_major_A, %cb: !column_major_B, %cc: !column_major_C,
+   %a: !row_major_A,     %b: !row_major_B,     %c: !row_major_C)
+// TODO: activate manually for now.
+// attributes { passthrough = [["target-cpu", "skylake-avx512"], ["prefer-vector-width", "512"]]}
+{
+  linalg.copy(%ca, %a) {inputPermutation = affine_map<(i, j) -> (j, i)> } : !column_major_A, !row_major_A
+  linalg.copy(%cb, %b) {inputPermutation = affine_map<(i, j) -> (j, i)> } : !column_major_B, !row_major_B
+  linalg.matmul ins(%a, %b : !row_major_A, !row_major_B)
+    outs(%c: !row_major_C)
+  linalg.copy(%c, %cc) {inputPermutation = affine_map<(i, j) -> (j, i)> } : !row_major_C, !column_major_C
+  return
+}
+
+func @print_perf(%iters: index, %total_time: f64) {
+  %c2 = constant 2 : index
+  %cM = constant ${M} : index
+  %cN = constant ${N} : index
+  %cK = constant ${K} : index
+
+  %mn = muli %cM, %cN : index
+  %mnk = muli %mn, %cK : index
+
+  // 2*M*N*K.
+  %flops_per_iter = muli %c2, %mnk : index
+  %flops = muli %iters, %flops_per_iter : index
+  %flops_i64 = index_cast %flops : index to i64
+  %flops_f = sitofp %flops_i64 : i64 to f64
+  %flops_per_s = divf %flops_f, %total_time : f64
+  vector.print %flops_per_s : f64
+
+  return
+}
+
+func @main() {
+  %f0 = constant 0.0 : f32
+  %f1 = constant 1.0 : f32
+
+  %cA = alloc() : !column_major_A
+  %cB = alloc() : !column_major_B
+  %cC = alloc() : !column_major_C
+
+  linalg.fill(%cA, %f1) : !column_major_A, f32
+  linalg.fill(%cB, %f1) : !column_major_B, f32
+  linalg.fill(%cC, %f0) : !column_major_C, f32
+
+  %c0 = constant 0: index
+  %c1 = constant 1: index
+  %iters = constant ${ITERS}: index
+
+  /// Run and dump performance for matmul_column_major as a row-major
+  %A = alloc() : !row_major_A
+  %B = alloc() : !row_major_B
+  %C = alloc() : !row_major_C
+  %t_start_matmul_column_major_as_row_major = call @rtclock() : () -> f64
+  scf.for %arg0 = %c0 to %iters step %c1 {
+    // linalg.matmul writes %C in place, need to reset it to zero every time.
+    // This is accounts for about 10-15% perf hit on small sizes.
+    // Once linalg on tensors is ready, fusing fill at teh register level will
+    // be easy.
+    linalg.fill(%C, %f0) : !row_major_C, f32
+    call @matmul_column_major_as_row_major(%cA, %cB, %cC, %A, %B, %C) :
+      (!column_major_A, !column_major_B, !column_major_C,
+       !row_major_A, !row_major_B, !row_major_C) -> ()
+  }
+  %t_end_matmul_column_major_as_row_major = call @rtclock() : () -> f64
+  %tmatmul_column_major_as_row_major = subf %t_end_matmul_column_major_as_row_major, %t_start_matmul_column_major_as_row_major: f64
+  call @print_perf(%iters, %tmatmul_column_major_as_row_major) : (index, f64) -> ()
+
+  %res = load %cC[%c0, %c0]: !column_major_C
+  // CHECK: 64
+  vector.print %res: f32
+  %res2 = load %C[%c0, %c0]: !row_major_C
+  // CHECK: 64
+  vector.print %res2: f32
+
+  dealloc %A : !row_major_A
+  dealloc %B : !row_major_B
+  dealloc %C : !row_major_C
+
+  dealloc %cA : !column_major_A
+  dealloc %cB : !column_major_B
+  dealloc %cC : !column_major_C
+
+  return
+}
+
+func private @rtclock() -> f64
+
+// TODO: init with random, run and check output.
+// func private @fill_random_f32(memref<*xf32>)

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp b/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp
index 652a036838ed..02058f886451 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp
@@ -51,10 +51,11 @@ void mlir::linalg::CodegenStrategy::transform(FuncOp func) const {
     // Some of these may be too aggressive as a stage 3 that is applied on each
     // stage 1 application and may have to be split out to post staged patterns
     // application (in which case they could just be passes, TBD).
-    PassManager pm(op->getContext());
-    pm.addPass(createLoopInvariantCodeMotionPass());
-    if (failed(pm.run(op->getParentOfType<ModuleOp>())))
-      llvm_unreachable("Unexpected failure in cleanup pass pipeline.");
+    op->walk([&](LoopLikeOpInterface loopLike) {
+      LLVM_DEBUG(loopLike.print(llvm::dbgs() << "\nOriginal loop:\n"));
+      if (failed(moveLoopInvariantCode(loopLike)))
+        llvm_unreachable("unexpected LICM failure");
+    });
     promoteSingleIterationLoops(cast<FuncOp>(op));
     hoistViewAllocOps(cast<FuncOp>(op));
     hoistRedundantVectorTransfers(cast<FuncOp>(op));
@@ -67,13 +68,11 @@ void mlir::linalg::CodegenStrategy::transform(FuncOp func) const {
   // Post staged patterns transforms
   //===--------------------------------------------------------------------===//
 
-  ModuleOp module = func->getParentOfType<ModuleOp>();
-
   // Programmatic splitting of slow/fast path vector transfers.
   OwningRewritePatternList patterns;
   patterns.insert<vector::VectorTransferFullPartialRewriter>(
       context, vectorTransformsOptions);
-  applyPatternsAndFoldGreedily(module, std::move(patterns));
+  applyPatternsAndFoldGreedily(func, std::move(patterns));
 
   // Programmatic controlled lowering of vector.contract only.
   OwningRewritePatternList vectorContractLoweringPatterns;
@@ -81,17 +80,16 @@ void mlir::linalg::CodegenStrategy::transform(FuncOp func) const {
       .insert<ContractionOpToOuterProductOpLowering,
               ContractionOpToMatmulOpLowering, ContractionOpLowering>(
           vectorTransformsOptions, context);
-  applyPatternsAndFoldGreedily(module,
-                               std::move(vectorContractLoweringPatterns));
+  applyPatternsAndFoldGreedily(func, std::move(vectorContractLoweringPatterns));
 
   // Programmatic controlled lowering of vector.transfer only.
   OwningRewritePatternList vectorToLoopsPatterns;
   populateVectorToSCFConversionPatterns(vectorToLoopsPatterns, context,
                                         vectorToSCFOptions);
-  applyPatternsAndFoldGreedily(module, std::move(vectorToLoopsPatterns));
+  applyPatternsAndFoldGreedily(func, std::move(vectorToLoopsPatterns));
 
   // Ensure we drop the marker in the end.
-  module.walk([](LinalgOp op) {
+  func.walk([](LinalgOp op) {
     op.removeAttr(LinalgTransforms::kLinalgTransformMarker);
   });
 }

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 2a1d4cd2ef57..76a5bb56a4b6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -68,8 +68,8 @@ static bool hasMultiplyAddBody(Region &r) {
 // TODO: Should be Tablegen'd from a single source that generates the op itself.
 static LogicalResult isContraction(Operation *op) {
   // TODO: interface for named ops.
-  if (isa<linalg::BatchMatmulOp, linalg::MatmulOp, linalg::MatvecOp,
-          linalg::VecmatOp, linalg::DotOp>(op))
+  if (isa<linalg::BatchMatmulOp, linalg::MatmulOp, linalg::MatmulColumnMajorOp,
+          linalg::MatvecOp, linalg::VecmatOp, linalg::DotOp>(op))
     return success();
 
   auto genericOp = dyn_cast<linalg::GenericOp>(op);

diff  --git a/mlir/test/CMakeLists.txt b/mlir/test/CMakeLists.txt
index 293d93268a11..9ab9c0932dcd 100644
--- a/mlir/test/CMakeLists.txt
+++ b/mlir/test/CMakeLists.txt
@@ -1,7 +1,6 @@
 add_subdirectory(Bindings)
 add_subdirectory(CAPI)
 add_subdirectory(EDSC)
-add_subdirectory(mlir-cpu-runner)
 add_subdirectory(SDBM)
 add_subdirectory(lib)
 
@@ -54,8 +53,6 @@ set(MLIR_TEST_DEPENDS
   mlir-sdbm-api-test
   mlir-tblgen
   mlir-translate
-  mlir_test_cblas
-  mlir_test_cblas_interface
   mlir_runner_utils
   mlir_c_runner_utils
   mlir_async_runtime

diff  --git a/mlir/test/Dialect/Linalg/codegen-strategy.mlir b/mlir/test/Dialect/Linalg/codegen-strategy.mlir
index 49ef2e6f26df..60cf4fe89482 100644
--- a/mlir/test/Dialect/Linalg/codegen-strategy.mlir
+++ b/mlir/test/Dialect/Linalg/codegen-strategy.mlir
@@ -1,5 +1,5 @@
-// RUN: mlir-opt %s -test-linalg-codegen-strategy="tile-sizes=2,4,8 vectorize vectorize-contraction-to=matrixintrinsics unroll-vector-transfers=true" | FileCheck %s
-// RUN: mlir-opt %s -test-linalg-codegen-strategy="tile-sizes=16,32,64 promote promote-full-tile-pad register-tile-sizes=2,4,8 vectorize vectorize-contraction-to=outerproduct split-transfers=true unroll-vector-transfers=false" | FileCheck %s --check-prefix=OUTER
+// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-op=linalg.matmul tile-sizes=2,4,8 vectorize vectorize-contraction-to=matrixintrinsics unroll-vector-transfers=true" | FileCheck %s
+// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-op=linalg.matmul tile-sizes=16,32,64 promote promote-full-tile-pad register-tile-sizes=2,4,8 vectorize vectorize-contraction-to=outerproduct split-transfers=true unroll-vector-transfers=false" | FileCheck %s --check-prefix=OUTER
 
 // CHECK-LABEL: func @matmul(
 // OUTER-LABEL: func @matmul(

diff  --git a/mlir/test/lib/Transforms/TestLinalgCodegenStrategy.cpp b/mlir/test/lib/Transforms/TestLinalgCodegenStrategy.cpp
index 5e8c92c5aa04..8d80de793658 100644
--- a/mlir/test/lib/Transforms/TestLinalgCodegenStrategy.cpp
+++ b/mlir/test/lib/Transforms/TestLinalgCodegenStrategy.cpp
@@ -42,6 +42,9 @@ struct TestLinalgCodegenStrategy
     // clang-format on
   }
 
+  template <typename LinalgNamedOp>
+  void applyStrategyToNamedLinalgOp();
+
   void runOnFunction() override;
 
   ListOption<int64_t> tileSizes{*this, "tile-sizes",
@@ -91,11 +94,21 @@ struct TestLinalgCodegenStrategy
       *this, "unroll-vector-transfers",
       llvm::cl::desc("Enable full unrolling of vector.transfer operations"),
       llvm::cl::init(false)};
+  Option<std::string> anchorOpName{
+      *this, "anchor-op",
+      llvm::cl::desc(
+          "Which single linalg op is the anchor for the codegen strategy to "
+          "latch on:\n"
+          "\tlinalg.matmul: anchor on linalg.matmul\n"
+          "\tlinalg.matmul_column_major: anchor on linalg.matmul_column_major\n"
+          "\tlinalg.copy: anchor on linalg.copy\n"
+          "\tlinalg.fill: anchor on linalg.fill\n"),
+      llvm::cl::init("")};
 };
 } // end anonymous namespace
 
-/// Apply transformations specified as patterns.
-void TestLinalgCodegenStrategy::runOnFunction() {
+template <typename LinalgNamedOp>
+void TestLinalgCodegenStrategy::applyStrategyToNamedLinalgOp() {
   LinalgTilingOptions tilingOptions;
   if (!tileSizes.empty())
     tilingOptions = tilingOptions.setTileSizes(tileSizes);
@@ -121,27 +134,42 @@ void TestLinalgCodegenStrategy::runOnFunction() {
           .Default(vector::VectorTransferSplit::None);
 
   CodegenStrategy strategy;
-  strategy.tileIf<MatmulOp>(!tileSizes.empty(), tilingOptions)
-      .promoteIf<MatmulOp>(promote,
-                           LinalgPromotionOptions()
-                               .setAlignment(16)
-                               .setUseFullTileBuffersByDefault(promoteFullTile))
-      .tileIf<MatmulOp>(!registerTileSizes.empty(), registerTilingOptions)
-      .promoteIf<MatmulOp>(registerPromote, LinalgPromotionOptions()
-                                                .setAlignment(16)
-                                                .setUseFullTileBuffersByDefault(
-                                                    registerPromoteFullTile))
-      .vectorizeIf<MatmulOp>(vectorize)
+  strategy.template tileIf<LinalgNamedOp>(!tileSizes.empty(), tilingOptions)
+      .template promoteIf<LinalgNamedOp>(
+          promote, LinalgPromotionOptions()
+                       .setAlignment(16)
+                       .setUseFullTileBuffersByDefault(promoteFullTile))
+      .template tileIf<LinalgNamedOp>(!registerTileSizes.empty(),
+                                      registerTilingOptions)
+      .template promoteIf<LinalgNamedOp>(
+          registerPromote,
+          LinalgPromotionOptions()
+              .setAlignment(16)
+              .setUseFullTileBuffersByDefault(registerPromoteFullTile))
+      .template vectorizeIf<LinalgNamedOp>(vectorize)
       .setVectorTransformsOptions(
           vector::VectorTransformsOptions()
               .setVectorTransformsOptions(vectorContractLowering)
               .setVectorTransferSplit(vectorTransferSplit))
       .setVectorTransferToSCFOptions(
           VectorTransferToSCFOptions().setUnroll(unrollVectorTransfers));
-
   strategy.transform(getFunction());
 }
 
+/// Apply transformations specified as patterns.
+void TestLinalgCodegenStrategy::runOnFunction() {
+  if (anchorOpName == MatmulOp::getOperationName())
+    applyStrategyToNamedLinalgOp<MatmulOp>();
+  else if (anchorOpName == MatmulColumnMajorOp::getOperationName())
+    applyStrategyToNamedLinalgOp<MatmulColumnMajorOp>();
+  else if (anchorOpName == CopyOp::getOperationName())
+    applyStrategyToNamedLinalgOp<CopyOp>();
+  else if (anchorOpName == FillOp::getOperationName())
+    applyStrategyToNamedLinalgOp<FillOp>();
+  else
+    llvm_unreachable("Unsupported anchor op");
+}
+
 namespace mlir {
 namespace test {
 void registerTestLinalgCodegenStrategy() {

diff  --git a/mlir/test/mlir-cpu-runner/CMakeLists.txt b/mlir/test/mlir-cpu-runner/CMakeLists.txt
deleted file mode 100644
index 62f395271855..000000000000
--- a/mlir/test/mlir-cpu-runner/CMakeLists.txt
+++ /dev/null
@@ -1,12 +0,0 @@
-set(LLVM_OPTIONAL_SOURCES
-  mlir_test_cblas.cpp
-  mlir_test_cblas_interface.cpp
-  )
-
-add_llvm_library(mlir_test_cblas SHARED mlir_test_cblas.cpp)
-target_compile_definitions(mlir_test_cblas PRIVATE mlir_test_cblas_EXPORTS)
-
-add_llvm_library(mlir_test_cblas_interface SHARED mlir_test_cblas_interface.cpp)
-target_link_libraries(mlir_test_cblas_interface PRIVATE mlir_test_cblas)
-target_compile_definitions(mlir_test_cblas_interface PRIVATE mlir_test_cblas_interface_EXPORTS)
-

diff  --git a/mlir/test/mlir-cpu-runner/include/mlir_test_cblas.h b/mlir/test/mlir-cpu-runner/include/mlir_test_cblas.h
deleted file mode 100644
index 34af2d4c6467..000000000000
--- a/mlir/test/mlir-cpu-runner/include/mlir_test_cblas.h
+++ /dev/null
@@ -1,49 +0,0 @@
-//===- mlir_test_cblas.h - Simple Blas subset -----------------------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-#ifndef MLIR_CPU_RUNNER_MLIR_TEST_CBLAS_H_
-#define MLIR_CPU_RUNNER_MLIR_TEST_CBLAS_H_
-
-#include "mlir/ExecutionEngine/RunnerUtils.h"
-
-#ifdef _WIN32
-#ifndef MLIR_TEST_CBLAS_EXPORT
-#ifdef mlir_test_cblas_EXPORTS
-// We are building this library
-#define MLIR_TEST_CBLAS_EXPORT __declspec(dllexport)
-#else
-// We are using this library
-#define MLIR_TEST_CBLAS_EXPORT __declspec(dllimport)
-#endif // mlir_test_cblas_EXPORTS
-#endif // MLIR_TEST_CBLAS_EXPORT
-#else
-#define MLIR_TEST_CBLAS_EXPORT
-#endif // _WIN32
-
-/// This reproduces a minimal subset of mlir_test_cblas to allow integration
-/// testing without explicitly requiring a dependence on an external library.
-/// Without loss of generality, various mlir_test_cblas implementations may be
-/// swapped in by including the proper headers and linking with the proper
-/// library.
-enum CBLAS_ORDER { CblasRowMajor = 101, CblasColMajor = 102 };
-enum CBLAS_TRANSPOSE {
-  CblasNoTrans = 111,
-  CblasTrans = 112,
-  CblasConjTrans = 113
-};
-
-extern "C" MLIR_TEST_CBLAS_EXPORT float
-mlir_test_cblas_sdot(const int N, const float *X, const int incX,
-                     const float *Y, const int incY);
-
-extern "C" MLIR_TEST_CBLAS_EXPORT void mlir_test_cblas_sgemm(
-    const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA,
-    const enum CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
-    const float alpha, const float *A, const int lda, const float *B,
-    const int ldb, const float beta, float *C, const int ldc);
-
-#endif // MLIR_CPU_RUNNER_MLIR_TEST_CBLAS_H_

diff  --git a/mlir/test/mlir-cpu-runner/include/mlir_test_cblas_interface.h b/mlir/test/mlir-cpu-runner/include/mlir_test_cblas_interface.h
deleted file mode 100644
index 1f5a0e618c12..000000000000
--- a/mlir/test/mlir-cpu-runner/include/mlir_test_cblas_interface.h
+++ /dev/null
@@ -1,59 +0,0 @@
-//===- mlir_test_cblas_interface.h - Simple Blas subset interface ---------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-#ifndef MLIR_CPU_RUNNER_MLIR_TEST_CBLAS_INTERFACE_H_
-#define MLIR_CPU_RUNNER_MLIR_TEST_CBLAS_INTERFACE_H_
-
-#include "mlir/ExecutionEngine/RunnerUtils.h"
-
-#ifdef _WIN32
-#ifndef MLIR_TEST_CBLAS_INTERFACE_EXPORT
-#ifdef mlir_test_cblas_interface_EXPORTS
-// We are building this library
-#define MLIR_TEST_CBLAS_INTERFACE_EXPORT __declspec(dllexport)
-#else
-// We are using this library
-#define MLIR_TEST_CBLAS_INTERFACE_EXPORT __declspec(dllimport)
-#endif // mlir_test_cblas_interface_EXPORTS
-#endif // MLIR_TEST_CBLAS_INTERFACE_EXPORT
-#else
-#define MLIR_TEST_CBLAS_INTERFACE_EXPORT
-#endif // _WIN32
-
-extern "C" MLIR_TEST_CBLAS_INTERFACE_EXPORT void
-_mlir_ciface_linalg_fill_viewf32_f32(StridedMemRefType<float, 0> *X, float f);
-
-extern "C" MLIR_TEST_CBLAS_INTERFACE_EXPORT void
-_mlir_ciface_linalg_fill_viewsxf32_f32(StridedMemRefType<float, 1> *X, float f);
-
-extern "C" MLIR_TEST_CBLAS_INTERFACE_EXPORT void
-_mlir_ciface_linalg_fill_viewsxsxf32_f32(StridedMemRefType<float, 2> *X,
-                                         float f);
-
-extern "C" MLIR_TEST_CBLAS_INTERFACE_EXPORT void
-_mlir_ciface_linalg_copy_viewf32_viewf32(StridedMemRefType<float, 0> *I,
-                                         StridedMemRefType<float, 0> *O);
-
-extern "C" MLIR_TEST_CBLAS_INTERFACE_EXPORT void
-_mlir_ciface_linalg_copy_viewsxf32_viewsxf32(StridedMemRefType<float, 1> *I,
-                                             StridedMemRefType<float, 1> *O);
-
-extern "C" MLIR_TEST_CBLAS_INTERFACE_EXPORT void
-_mlir_ciface_linalg_copy_viewsxsxf32_viewsxsxf32(
-    StridedMemRefType<float, 2> *I, StridedMemRefType<float, 2> *O);
-
-extern "C" MLIR_TEST_CBLAS_INTERFACE_EXPORT void
-_mlir_ciface_linalg_dot_viewsxf32_viewsxf32_viewf32(
-    StridedMemRefType<float, 1> *X, StridedMemRefType<float, 1> *Y,
-    StridedMemRefType<float, 0> *Z);
-
-extern "C" MLIR_TEST_CBLAS_INTERFACE_EXPORT void
-_mlir_ciface_linalg_matmul_viewsxsxf32_viewsxsxf32_viewsxsxf32(
-    StridedMemRefType<float, 2> *A, StridedMemRefType<float, 2> *B,
-    StridedMemRefType<float, 2> *C);
-
-#endif // MLIR_CPU_RUNNER_MLIR_TEST_CBLAS_INTERFACE_H_

diff  --git a/mlir/test/mlir-cpu-runner/linalg_integration_test.mlir b/mlir/test/mlir-cpu-runner/linalg_integration_test.mlir
deleted file mode 100644
index 846011a0f488..000000000000
--- a/mlir/test/mlir-cpu-runner/linalg_integration_test.mlir
+++ /dev/null
@@ -1,99 +0,0 @@
-// RUN: mlir-opt %s -convert-linalg-to-std -convert-linalg-to-llvm \
-// RUN: | mlir-cpu-runner -e dot -entry-point-result=f32 -shared-libs=%linalg_test_lib_dir/libmlir_test_cblas%shlibext,%linalg_test_lib_dir/libmlir_test_cblas_interface%shlibext \
-// RUN: | FileCheck %s
-
-// RUN: mlir-opt %s -convert-linalg-to-loops -convert-linalg-to-std -convert-linalg-to-llvm \
-// RUN: | mlir-cpu-runner -e dot -entry-point-result=f32 -shared-libs=%linalg_test_lib_dir/libmlir_test_cblas%shlibext,%linalg_test_lib_dir/libmlir_test_cblas_interface%shlibext \
-// RUN: | FileCheck %s
-
-// RUN: mlir-opt %s -convert-linalg-to-std -convert-linalg-to-llvm \
-// RUN: | mlir-cpu-runner -e matmul -entry-point-result=f32 -shared-libs=%linalg_test_lib_dir/libmlir_test_cblas%shlibext,%linalg_test_lib_dir/libmlir_test_cblas_interface%shlibext \
-// RUN: | FileCheck %s
-
-// RUN: mlir-opt %s -convert-linalg-to-loops -convert-linalg-to-std -convert-linalg-to-llvm \
-// RUN: | mlir-cpu-runner -e matmul -entry-point-result=f32 -shared-libs=%linalg_test_lib_dir/libmlir_test_cblas%shlibext,%linalg_test_lib_dir/libmlir_test_cblas_interface%shlibext \
-// RUN: | FileCheck %s
-
-// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=2,3,4" -linalg-promote-subviews -convert-linalg-to-loops -convert-linalg-to-std -convert-linalg-to-llvm \
-// RUN: | mlir-cpu-runner -e matmul -entry-point-result=f32 -shared-libs=%linalg_test_lib_dir/libmlir_test_cblas%shlibext,%linalg_test_lib_dir/libmlir_test_cblas_interface%shlibext \
-// RUN: | FileCheck %s
-
-// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=2,3,4" -linalg-promote-subviews -convert-linalg-to-std -convert-linalg-to-llvm \
-// RUN: | mlir-cpu-runner -e matmul -entry-point-result=f32 -shared-libs=%linalg_test_lib_dir/libmlir_test_cblas%shlibext,%linalg_test_lib_dir/libmlir_test_cblas_interface%shlibext \
-// RUN: | FileCheck %s
-
-// Creates and returns a 1-D buffer of size %s filled with the value %f
-func @alloc_filled_f32(%s : index, %f : f32) -> memref<?xi8> {
-  %c0 = constant 0 : index
-  %c1 = constant 1 : index
-  %c4 = constant 4 : index
-  %s4 = muli %s, %c4: index
-  %buf = alloc(%s4) {alignment = 256} : memref<?xi8>
-  %V = view %buf[%c0][%s] : memref<?xi8> to memref<?xf32>
-  linalg.fill(%V, %f) : memref<?xf32>, f32
-  return %buf : memref<?xi8>
-}
-
-// Test for linalg.dot.
-func @dot() -> f32 {
-  %c0 = constant 0 : index
-  %c1 = constant 1 : index
-  %c16 = constant 16 : index
-  %f10 = constant 10.00000e+00 : f32
-  %f1 = constant 1.00000e+00 : f32
-  %f2 = constant 2.00000e+00 : f32
-
-  %bA = call @alloc_filled_f32(%c16, %f2) : (index, f32) -> (memref<?xi8>)
-  %bB = call @alloc_filled_f32(%c16, %f1) : (index, f32) -> (memref<?xi8>)
-  %bC = call @alloc_filled_f32(%c1, %f10) : (index, f32) -> (memref<?xi8>)
-
-  %A = view %bA[%c0][%c16] : memref<?xi8> to memref<?xf32>
-  %B = view %bB[%c0][%c16] : memref<?xi8> to memref<?xf32>
-  %C = view %bC[%c0][] : memref<?xi8> to memref<f32>
-
-  linalg.dot ins(%A, %B : memref<?xf32>, memref<?xf32>)
-            outs(%C : memref<f32>)
-  %res = load %C[] : memref<f32>
-
-  dealloc %bC : memref<?xi8>
-  dealloc %bB : memref<?xi8>
-  dealloc %bA : memref<?xi8>
-
-  return %res : f32
-}
-
-// Test for linalg.matmul.
-func @matmul() -> f32 {
-  %c0 = constant 0 : index
-  %c1 = constant 1 : index
-  %c6 = constant 6 : index
-  %c7 = constant 7 : index
-  %c2 = constant 2 : index
-  %c16 = constant 16 : index
-  %c4 = constant 4 : index
-  %c32 = constant 32 : index
-  %f1 = constant 1.00000e+00 : f32
-  %f2 = constant 2.00000e+00 : f32
-  %f10 = constant 10.00000e+00 : f32
-
-  %bA = call @alloc_filled_f32(%c32, %f2) : (index, f32) -> (memref<?xi8>)
-  %bB = call @alloc_filled_f32(%c32, %f1) : (index, f32) -> (memref<?xi8>)
-  %bC = call @alloc_filled_f32(%c4, %f10) : (index, f32) -> (memref<?xi8>)
-
-  %A = view %bA[%c0][%c2, %c16] : memref<?xi8> to memref<?x?xf32>
-  %B = view %bB[%c0][%c16, %c2] : memref<?xi8> to memref<?x?xf32>
-  %C = view %bC[%c0][%c2, %c2] : memref<?xi8> to memref<?x?xf32>
-
-  linalg.matmul ins(%A, %B : memref<?x?xf32>, memref<?x?xf32>)
-               outs(%C : memref<?x?xf32>)
-  %res = load %C[%c0, %c1] : memref<?x?xf32>
-
-  dealloc %bC : memref<?xi8>
-  dealloc %bB : memref<?xi8>
-  dealloc %bA : memref<?xi8>
-
-  return %res : f32
-}
-
-// All tests return this value
-// CHECK: 4.2{{0+}}e+01

diff  --git a/mlir/test/mlir-cpu-runner/mlir_test_cblas.cpp b/mlir/test/mlir-cpu-runner/mlir_test_cblas.cpp
deleted file mode 100644
index 99a6a2658422..000000000000
--- a/mlir/test/mlir-cpu-runner/mlir_test_cblas.cpp
+++ /dev/null
@@ -1,46 +0,0 @@
-//===- mlir_test_cblas.cpp - Simple Blas subset implementation ------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// Simple Blas subset implementation.
-//
-//===----------------------------------------------------------------------===//
-
-#include "include/mlir_test_cblas.h"
-#include <assert.h>
-
-extern "C" float mlir_test_cblas_sdot(const int N, const float *X,
-                                      const int incX, const float *Y,
-                                      const int incY) {
-  float res = 0.0f;
-  for (int i = 0; i < N; ++i)
-    res += X[i * incX] * Y[i * incY];
-  return res;
-}
-
-extern "C" void mlir_test_cblas_sgemm(
-    const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA,
-    const enum CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
-    const float alpha, const float *A, const int lda, const float *B,
-    const int ldb, const float beta, float *C, const int ldc) {
-  assert(Order == CBLAS_ORDER::CblasRowMajor);
-  assert(TransA == CBLAS_TRANSPOSE::CblasNoTrans);
-  assert(TransB == CBLAS_TRANSPOSE::CblasNoTrans);
-  for (int m = 0; m < M; ++m) {
-    auto *pA = A + m * lda;
-    auto *pC = C + m * ldc;
-    for (int n = 0; n < N; ++n) {
-      float c = pC[n];
-      float res = 0.0f;
-      for (int k = 0; k < K; ++k) {
-        auto *pB = B + k * ldb;
-        res += pA[k] * pB[n];
-      }
-      pC[n] = alpha * c + beta * res;
-    }
-  }
-}

diff  --git a/mlir/test/mlir-cpu-runner/mlir_test_cblas_interface.cpp b/mlir/test/mlir-cpu-runner/mlir_test_cblas_interface.cpp
deleted file mode 100644
index eaeeaf6cd550..000000000000
--- a/mlir/test/mlir-cpu-runner/mlir_test_cblas_interface.cpp
+++ /dev/null
@@ -1,107 +0,0 @@
-//===- mlir_test_cblas_interface.cpp - Simple Blas subset interface -------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// Simple Blas subset interface implementation.
-//
-//===----------------------------------------------------------------------===//
-
-#include "include/mlir_test_cblas_interface.h"
-#include "include/mlir_test_cblas.h"
-#include <assert.h>
-#include <iostream>
-
-extern "C" void
-_mlir_ciface_linalg_fill_viewf32_f32(StridedMemRefType<float, 0> *X, float f) {
-  X->data[X->offset] = f;
-}
-
-extern "C" void
-_mlir_ciface_linalg_fill_viewsxf32_f32(StridedMemRefType<float, 1> *X,
-                                       float f) {
-  for (unsigned i = 0; i < X->sizes[0]; ++i)
-    *(X->data + X->offset + i * X->strides[0]) = f;
-}
-
-extern "C" void
-_mlir_ciface_linalg_fill_viewsxsxf32_f32(StridedMemRefType<float, 2> *X,
-                                         float f) {
-  for (unsigned i = 0; i < X->sizes[0]; ++i)
-    for (unsigned j = 0; j < X->sizes[1]; ++j)
-      *(X->data + X->offset + i * X->strides[0] + j * X->strides[1]) = f;
-}
-
-extern "C" void
-_mlir_ciface_linalg_copy_viewf32_viewf32(StridedMemRefType<float, 0> *I,
-                                         StridedMemRefType<float, 0> *O) {
-  O->data[O->offset] = I->data[I->offset];
-}
-
-extern "C" void
-_mlir_ciface_linalg_copy_viewsxf32_viewsxf32(StridedMemRefType<float, 1> *I,
-                                             StridedMemRefType<float, 1> *O) {
-  if (I->sizes[0] != O->sizes[0]) {
-    std::cerr << "Incompatible strided memrefs\n";
-    printMemRefMetaData(std::cerr, *I);
-    printMemRefMetaData(std::cerr, *O);
-    return;
-  }
-  for (unsigned i = 0; i < I->sizes[0]; ++i)
-    O->data[O->offset + i * O->strides[0]] =
-        I->data[I->offset + i * I->strides[0]];
-}
-
-extern "C" void _mlir_ciface_linalg_copy_viewsxsxf32_viewsxsxf32(
-    StridedMemRefType<float, 2> *I, StridedMemRefType<float, 2> *O) {
-  if (I->sizes[0] != O->sizes[0] || I->sizes[1] != O->sizes[1]) {
-    std::cerr << "Incompatible strided memrefs\n";
-    printMemRefMetaData(std::cerr, *I);
-    printMemRefMetaData(std::cerr, *O);
-    return;
-  }
-  auto so0 = O->strides[0], so1 = O->strides[1];
-  auto si0 = I->strides[0], si1 = I->strides[1];
-  for (unsigned i = 0; i < I->sizes[0]; ++i)
-    for (unsigned j = 0; j < I->sizes[1]; ++j)
-      O->data[O->offset + i * so0 + j * so1] =
-          I->data[I->offset + i * si0 + j * si1];
-}
-
-extern "C" void _mlir_ciface_linalg_dot_viewsxf32_viewsxf32_viewf32(
-    StridedMemRefType<float, 1> *X, StridedMemRefType<float, 1> *Y,
-    StridedMemRefType<float, 0> *Z) {
-  if (X->strides[0] != 1 || Y->strides[0] != 1 || X->sizes[0] != Y->sizes[0]) {
-    std::cerr << "Incompatible strided memrefs\n";
-    printMemRefMetaData(std::cerr, *X);
-    printMemRefMetaData(std::cerr, *Y);
-    printMemRefMetaData(std::cerr, *Z);
-    return;
-  }
-  Z->data[Z->offset] +=
-      mlir_test_cblas_sdot(X->sizes[0], X->data + X->offset, X->strides[0],
-                           Y->data + Y->offset, Y->strides[0]);
-}
-
-extern "C" void _mlir_ciface_linalg_matmul_viewsxsxf32_viewsxsxf32_viewsxsxf32(
-    StridedMemRefType<float, 2> *A, StridedMemRefType<float, 2> *B,
-    StridedMemRefType<float, 2> *C) {
-  if (A->strides[1] != B->strides[1] || A->strides[1] != C->strides[1] ||
-      A->strides[1] != 1 || A->sizes[0] < A->strides[1] ||
-      B->sizes[0] < B->strides[1] || C->sizes[0] < C->strides[1] ||
-      C->sizes[0] != A->sizes[0] || C->sizes[1] != B->sizes[1] ||
-      A->sizes[1] != B->sizes[0]) {
-    printMemRefMetaData(std::cerr, *A);
-    printMemRefMetaData(std::cerr, *B);
-    printMemRefMetaData(std::cerr, *C);
-    return;
-  }
-  mlir_test_cblas_sgemm(
-      CBLAS_ORDER::CblasRowMajor, CBLAS_TRANSPOSE::CblasNoTrans,
-      CBLAS_TRANSPOSE::CblasNoTrans, C->sizes[0], C->sizes[1], A->sizes[1],
-      1.0f, A->data + A->offset, A->strides[0], B->data + B->offset,
-      B->strides[0], 1.0f, C->data + C->offset, C->strides[0]);
-}


        


More information about the Mlir-commits mailing list