[Mlir-commits] [mlir] fc12602 - [mlir][sparse] fuse collapse_shape on sparse tensor with GenericOp.
Peiming Liu
llvmlistbot at llvm.org
Wed Mar 1 11:05:53 PST 2023
Author: Peiming Liu
Date: 2023-03-01T19:05:48Z
New Revision: fc126022e884da273d6796d97741fc15ce50fd54
URL: https://github.com/llvm/llvm-project/commit/fc126022e884da273d6796d97741fc15ce50fd54
DIFF: https://github.com/llvm/llvm-project/commit/fc126022e884da273d6796d97741fc15ce50fd54.diff
LOG: [mlir][sparse] fuse collapse_shape on sparse tensor with GenericOp.
Instead of always materializing a new sparse tensor after reshape, this patch tries to fuses the reshape (currently only on COO) with GenericOp and coiterates with the reshaped tensors without allocating a new sparse tensor.
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D145016
Added:
mlir/test/Dialect/SparseTensor/sparse_reshape_dot.mlir
mlir/test/Integration/Dialect/SparseTensor/CPU/reshape_dot.mlir
Modified:
mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
index a52a0dadd42b9..8718697e0cc85 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
@@ -103,7 +103,7 @@ SparseTensorEncodingAttr getSparseTensorEncoding(Type type);
/// Returns true iff the given type is a COO type where the last level
/// is unique.
-bool isUniqueCOOType(TensorType tp);
+bool isUniqueCOOType(Type tp);
/// Returns the starting level for a trailing COO region that spans
/// at least two levels. If no such COO region is found, then returns
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 740023a380af4..ec185c316a42a 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -453,7 +453,7 @@ static bool isCOOType(SparseTensorEncodingAttr enc, Level startLvl,
return !isUnique || enc.isUniqueLvl(lvlRank - 1);
}
-bool mlir::sparse_tensor::isUniqueCOOType(TensorType tp) {
+bool mlir::sparse_tensor::isUniqueCOOType(Type tp) {
return isCOOType(getSparseTensorEncoding(tp), 0, /*isUnique=*/true);
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
index 72d341925cdad..bb2468580a57a 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
@@ -132,31 +132,43 @@ LoopEmitter::LoopEmitter(ValueRange tensors, StringAttr loopTag, bool hasOutput,
initialize(tensors, loopTag, hasOutput, isSparseOut, topSort);
}
-void LoopEmitter::initialize(ValueRange tensors, StringAttr loopTag,
- bool hasOutput, bool isSparseOut,
- ArrayRef<unsigned> topSort) {
+void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
+ bool isSparseOut, ArrayRef<unsigned> topSort) {
// First initializes fields.
this->loopTag = loopTag;
this->hasOutput = hasOutput;
this->isSparseOut = isSparseOut;
- this->tensors.assign(tensors.begin(), tensors.end());
+ this->tensors.assign(ts.begin(), ts.end());
this->isSparseSlices.assign(tensors.size(), false);
this->dimTypes.assign(tensors.size(), std::vector<DimLevelType>());
this->pidxs.assign(tensors.size(), std::vector<Value>());
this->coord.assign(tensors.size(), std::vector<Value>());
this->highs.assign(tensors.size(), std::vector<Value>());
+ this->lvlSizes.assign(tensors.size(), std::vector<Value>());
this->ptrBuffer.assign(tensors.size(), std::vector<Value>());
this->idxBuffer.assign(tensors.size(), std::vector<Value>());
this->valBuffer.assign(tensors.size(), nullptr);
this->loopStack.reserve(topSort.size());
this->sparsiferLoopLvlMap.assign(topSort.size(), 0);
+ this->collapseReassoc.assign(tensors.size(), nullptr);
for (size_t tid = 0, e = tensors.size(); tid < e; tid++) {
auto t = tensors[tid];
// a scalar or 0-dimension tensors
if (isZeroRankedTensorOrScalar(t.getType()))
continue;
+
auto rtp = getRankedTensorType(t);
+ if (auto reshape = t.getDefiningOp<tensor::CollapseShapeOp>();
+ isUniqueCOOType(rtp) && reshape) {
+ // TODO: Supports more kinds of sparse tensors.
+ // FIXME: We should instead lower reshape operations on sparse tensors to
+ // view change.
+ collapseReassoc[tid] = reshape.getReassociation();
+ rtp = reshape.getSrcType();
+ // Overwrites the tensor to the source tensor of reshape operations.
+ tensors[tid] = t = reshape.getSrc();
+ }
auto rank = static_cast<size_t>(rtp.getRank());
auto enc = getSparseTensorEncoding(rtp);
// We always treat sparse output tensor as dense so that we always iterate
@@ -172,6 +184,7 @@ void LoopEmitter::initialize(ValueRange tensors, StringAttr loopTag,
pidxs[tid].assign(rank, Value());
coord[tid].assign(rank, Value());
highs[tid].assign(rank, Value());
+ lvlSizes[tid].assign(rank, Value());
ptrBuffer[tid].assign(rank, Value());
idxBuffer[tid].assign(rank, Value());
}
@@ -224,7 +237,8 @@ void LoopEmitter::initializeLoopEmit(OpBuilder &builder, Location loc,
// Find upper bound in current dimension.
// FIXME: `toOrigDim` is deprecated
const Dimension d = toOrigDim(enc, l);
- highs[t][l] = mlir::linalg::createOrFoldDimOp(builder, loc, tensor, d);
+ lvlSizes[t][l] = highs[t][l] =
+ mlir::linalg::createOrFoldDimOp(builder, loc, tensor, d);
}
// Perform the required bufferization. Dense inputs materialize
@@ -325,6 +339,8 @@ Operation *LoopEmitter::enterLoopOverTensorAtDim(
}
auto enc = getSparseTensorEncoding(tensors[tid].getType());
+ auto reass = getCollapseReassociation(tid, dim);
+ dim = reass.front();
// TODO: support dynamic slices.
Value step = constantIndex(builder, loc, 1);
Value lo = isSparseInput ? pidxs[tid][dim] // current offset
@@ -334,6 +350,7 @@ Operation *LoopEmitter::enterLoopOverTensorAtDim(
Operation *loop = nullptr;
Value iv;
if (isParallel) {
+ assert(collapseReassoc[tid] == nullptr);
scf::ParallelOp parOp =
builder.create<scf::ParallelOp>(loc, lo, hi, step, reduc);
builder.setInsertionPointToStart(parOp.getBody());
@@ -365,10 +382,21 @@ Operation *LoopEmitter::enterLoopOverTensorAtDim(
Value c;
if (isSparseInput) {
- pidxs[tid][dim] = iv;
- // Generating a load on the indices array yields the coordinate.
- Value ptr = idxBuffer[tid][dim];
- c = genIndexLoad(builder, loc, ptr, iv);
+ assert(reass.size() == 1 || isUniqueCOOType(tensors[tid].getType()));
+ c = constantIndex(builder, loc, 0);
+ for (unsigned i = 0; i < reass.size(); i++) {
+ auto lvl = reass[i];
+ // For COO, the pidxs are always the same across consecutive levels.
+ pidxs[tid][lvl] = iv;
+ // Generating a load on the indices array yields the coordinate.
+ Value ptr = idxBuffer[tid][lvl];
+ Value off = genIndexLoad(builder, loc, ptr, iv);
+ c = builder.create<arith::AddIOp>(loc, c, off);
+ if (i != reass.size() - 1) {
+ c = builder.create<arith::MulIOp>(loc, c,
+ this->lvlSizes[tid][reass[i + 1]]);
+ }
+ }
} else {
// Dense tensor, the coordinates is the inducation variable.
c = iv;
@@ -643,27 +671,30 @@ void LoopEmitter::prepareLoopOverTensorAtDim(OpBuilder &builder, Location loc,
if (isDenseDLT(dimType))
return;
- // Either the first dimension, or the previous dimension has been set.
- assert(dim == 0 || pidxs[tid][dim - 1]);
- Value c0 = constantIndex(builder, loc, 0);
- Value c1 = constantIndex(builder, loc, 1);
- if (isCompressedDLT(dimType)) {
- Value ptr = ptrBuffer[tid][dim];
-
- Value pLo = dim == 0 ? c0 : pidxs[tid][dim - 1];
- pidxs[tid][dim] = genIndexLoad(builder, loc, ptr, pLo);
-
- Value pHi = builder.create<arith::AddIOp>(loc, pLo, c1);
- highs[tid][dim] = genIndexLoad(builder, loc, ptr, pHi);
- return;
- }
- if (isSingletonDLT(dimType)) {
- Value pLo = dim == 0 ? c0 : pidxs[tid][dim - 1];
- Value pHi = builder.create<arith::AddIOp>(loc, pLo, c1);
+ auto reassoc = getCollapseReassociation(tid, dim);
+ for (auto lvl : reassoc) {
+ // Either the first dimension, or the previous dimension has been set.
+ assert(lvl == 0 || pidxs[tid][lvl - 1]);
+ Value c0 = constantIndex(builder, loc, 0);
+ Value c1 = constantIndex(builder, loc, 1);
+ if (isCompressedDLT(dimType)) {
+ Value ptr = ptrBuffer[tid][lvl];
+
+ Value pLo = lvl == 0 ? c0 : pidxs[tid][lvl - 1];
+ pidxs[tid][lvl] = genIndexLoad(builder, loc, ptr, pLo);
+
+ Value pHi = builder.create<arith::AddIOp>(loc, pLo, c1);
+ highs[tid][lvl] = genIndexLoad(builder, loc, ptr, pHi);
+ return;
+ }
+ if (isSingletonDLT(dimType)) {
+ Value pLo = lvl == 0 ? c0 : pidxs[tid][lvl - 1];
+ Value pHi = builder.create<arith::AddIOp>(loc, pLo, c1);
- pidxs[tid][dim] = pLo;
- highs[tid][dim] = pHi;
- return;
+ pidxs[tid][lvl] = pLo;
+ highs[tid][lvl] = pHi;
+ return;
+ }
}
llvm_unreachable("Unrecognizable dimesion type!");
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
index 419d7039d1e6a..b8c287140b4c6 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
@@ -256,6 +256,7 @@ class LoopEmitter {
std::vector<std::vector<Value>> pidxs;
std::vector<std::vector<Value>> coord;
std::vector<std::vector<Value>> highs;
+ std::vector<std::vector<Value>> lvlSizes;
std::vector<std::vector<Value>> ptrBuffer; // to_pointers
std::vector<std::vector<Value>> idxBuffer; // to_indices
std::vector<Value> valBuffer; // to_value
@@ -276,6 +277,28 @@ class LoopEmitter {
/// general.
std::vector<unsigned> sparsiferLoopLvlMap;
+ //
+ // View based reshape related-fields and methods
+ //
+
+ /// Collapse Reassociations related to a specific tensor
+ // TODO: support expand.
+ std::vector<ArrayAttr> collapseReassoc;
+
+ /// Get the collapse reassociation for tensors[tid] on l. For unreshaped
+ /// operands, the reassociation is simply an identity transformation.
+ SmallVector<int64_t, 2> getCollapseReassociation(unsigned tid, unsigned l) {
+ // Returns for SmallVector<int64_t, 2> just like `ReassociaionIndices`
+ if (auto reass = collapseReassoc[tid]) {
+ auto attr = reass[l];
+ return llvm::to_vector<2>(
+ llvm::map_range(attr.cast<ArrayAttr>(), [&](Attribute indexAttr) {
+ return indexAttr.cast<IntegerAttr>().getInt();
+ }));
+ }
+ return {l};
+ }
+
/// TODO: not yet used, it should track the current level for each tensor
/// to help eliminate `dim` paramters from above APIs.
/// std::vector<size_t> curLv;
diff --git a/mlir/test/Dialect/SparseTensor/sparse_reshape_dot.mlir b/mlir/test/Dialect/SparseTensor/sparse_reshape_dot.mlir
new file mode 100644
index 0000000000000..e52aca5533574
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/sparse_reshape_dot.mlir
@@ -0,0 +1,59 @@
+// RUN: mlir-opt %s --linalg-generalize-named-ops --sparsification --cse --canonicalize | FileCheck %s
+
+#COO_2D = #sparse_tensor.encoding<{ dimLevelType = [ "compressed-nu", "singleton" ], pointerBitWidth = 32, indexBitWidth = 32 }>
+#COO_3D = #sparse_tensor.encoding<{ dimLevelType = [ "compressed-nu", "singleton-nu", "singleton" ], pointerBitWidth = 32, indexBitWidth = 32 }>
+
+// CHECK-LABEL: func.func @sparse_reshape_fused(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<5x6xf32>,
+// CHECK-SAME: %[[VAL_1:.*]]: tensor<6x2x3xf32,
+// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 5 : index
+// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 3 : index
+// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[VAL_6:.*]] = tensor.empty() : tensor<5x6xf32>
+// CHECK-DAG: %[[VAL_7:.*]] = sparse_tensor.pointers %[[VAL_1]] {dimension = 0 : index}
+// CHECK-DAG: %[[VAL_8:.*]] = sparse_tensor.indices %[[VAL_1]] {dimension = 0 : index}
+// CHECK-DAG: %[[VAL_9:.*]] = sparse_tensor.indices %[[VAL_1]] {dimension = 1 : index}
+// CHECK-DAG: %[[VAL_10:.*]] = sparse_tensor.indices %[[VAL_1]] {dimension = 2 : index}
+// CHECK-DAG: %[[VAL_11:.*]] = sparse_tensor.values %[[VAL_1]]
+// CHECK-DAG: %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_6]] : memref<5x6xf32>
+// CHECK: scf.for %[[VAL_13:.*]] = %[[VAL_4]] to %[[VAL_2]] step %[[VAL_5]] {
+// CHECK: %[[VAL_14:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_4]]] : memref<?xi32>
+// CHECK: %[[VAL_15:.*]] = arith.extui %[[VAL_14]] : i32 to i64
+// CHECK: %[[VAL_16:.*]] = arith.index_cast %[[VAL_15]] : i64 to index
+// CHECK: %[[VAL_17:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_5]]] : memref<?xi32>
+// CHECK: %[[VAL_18:.*]] = arith.extui %[[VAL_17]] : i32 to i64
+// CHECK: %[[VAL_19:.*]] = arith.index_cast %[[VAL_18]] : i64 to index
+// CHECK: scf.for %[[VAL_20:.*]] = %[[VAL_16]] to %[[VAL_19]] step %[[VAL_5]] {
+// CHECK: %[[VAL_21:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_20]]] : memref<?xi32, strided<[?], offset: ?>>
+// CHECK: %[[VAL_22:.*]] = arith.extui %[[VAL_21]] : i32 to i64
+// CHECK: %[[VAL_23:.*]] = arith.index_cast %[[VAL_22]] : i64 to index
+// CHECK: %[[VAL_24:.*]] = tensor.extract %[[VAL_0]]{{\[}}%[[VAL_13]], %[[VAL_23]]] : tensor<5x6xf32>
+// CHECK: %[[VAL_25:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_20]]] : memref<?xi32, strided<[?], offset: ?>>
+// CHECK: %[[VAL_26:.*]] = arith.extui %[[VAL_25]] : i32 to i64
+// CHECK: %[[VAL_27:.*]] = arith.index_cast %[[VAL_26]] : i64 to index
+// CHECK: %[[VAL_28:.*]] = arith.muli %[[VAL_27]], %[[VAL_3]] : index
+// CHECK: %[[VAL_29:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_20]]] : memref<?xi32, strided<[?], offset: ?>>
+// CHECK: %[[VAL_30:.*]] = arith.extui %[[VAL_29]] : i32 to i64
+// CHECK: %[[VAL_31:.*]] = arith.index_cast %[[VAL_30]] : i64 to index
+// CHECK: %[[VAL_32:.*]] = arith.addi %[[VAL_28]], %[[VAL_31]] : index
+// CHECK: %[[VAL_33:.*]] = tensor.extract %[[VAL_6]]{{\[}}%[[VAL_13]], %[[VAL_32]]] : tensor<5x6xf32>
+// CHECK: %[[VAL_34:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_20]]] : memref<?xf32>
+// CHECK: %[[VAL_35:.*]] = arith.mulf %[[VAL_24]], %[[VAL_34]] : f32
+// CHECK: %[[VAL_36:.*]] = arith.addf %[[VAL_33]], %[[VAL_35]] : f32
+// CHECK: memref.store %[[VAL_36]], %[[VAL_12]]{{\[}}%[[VAL_13]], %[[VAL_32]]] : memref<5x6xf32>
+// CHECK: } {"Emitted from" = "linalg.generic"}
+// CHECK: } {"Emitted from" = "linalg.generic"}
+// CHECK: %[[VAL_37:.*]] = bufferization.to_tensor %[[VAL_12]] : memref<5x6xf32>
+// CHECK: %[[VAL_38:.*]] = tensor.expand_shape %[[VAL_37]] {{\[\[}}0], [1, 2]] : tensor<5x6xf32> into tensor<5x2x3xf32>
+// CHECK: %[[VAL_39:.*]] = tensor.cast %[[VAL_38]] : tensor<5x2x3xf32> to tensor<?x?x?xf32>
+// CHECK: return %[[VAL_39]] : tensor<?x?x?xf32>
+// CHECK: }
+func.func @sparse_reshape_fused(%arg0: tensor<5x6xf32>, %arg1: tensor<6x2x3xf32, #COO_3D>) -> tensor<?x?x?xf32> {
+ %collapsed = tensor.collapse_shape %arg1 [[0], [1, 2]] : tensor<6x2x3xf32, #COO_3D> into tensor<6x6xf32, #COO_2D>
+ %0 = tensor.empty() : tensor<5x6xf32>
+ %2 = linalg.matmul ins(%arg0, %collapsed : tensor<5x6xf32>, tensor<6x6xf32, #COO_2D>) outs(%0 : tensor<5x6xf32>) -> tensor<5x6xf32>
+ %expanded = tensor.expand_shape %2 [[0], [1, 2]] : tensor<5x6xf32> into tensor<5x2x3xf32>
+ %ret1 = tensor.cast %expanded : tensor<5x2x3xf32> to tensor<?x?x?xf32>
+ return %ret1 : tensor<?x?x?xf32>
+}
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/reshape_dot.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/reshape_dot.mlir
new file mode 100644
index 0000000000000..33606b01c89b7
--- /dev/null
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/reshape_dot.mlir
@@ -0,0 +1,108 @@
+// DEFINE: %{option} = enable-runtime-library=false
+// DEFINE: %{compile} = mlir-opt %s --sparse-compiler=%{option}
+// DEFINE: %{run} = TENSOR0="%mlir_src_dir/test/Integration/data/test.mtx" \
+// DEFINE: mlir-cpu-runner \
+// DEFINE: -e entry -entry-point-result=void \
+// DEFINE: -shared-libs=%mlir_c_runner_utils,%mlir_runner_utils | \
+// DEFINE: FileCheck %s
+//
+// RUN: %{compile} | %{run}
+//
+// Do the same run, but now with direct IR generation.
+// REDEFINE: %{option} = "enable-runtime-library=true"
+// RUN: %{compile} | %{run}
+//
+// Do the same run, but now with direct IR generation and vectorization.
+// REDEFINE: %{option} = "enable-runtime-library=false vl=2 reassociate-fp-reductions=true enable-index-optimizations=true"
+// RUN: %{compile} | %{run}
+
+#COO_2D = #sparse_tensor.encoding<{ dimLevelType = [ "compressed-nu", "singleton" ], pointerBitWidth = 32, indexBitWidth = 32 }>
+#COO_3D = #sparse_tensor.encoding<{ dimLevelType = [ "compressed-nu", "singleton-nu", "singleton" ], pointerBitWidth = 32, indexBitWidth = 32 }>
+
+module {
+ func.func private @printMemref3dF32(%ptr : tensor<?x?x?xf32>) attributes { llvm.emit_c_interface }
+ func.func private @printMemref2dF32(%ptr : tensor<?x?xf32>) attributes { llvm.emit_c_interface }
+
+ func.func @test_sparse(%arg0: tensor<5x6xf32>, %arg1: tensor<6x2x3xf32, #COO_3D>) -> tensor<?x?x?xf32> {
+ %collapsed = tensor.collapse_shape %arg1 [[0], [1, 2]] : tensor<6x2x3xf32, #COO_3D> into tensor<6x6xf32, #COO_2D>
+ %0 = tensor.empty() : tensor<5x6xf32>
+ %cst = arith.constant 0.000000e+00 : f32
+ %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<5x6xf32>) -> tensor<5x6xf32>
+ %2 = linalg.matmul ins(%arg0, %collapsed : tensor<5x6xf32>, tensor<6x6xf32, #COO_2D>) outs(%1 : tensor<5x6xf32>) -> tensor<5x6xf32>
+ %expanded = tensor.expand_shape %2 [[0], [1, 2]] : tensor<5x6xf32> into tensor<5x2x3xf32>
+ %ret1 = tensor.cast %expanded : tensor<5x2x3xf32> to tensor<?x?x?xf32>
+ return %ret1 : tensor<?x?x?xf32>
+ }
+
+ func.func @test_dense(%arg0: tensor<5x6xf32>, %arg1: tensor<6x2x3xf32>) -> tensor<?x?x?xf32> {
+ %collapsed = tensor.collapse_shape %arg1 [[0], [1, 2]] : tensor<6x2x3xf32> into tensor<6x6xf32>
+ %0 = tensor.empty() : tensor<5x6xf32>
+ %cst = arith.constant 0.000000e+00 : f32
+ %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<5x6xf32>) -> tensor<5x6xf32>
+ %2 = linalg.matmul ins(%arg0, %collapsed : tensor<5x6xf32>, tensor<6x6xf32>) outs(%1 : tensor<5x6xf32>) -> tensor<5x6xf32>
+ %expanded = tensor.expand_shape %2 [[0], [1, 2]] : tensor<5x6xf32> into tensor<5x2x3xf32>
+ %ret1 = tensor.cast %expanded : tensor<5x2x3xf32> to tensor<?x?x?xf32>
+ return %ret1 : tensor<?x?x?xf32>
+ }
+
+
+ func.func @entry() {
+ // Setup two sparse vectors.
+ %d1 = arith.constant sparse<
+ [ [0, 0], [1, 1], [2, 2], [2, 3], [4, 5] ],
+ [1.0, 2.0, 3.0, 4.0, 5.0]
+ > : tensor<5x6xf32>
+
+ %d2 = arith.constant sparse<
+ [ [0, 0, 0], [1, 1, 1], [2, 1, 1] ],
+ [ 6.0, 7.0, 8.0]
+ > : tensor<6x2x3xf32>
+
+ // CHECK: Memref base@ = {{.*}} rank = 3 offset = 0 sizes = [5, 2, 3] strides = [6, 3, 1] data =
+ // CHECK-NEXT:[
+ // CHECK-SAME: [
+ // CHECK-SAME: [6, 0, 0],
+ // CHECK-NEXT: [0, 0, 0]],
+ // CHECK-NEXT: [
+ // CHECK-SAME: [0, 0, 0],
+ // CHECK-NEXT: [0, 14, 0]],
+ // CHECK-NEXT: [
+ // CHECK-SAME: [0, 0, 0],
+ // CHECK-NEXT: [0, 24, 0]],
+ // CHECK-NEXT: [
+ // CHECK-SAME: [0, 0, 0],
+ // CHECK-NEXT: [0, 0, 0]],
+ // CHECK-NEXT: [
+ // CHECK-SAME: [0, 0, 0],
+ // CHECK-NEXT: [0, 0, 0]]]
+ %do1 = call @test_dense(%d1, %d2) : (tensor<5x6xf32>, tensor<6x2x3xf32>) -> tensor<?x?x?xf32>
+ call @printMemref3dF32(%do1) : (tensor<?x?x?xf32>) -> ()
+
+ // Same results.
+ // CHECK-NEXT: Memref base@ = {{.*}} rank = 3 offset = 0 sizes = [5, 2, 3] strides = [6, 3, 1] data =
+ // CHECK-NEXT:[
+ // CHECK-SAME: [
+ // CHECK-SAME: [6, 0, 0],
+ // CHECK-NEXT: [0, 0, 0]],
+ // CHECK-NEXT: [
+ // CHECK-SAME: [0, 0, 0],
+ // CHECK-NEXT: [0, 14, 0]],
+ // CHECK-NEXT: [
+ // CHECK-SAME: [0, 0, 0],
+ // CHECK-NEXT: [0, 24, 0]],
+ // CHECK-NEXT: [
+ // CHECK-SAME: [0, 0, 0],
+ // CHECK-NEXT: [0, 0, 0]],
+ // CHECK-NEXT: [
+ // CHECK-SAME: [0, 0, 0],
+ // CHECK-NEXT: [0, 0, 0]]]
+ %s2 = sparse_tensor.convert %d2 : tensor<6x2x3xf32> to tensor<6x2x3xf32, #COO_3D>
+ %so1 = call @test_sparse(%d1, %s2): (tensor<5x6xf32>, tensor<6x2x3xf32, #COO_3D>) -> tensor<?x?x?xf32>
+ call @printMemref3dF32(%so1) : (tensor<?x?x?xf32>) -> ()
+
+ bufferization.dealloc_tensor %s2 : tensor<6x2x3xf32, #COO_3D>
+ bufferization.dealloc_tensor %do1 : tensor<?x?x?xf32>
+ bufferization.dealloc_tensor %so1 : tensor<?x?x?xf32>
+ return
+ }
+}
More information about the Mlir-commits
mailing list