[Mlir-commits] [mlir] e383eaa - [mlir][sparse] parameterize MTTKRP kernel
Aart Bik
llvmlistbot at llvm.org
Fri Jan 14 16:20:39 PST 2022
Author: Aart Bik
Date: 2022-01-14T16:20:31-08:00
New Revision: e383eaa647da3a3f45313d22d83b6d6b5cdc9fc1
URL: https://github.com/llvm/llvm-project/commit/e383eaa647da3a3f45313d22d83b6d6b5cdc9fc1
DIFF: https://github.com/llvm/llvm-project/commit/e383eaa647da3a3f45313d22d83b6d6b5cdc9fc1.diff
LOG: [mlir][sparse] parameterize MTTKRP kernel
Rather than hardcoding all constants, we now use the input tensor to drive the
code setup. Of course, we still need to hardcode dim-2 of A and the final
verification in CHECK is input dependent, but overall this sets a slightly
better example of tensor setup in general.
Reviewed By: bixia
Differential Revision: https://reviews.llvm.org/D117349
Added:
Modified:
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_mttkrp.mlir
Removed:
################################################################################
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_mttkrp.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_mttkrp.mlir
index ca1287387d72e..34934093f5b67 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_mttkrp.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_mttkrp.mlir
@@ -75,49 +75,53 @@ module {
// Main driver that reads matrix from file and calls the sparse kernel.
//
func @entry() {
- %i0 = arith.constant 0. : f64
+ %f0 = arith.constant 0.0 : f64
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
- %c3 = arith.constant 3 : index
- %c4 = arith.constant 4 : index
- %c5 = arith.constant 5 : index
- %c256 = arith.constant 256 : index
- // Read the sparse B input from a file.
+ // Read the sparse input tensor B from a file.
%fileName = call @getTensorFilename(%c0) : (index) -> (!Filename)
%b = sparse_tensor.new %fileName
: !Filename to tensor<?x?x?xf64, #SparseTensor>
- // Initialize dense C and D inputs and dense output A.
- %cdata = memref.alloc(%c3, %c5) : memref<?x?xf64>
- scf.for %i = %c0 to %c3 step %c1 {
- scf.for %j = %c0 to %c5 step %c1 {
- %k0 = arith.muli %i, %c5 : index
+ // Get sizes from B, pick a fixed size for dim-2 of A.
+ %isz = tensor.dim %b, %c0 : tensor<?x?x?xf64, #SparseTensor>
+ %jsz = arith.constant 5 : index
+ %ksz = tensor.dim %b, %c1 : tensor<?x?x?xf64, #SparseTensor>
+ %lsz = tensor.dim %b, %c2 : tensor<?x?x?xf64, #SparseTensor>
+
+ // Initialize dense input matrix C.
+ %cdata = memref.alloc(%ksz, %jsz) : memref<?x?xf64>
+ scf.for %k = %c0 to %ksz step %c1 {
+ scf.for %j = %c0 to %jsz step %c1 {
+ %k0 = arith.muli %k, %jsz : index
%k1 = arith.addi %k0, %j : index
%k2 = arith.index_cast %k1 : index to i32
- %k = arith.sitofp %k2 : i32 to f64
- memref.store %k, %cdata[%i, %j] : memref<?x?xf64>
+ %kf = arith.sitofp %k2 : i32 to f64
+ memref.store %kf, %cdata[%k, %j] : memref<?x?xf64>
}
}
%c = bufferization.to_tensor %cdata : memref<?x?xf64>
- %ddata = memref.alloc(%c4, %c5) : memref<?x?xf64>
- scf.for %i = %c0 to %c4 step %c1 {
- scf.for %j = %c0 to %c5 step %c1 {
- %k0 = arith.muli %i, %c5 : index
+ // Initialize dense input matrix D.
+ %ddata = memref.alloc(%lsz, %jsz) : memref<?x?xf64>
+ scf.for %l = %c0 to %lsz step %c1 {
+ scf.for %j = %c0 to %jsz step %c1 {
+ %k0 = arith.muli %l, %jsz : index
%k1 = arith.addi %k0, %j : index
%k2 = arith.index_cast %k1 : index to i32
- %k = arith.sitofp %k2 : i32 to f64
- memref.store %k, %ddata[%i, %j] : memref<?x?xf64>
+ %kf = arith.sitofp %k2 : i32 to f64
+ memref.store %kf, %ddata[%l, %j] : memref<?x?xf64>
}
}
%d = bufferization.to_tensor %ddata : memref<?x?xf64>
- %adata = memref.alloc(%c2, %c5) : memref<?x?xf64>
- scf.for %i = %c0 to %c2 step %c1 {
- scf.for %j = %c0 to %c5 step %c1 {
- memref.store %i0, %adata[%i, %j] : memref<?x?xf64>
+ // Initialize dense output matrix A.
+ %adata = memref.alloc(%isz, %jsz) : memref<?x?xf64>
+ scf.for %i = %c0 to %isz step %c1 {
+ scf.for %j = %c0 to %jsz step %c1 {
+ memref.store %f0, %adata[%i, %j] : memref<?x?xf64>
}
}
%a = bufferization.to_tensor %adata : memref<?x?xf64>
@@ -133,7 +137,7 @@ module {
// CHECK: ( 10000, 14225, 19180, 24865, 31280 ) )
//
%m = bufferization.to_memref %0 : memref<?x?xf64>
- %v = vector.transfer_read %m[%c0, %c0], %i0
+ %v = vector.transfer_read %m[%c0, %c0], %f0
: memref<?x?xf64>, vector<2x5xf64>
vector.print %v : vector<2x5xf64>
More information about the Mlir-commits
mailing list