[Mlir-commits] [mlir] [mlir][sparse] support 'batch' dimensions in sparse_tensor.print (PR #91411)
Aart Bik
llvmlistbot at llvm.org
Tue May 7 16:42:40 PDT 2024
https://github.com/aartbik created https://github.com/llvm/llvm-project/pull/91411
None
>From b1a1d322e8bd31653a879060145aef42325a1fda Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Tue, 7 May 2024 16:32:04 -0700
Subject: [PATCH] [mlir][sparse] support 'batch' dimensions in
sparse_tensor.print
---
.../Transforms/SparseTensorCodegen.cpp | 12 ++-
.../Transforms/SparseTensorRewriting.cpp | 66 ++++++++++-------
.../SparseTensor/CPU/sparse_pack_d.mlir | 12 +--
.../SparseTensor/CPU/sparse_print_3d.mlir | 74 +++++++++++++++++++
4 files changed, 130 insertions(+), 34 deletions(-)
create mode 100755 mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_print_3d.mlir
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index d9b203a886488..164e722c45dba 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -417,11 +417,17 @@ static void genEndInsert(OpBuilder &builder, Location loc,
/// Generates a subview into the sizes.
static Value genSliceToSize(OpBuilder &builder, Location loc, Value mem,
Value sz) {
- auto elemTp = llvm::cast<MemRefType>(mem.getType()).getElementType();
+ auto memTp = llvm::cast<MemRefType>(mem.getType());
+ // For higher-dimensional memrefs, we assume that the innermost
+ // dimension is always of the right size.
+ // TODO: generate complex truncating view here too?
+ if (memTp.getRank() > 1)
+ return mem;
+ // Truncate linear memrefs to given size.
return builder
.create<memref::SubViewOp>(
- loc, MemRefType::get({ShapedType::kDynamic}, elemTp), mem,
- ValueRange{}, ValueRange{sz}, ValueRange{},
+ loc, MemRefType::get({ShapedType::kDynamic}, memTp.getElementType()),
+ mem, ValueRange{}, ValueRange{sz}, ValueRange{},
ArrayRef<int64_t>{0}, // static offset
ArrayRef<int64_t>{ShapedType::kDynamic}, // dynamic size
ArrayRef<int64_t>{1}) // static stride
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 7d469198a653c..025fd3331ba89 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -785,45 +785,61 @@ struct PrintRewriter : public OpRewritePattern<PrintOp> {
}
private:
- // Helper to print contents of a single memref. Note that for the "push_back"
- // vectors, this prints the full capacity, not just the size. This is done
- // on purpose, so that clients see how much storage has been allocated in
- // total. Contents of the extra capacity in the buffer may be uninitialized
- // (unless the flag enable-buffer-initialization is set to true).
+ // Helper to print contents of a single memref. For "push_back" vectors,
+ // we assume that the previous getters for pos/crd/val have added a
+ // slice-to-size view to make sure we just print the size and not the
+ // full capacity.
//
- // Generates code to print:
+ // Generates code to print (1-dim or higher):
// ( a0, a1, ... )
static void printContents(PatternRewriter &rewriter, Location loc,
Value vec) {
+ auto shape = cast<ShapedType>(vec.getType()).getShape();
+ SmallVector<Value> idxs;
+ printContentsLevel(rewriter, loc, vec, 0, shape, idxs);
+ rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::NewLine);
+ }
+
+ // Helper to the helper.
+ static void printContentsLevel(PatternRewriter &rewriter, Location loc,
+ Value vec, unsigned i, ArrayRef<int64_t> shape,
+ SmallVectorImpl<Value> &idxs) {
// Open bracket.
rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
- // For loop over elements.
+ // Generate for loop.
auto zero = constantIndex(rewriter, loc, 0);
- auto size = rewriter.create<memref::DimOp>(loc, vec, zero);
+ auto index = constantIndex(rewriter, loc, i);
+ auto size = rewriter.create<memref::DimOp>(loc, vec, index);
auto step = constantIndex(rewriter, loc, 1);
auto forOp = rewriter.create<scf::ForOp>(loc, zero, size, step);
+ idxs.push_back(forOp.getInductionVar());
rewriter.setInsertionPointToStart(forOp.getBody());
- auto idx = forOp.getInductionVar();
- auto val = rewriter.create<memref::LoadOp>(loc, vec, idx);
- if (llvm::isa<ComplexType>(val.getType())) {
- // Since the vector dialect does not support complex types in any op,
- // we split those into (real, imag) pairs here.
- Value real = rewriter.create<complex::ReOp>(loc, val);
- Value imag = rewriter.create<complex::ImOp>(loc, val);
- rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
- rewriter.create<vector::PrintOp>(loc, real,
- vector::PrintPunctuation::Comma);
- rewriter.create<vector::PrintOp>(loc, imag,
- vector::PrintPunctuation::Close);
- rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Comma);
+ if (i < shape.size() - 1) {
+ // Enter deeper loop nest.
+ printContentsLevel(rewriter, loc, vec, i + 1, shape, idxs);
} else {
- rewriter.create<vector::PrintOp>(loc, val,
- vector::PrintPunctuation::Comma);
+ // Actual contents printing.
+ auto val = rewriter.create<memref::LoadOp>(loc, vec, idxs);
+ if (llvm::isa<ComplexType>(val.getType())) {
+ // Since the vector dialect does not support complex types in any op,
+ // we split those into (real, imag) pairs here.
+ Value real = rewriter.create<complex::ReOp>(loc, val);
+ Value imag = rewriter.create<complex::ImOp>(loc, val);
+ rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
+ rewriter.create<vector::PrintOp>(loc, real,
+ vector::PrintPunctuation::Comma);
+ rewriter.create<vector::PrintOp>(loc, imag,
+ vector::PrintPunctuation::Close);
+ rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Comma);
+ } else {
+ rewriter.create<vector::PrintOp>(loc, val,
+ vector::PrintPunctuation::Comma);
+ }
}
+ idxs.pop_back();
rewriter.setInsertionPointAfter(forOp);
- // Close bracket and end of line.
+ // Close bracket.
rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Close);
- rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::NewLine);
}
// Helper method to print run-time lvl/dim sizes.
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack_d.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack_d.mlir
index 20ae7e86285cc..467a77f30777a 100755
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack_d.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack_d.mlir
@@ -29,7 +29,7 @@
crdWidth = 32
}>
-#BatchedCSR = #sparse_tensor.encoding<{
+#DenseCSR = #sparse_tensor.encoding<{
map = (d0, d1, d2) -> (d0 : dense, d1 : dense, d2 : compressed),
posWidth = 64,
crdWidth = 32
@@ -42,7 +42,7 @@
}>
//
-// Test assembly operation with CCC, batched-CSR and CSR-dense.
+// Test assembly operation with CCC, dense-CSR and CSR-dense.
//
module {
//
@@ -77,7 +77,7 @@ module {
tensor<6xi64>, tensor<8xi32>), tensor<8xf32> to tensor<4x3x2xf32, #CCC>
//
- // Setup BatchedCSR.
+ // Setup DenseCSR.
//
%data1 = arith.constant dense<
@@ -88,7 +88,7 @@ module {
%crd1 = arith.constant dense<
[ 0, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1]> : tensor<16xi32>
- %s1 = sparse_tensor.assemble (%pos1, %crd1), %data1 : (tensor<13xi64>, tensor<16xi32>), tensor<16xf32> to tensor<4x3x2xf32, #BatchedCSR>
+ %s1 = sparse_tensor.assemble (%pos1, %crd1), %data1 : (tensor<13xi64>, tensor<16xi32>), tensor<16xf32> to tensor<4x3x2xf32, #DenseCSR>
//
// Setup CSRDense.
@@ -137,7 +137,7 @@ module {
// CHECK-NEXT: ----
//
sparse_tensor.print %s0 : tensor<4x3x2xf32, #CCC>
- sparse_tensor.print %s1 : tensor<4x3x2xf32, #BatchedCSR>
+ sparse_tensor.print %s1 : tensor<4x3x2xf32, #DenseCSR>
sparse_tensor.print %s2 : tensor<4x3x2xf32, #CSRDense>
// TODO: This check is no longer needed once the codegen path uses the
@@ -148,7 +148,7 @@ module {
// sparse_tensor.assemble copies buffers when running with the runtime
// library. Deallocations are not needed when running in codegen mode.
bufferization.dealloc_tensor %s0 : tensor<4x3x2xf32, #CCC>
- bufferization.dealloc_tensor %s1 : tensor<4x3x2xf32, #BatchedCSR>
+ bufferization.dealloc_tensor %s1 : tensor<4x3x2xf32, #DenseCSR>
bufferization.dealloc_tensor %s2 : tensor<4x3x2xf32, #CSRDense>
}
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_print_3d.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_print_3d.mlir
new file mode 100755
index 0000000000000..98dee304fa511
--- /dev/null
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_print_3d.mlir
@@ -0,0 +1,74 @@
+//--------------------------------------------------------------------------------------------------
+// WHEN CREATING A NEW TEST, PLEASE JUST COPY & PASTE WITHOUT EDITS.
+//
+// Set-up that's shared across all tests in this directory. In principle, this
+// config could be moved to lit.local.cfg. However, there are downstream users that
+// do not use these LIT config files. Hence why this is kept inline.
+//
+// DEFINE: %{sparsifier_opts} = enable-runtime-library=true
+// DEFINE: %{sparsifier_opts_sve} = enable-arm-sve=true %{sparsifier_opts}
+// DEFINE: %{compile} = mlir-opt %s --sparsifier="%{sparsifier_opts}"
+// DEFINE: %{compile_sve} = mlir-opt %s --sparsifier="%{sparsifier_opts_sve}"
+// DEFINE: %{run_libs} = -shared-libs=%mlir_c_runner_utils,%mlir_runner_utils
+// DEFINE: %{run_opts} = -e main -entry-point-result=void
+// DEFINE: %{run} = mlir-cpu-runner %{run_opts} %{run_libs}
+// DEFINE: %{run_sve} = %mcr_aarch64_cmd --march=aarch64 --mattr="+sve" %{run_opts} %{run_libs}
+//
+// DEFINE: %{env} =
+//--------------------------------------------------------------------------------------------------
+
+// TODO: make this work with libgen
+
+// Do the same run, but now with direct IR generation.
+// REDEFINE: %{sparsifier_opts} = enable-runtime-library=false enable-buffer-initialization=true
+// RUN: %{compile} | %{run} | FileCheck %s
+//
+
+#BatchedCSR = #sparse_tensor.encoding<{
+ map = (d0, d1, d2) -> (d0 : batch, d1 : dense, d2 : compressed)
+}>
+
+module {
+
+ //
+ // Main driver that tests 3-D sparse tensor printing.
+ //
+ func.func @main() {
+
+ %pos = arith.constant dense<
+ [[ 0, 8, 16, 24, 32],
+ [ 0, 8, 16, 24, 32]]
+ > : tensor<2x5xindex>
+
+ %crd = arith.constant dense<
+ [[0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7],
+ [0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7]]
+ > : tensor<2x32xindex>
+
+ %val = arith.constant dense<
+ [[ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.,
+ 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22.,
+ 23., 24., 25., 26., 27., 28., 29., 30., 31., 32.],
+ [33., 34., 35., 36., 37., 38., 39., 40., 41., 42., 43.,
+ 44., 45., 46., 47., 48., 49., 50., 51., 52., 53., 54.,
+ 55., 56., 57., 58., 59., 60., 61., 62., 63., 64.]]
+ > : tensor<2x32xf64>
+
+ %X = sparse_tensor.assemble (%pos, %crd), %val
+ : (tensor<2x5xindex>, tensor<2x32xindex>), tensor<2x32xf64> to tensor<2x4x8xf64, #BatchedCSR>
+
+ // CHECK: ---- Sparse Tensor ----
+ // CHECK-NEXT: nse = 32
+ // CHECK-NEXT: dim = ( 2, 4, 8 )
+ // CHECK-NEXT: lvl = ( 2, 4, 8 )
+ // CHECK-NEXT: pos[2] : ( ( 0, 8, 16, 24, 32, )( 0, 8, 16, 24, 32, ) )
+ // CHECK-NEXT: crd[2] : ( ( 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, )
+ // CHECK-SAME: ( 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, ) )
+ // CHECK-NEXT: values : ( ( 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, )
+ // CHECK-SAME: ( 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, ) )
+ // CHECK-NEXT: ----
+ sparse_tensor.print %X : tensor<2x4x8xf64, #BatchedCSR>
+
+ return
+ }
+}
More information about the Mlir-commits
mailing list