[Mlir-commits] [mlir] 19cde2d - [mlir][sparse] Improve concatenate operation conversion for the case with annotated all dense result.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Dec 7 12:06:55 PST 2022
Author: bixia1
Date: 2022-12-07T12:06:50-08:00
New Revision: 19cde2df95f379d05fbb599f7d601003718dc91a
URL: https://github.com/llvm/llvm-project/commit/19cde2df95f379d05fbb599f7d601003718dc91a
DIFF: https://github.com/llvm/llvm-project/commit/19cde2df95f379d05fbb599f7d601003718dc91a.diff
LOG: [mlir][sparse] Improve concatenate operation conversion for the case with annotated all dense result.
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D139345
Added:
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
mlir/test/Dialect/SparseTensor/sparse_concat.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index eb2b567a22219..ba3c94c1330f9 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -1318,7 +1318,7 @@ class SparseTensorConcatConverter : public OpConversionPattern<ConcatenateOp> {
matchAndRewrite(ConcatenateOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// The conversion works as follow:
- // (1). When output is sparse, and mix of inputs:
+ // (1). When output is sparse and not all dims are dense, and mix of inputs:
// a_sparse = concat (b_dense, c_sparse, ....)
// =>
// coo_for_a = newSparseCOO(shapeOf(a))
@@ -1331,10 +1331,10 @@ class SparseTensorConcatConverter : public OpConversionPattern<ConcatenateOp> {
// a = newSparseTensor(coo_for_a)
// return a
//
- // (2). When output is dense, and mix of inputs:
+ // (2). When output is dense or annotated all dense, and mix of inputs:
// a_dense = concat (b_dense, c_sparse, ....)
// =>
- // a = malloc(shapeOf(a))
+ // a = malloc(shapeOf(a)) or newSparseAllDense(shapeOf(a))
// for i, j, k // dense input
// a[ adjustForOffset(i,j,k) ] = b[i,j,k]
//
@@ -1362,18 +1362,50 @@ class SparseTensorConcatConverter : public OpConversionPattern<ConcatenateOp> {
concatSizesFromInputs(rewriter, sizes, loc, dstTp, op.getInputs(),
concatDim);
+ bool allDense = false;
+ Value dstTensor;
if (encDst) {
- // Start a new COO for the destination tensor.
- dst =
- params.genBuffers(encDst, sizes, dstTp).genNewCall(Action::kEmptyCOO);
- dstPerm = params.getDim2LvlMap();
- elemPtr = genAllocaScalar(rewriter, loc, elemTp);
+ allDense = llvm::all_of(encDst.getDimLevelType(),
+ [](DimLevelType dlt) { return isDenseDLT(dlt); });
+ // Start a new COO or an initialized annotated all dense sparse tensor.
+ dst = params.genBuffers(encDst, sizes, dstTp)
+ .genNewCall(allDense ? Action::kEmpty : Action::kEmptyCOO);
dstIdx = genAlloca(rewriter, loc, rank, rewriter.getIndexType());
+ if (allDense) {
+ dstTensor = dst;
+ // Get the values buffer for the sparse tensor and reshape it to the
+ // corresponding dense tensor shape.
+ dst = genValuesCall(rewriter, loc,
+ MemRefType::get({ShapedType::kDynamic}, elemTp),
+ {dst});
+
+ // Use the dstIdx to store the level sizes.
+ SmallVector<Value> lvlSizes;
+ for (unsigned i = 0; i < sizes.size(); i++)
+ lvlSizes.push_back(sizes[toOrigDim(encDst, i)]);
+ storeIndices(rewriter, loc, rank, dstIdx, lvlSizes);
+ // The memref ReshapeOp requires the sizes buffer to have a static
+ // shape.
+ Value typedBuffer = rewriter.create<memref::CastOp>(
+ loc, MemRefType::get({rank}, rewriter.getIndexType()), dstIdx);
+ SmallVector<int64_t> shape(rank, ShapedType::kDynamic);
+ dst = rewriter.create<memref::ReshapeOp>(
+ loc, MemRefType::get(shape, elemTp), dst, typedBuffer);
+ } else {
+ dstPerm = params.getDim2LvlMap();
+ elemPtr = genAllocaScalar(rewriter, loc, elemTp);
+ }
} else {
// TODO: Dense buffers should be allocated/deallocated via the callback
// in BufferizationOptions.
dst = allocDenseTensor(rewriter, loc, dstTp, sizes);
}
+ auto dimIdx2LvlIdx = [&](ValueRange dIdx) -> SmallVector<Value> {
+ SmallVector<Value> lIdx;
+ for (unsigned i = 0; i < dIdx.size(); i++)
+ lIdx.push_back(dIdx[toOrigDim(encDst, i)]);
+ return lIdx;
+ };
for (auto it : llvm::zip(op.getInputs(), adaptor.getInputs())) {
Value orignalOp = std::get<0>(it); // Input (with encoding) from Op
Value adaptedOp = std::get<1>(it); // Input (type converted) from adaptor
@@ -1384,24 +1416,29 @@ class SparseTensorConcatConverter : public OpConversionPattern<ConcatenateOp> {
rewriter, loc, adaptedOp, srcTp,
[&](OpBuilder &builder, Location loc, Value idx,
Value elemPtr) -> void {
- auto indVec =
+ SmallVector<Value> dimInd =
loadIndices(builder, loc, rank, idx, concatDim, offset);
- if (encDst) {
- // Case: sparse => sparse
- storeIndices(builder, loc, rank, dstIdx, indVec);
+ if (encDst && !allDense) {
+ // Case: sparse => sparse, except for annotated all dense.
+ storeIndices(builder, loc, rank, dstIdx, dimInd);
genAddEltCall(builder, loc, elemTp, dst, elemPtr, dstIdx,
dstPerm);
} else {
- // Case: sparse => dense
- insertScalarIntoDenseTensor(builder, loc, elemPtr, dst, indVec);
+ // Case: sparse => dense, or annotated all dense.
+ SmallVector<Value> lvlInd;
+ if (allDense)
+ lvlInd = dimIdx2LvlIdx(dimInd);
+ else
+ lvlInd = dimInd;
+ insertScalarIntoDenseTensor(builder, loc, elemPtr, dst, lvlInd);
}
});
} else {
genDenseTensorIterationLoop(
rewriter, loc, adaptedOp, srcTp,
[&](OpBuilder &builder, Location loc, ValueRange idx) -> void {
- if (encDst) {
- // Case: dense => sparse
+ if (encDst && !allDense) {
+ // Case: dense => sparse, except for annotated all dense.
storeIndices(builder, loc, rank, dstIdx, idx, concatDim,
offset);
Value val = genValueForDense(builder, loc, adaptedOp, idx);
@@ -1409,13 +1446,15 @@ class SparseTensorConcatConverter : public OpConversionPattern<ConcatenateOp> {
genAddEltCall(builder, loc, elemTp, dst, elemPtr, dstIdx,
dstPerm);
} else {
- // Case: dense => dense
+ // Case: dense => dense, or annotated all dense.
Value val = genValueForDense(builder, loc, adaptedOp, idx);
- SmallVector<Value> indVec(idx);
+ SmallVector<Value> lvlInd(idx);
// Apply offset.
- indVec[concatDim] = builder.create<arith::AddIOp>(
- loc, indVec[concatDim], offset);
- builder.create<memref::StoreOp>(loc, val, dst, indVec);
+ lvlInd[concatDim] = builder.create<arith::AddIOp>(
+ loc, lvlInd[concatDim], offset);
+ if (allDense)
+ lvlInd = dimIdx2LvlIdx(lvlInd);
+ builder.create<memref::StoreOp>(loc, val, dst, lvlInd);
}
});
}
@@ -1427,11 +1466,15 @@ class SparseTensorConcatConverter : public OpConversionPattern<ConcatenateOp> {
offset = rewriter.create<arith::AddIOp>(loc, offset, curDim);
}
if (encDst) {
- // In sparse output case, the destination holds the COO.
- Value coo = dst;
- dst = params.genNewCall(Action::kFromCOO, coo);
- // Release resources.
- genDelCOOCall(rewriter, loc, elemTp, coo);
+ if (!allDense) {
+ // In sparse output case, the destination holds the COO.
+ Value coo = dst;
+ dst = params.genNewCall(Action::kFromCOO, coo);
+ // Release resources.
+ genDelCOOCall(rewriter, loc, elemTp, coo);
+ } else {
+ dst = dstTensor;
+ }
rewriter.replaceOp(op, dst);
} else {
rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, dstTp, dst);
diff --git a/mlir/test/Dialect/SparseTensor/sparse_concat.mlir b/mlir/test/Dialect/SparseTensor/sparse_concat.mlir
index 021a76c388d5e..f997157ba8f6d 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_concat.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_concat.mlir
@@ -7,6 +7,11 @@
dimOrdering = affine_map<(i,j) -> (j,i)>
}>
+#SparseMatrix_D_P = #sparse_tensor.encoding<{
+ dimLevelType = [ "dense", "dense" ],
+ dimOrdering = affine_map<(i,j) -> (j,i)>
+}>
+
// CHECK-LABEL: func.func @concat_mix_dense(
// CHECK-SAME: %[[TMP_arg0:.*]]: tensor<2x4xf64>,
// CHECK-SAME: %[[TMP_arg1:.*]]: !llvm.ptr<i8>)
@@ -102,9 +107,9 @@ func.func @concat_mix_dense(%arg0: tensor<2x4xf64>, %arg1: tensor<3x4xf64, #Spar
// CHECK-DAG: memref.store %[[TMP_c1]], %[[Iota_0]][%[[TMP_c1]]] : memref<2xindex>
// CHECK-DAG: %[[NullPtr:.*]] = llvm.mlir.null : !llvm.ptr<i8>
// CHECK: %[[TMP_7:.*]] = call @newSparseTensor(%[[DimSizesP_0]], %[[LvlSizesP_0]], %[[LvlTypesP_0]], %[[IotaP_0]], %[[IotaP_0]], %[[TMP_c0_i32]], %[[TMP_c0_i32]], %[[TMP_c1_i32]], %[[TMP_c4_i32]], %[[NullPtr]])
-// CHECK: %[[TMP_8:.*]] = memref.alloca() : memref<f64>
// CHECK: %[[TMP_9:.*]] = memref.alloca() : memref<2xindex>
// CHECK: %[[TMP_10:.*]] = memref.cast %[[TMP_9]] : memref<2xindex> to memref<?xindex>
+// CHECK: %[[TMP_8:.*]] = memref.alloca() : memref<f64>
// CHECK: scf.for %[[TMP_arg2:.*]] = %[[TMP_c0]] to %[[TMP_c2]] step %[[TMP_c1]] {
// CHECK: scf.for %[[TMP_arg3:.*]] = %[[TMP_c0]] to %[[TMP_c4]] step %[[TMP_c1]] {
// CHECK: memref.store %[[TMP_arg2]], %[[TMP_9]][%[[TMP_c0]]] : memref<2xindex>
@@ -192,9 +197,9 @@ func.func @concat_mix_sparse(%arg0: tensor<2x4xf64>, %arg1: tensor<3x4xf64, #Spa
// CHECK-DAG: memref.store %[[TMP_c0]], %[[Dim2Lvl_0]][%[[TMP_c1]]] : memref<2xindex>
// CHECK-DAG: %[[NullPtr:.*]] = llvm.mlir.null : !llvm.ptr<i8>
// CHECK: %[[TMP_7:.*]] = call @newSparseTensor(%[[DimSizesP_0]], %[[LvlSizesP_0]], %[[LvlTypesP_0]], %[[Lvl2DimP_0]], %[[Dim2LvlP_0]], %[[TMP_c0_i32]], %[[TMP_c0_i32]], %[[TMP_c1_i32]], %[[TMP_c4_i32]], %[[NullPtr]])
-// CHECK: %[[TMP_8:.*]] = memref.alloca() : memref<f64>
// CHECK: %[[TMP_9:.*]] = memref.alloca() : memref<2xindex>
// CHECK: %[[TMP_10:.*]] = memref.cast %[[TMP_9]] : memref<2xindex> to memref<?xindex>
+// CHECK: %[[TMP_8:.*]] = memref.alloca() : memref<f64>
// CHECK: scf.for %[[TMP_arg2:.*]] = %[[TMP_c0]] to %[[TMP_c4]] step %[[TMP_c1]] {
// CHECK: scf.for %[[TMP_arg3:.*]] = %[[TMP_c0]] to %[[TMP_c2]] step %[[TMP_c1]] {
// CHECK: memref.store %[[TMP_arg2]], %[[TMP_9]][%[[TMP_c0]]] : memref<2xindex>
@@ -367,10 +372,91 @@ func.func @concat_mix_dense_perm_dim1(%arg0: tensor<4x2xf64>, %arg1: tensor<4x3x
// CHECK: call @delSparseTensorIteratorF64(%[[TMP_8]]) : (!llvm.ptr<i8>) -> ()
// CHECK: %[[TMP_12:.*]] = bufferization.to_tensor %[[TMP_1]] : memref<?x?xf64>
// CHECK: return %[[TMP_12]] : tensor<?x?xf64>
-// CHECK: }
// CHECK: }
func.func @concat_mix_dense_perm_dim1_dyn(%arg0: tensor<3x2xf64>, %arg1: tensor<3x3xf64, #SparseMatrix>) -> tensor<?x?xf64> {
%0 = sparse_tensor.concatenate %arg0, %arg1 {dimension = 1 : index}
: tensor<3x2xf64>, tensor<3x3xf64, #SparseMatrix> to tensor<?x?xf64>
return %0 : tensor<?x?xf64>
}
+
+// CHECK-LABEL: func.func @concat_annotated_dense(
+// CHECK-SAME: %[[TMP_arg0:.*]]: tensor<4x2xf64>,
+// CHECK-SAME: %[[TMP_arg1:.*]]: !llvm.ptr<i8>)
+// CHECK-DAG: %[[TMP_c2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[TMP_c6_i32:.*]] = arith.constant 6 : i32
+// CHECK-DAG: %[[TMP_c4_i8:.*]] = arith.constant 4 : i8
+// CHECK-DAG: %[[TMP_c8_i8:.*]] = arith.constant 8 : i8
+// CHECK-DAG: %[[TMP_c3:.*]] = arith.constant 3 : index
+// CHECK-DAG: %[[TMP_cst:.*]] = arith.constant 0.000000e+00 : f64
+// CHECK-DAG: %[[TMP_c1_i32:.*]] = arith.constant 1 : i32
+// CHECK-DAG: %[[TMP_c0_i32:.*]] = arith.constant 0 : i32
+// CHECK-DAG: %[[TMP_c1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[TMP_c0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[TMP_c4:.*]] = arith.constant 4 : index
+// CHECK-DAG: %[[TMP_c5:.*]] = arith.constant 5 : index
+// CHECK-DAG: %[[LvlTypes_0:.*]] = memref.alloca() : memref<2xi8>
+// CHECK-DAG: %[[LvlTypesP_0:.*]] = memref.cast %[[LvlTypes_0]] : memref<2xi8> to memref<?xi8>
+// CHECK-DAG: memref.store %[[TMP_c4_i8]], %[[LvlTypes_0]][%[[TMP_c0]]] : memref<2xi8>
+// CHECK-DAG: memref.store %[[TMP_c4_i8]], %[[LvlTypes_0]][%[[TMP_c1]]] : memref<2xi8>
+// CHECK-DAG: %[[DimSizes_0:.*]] = memref.alloca() : memref<2xindex>
+// CHECK-DAG: %[[DimSizesP_0:.*]] = memref.cast %[[DimSizes_0]] : memref<2xindex> to memref<?xindex>
+// CHECK-DAG: memref.store %[[TMP_c4]], %[[DimSizes_0]][%[[TMP_c0]]] : memref<2xindex>
+// CHECK-DAG: memref.store %[[TMP_c5]], %[[DimSizes_0]][%[[TMP_c1]]] : memref<2xindex>
+// CHECK-DAG: %[[LvlSizes_0:.*]] = memref.alloca() : memref<2xindex>
+// CHECK-DAG: %[[LvlSizesP_0:.*]] = memref.cast %[[LvlSizes_0]] : memref<2xindex> to memref<?xindex>
+// CHECK-DAG: %[[Lvl2Dim_0:.*]] = memref.alloca() : memref<2xindex>
+// CHECK-DAG: %[[Lvl2DimP_0:.*]] = memref.cast %[[Lvl2Dim_0]] : memref<2xindex> to memref<?xindex>
+// CHECK-DAG: %[[Dim2Lvl_0:.*]] = memref.alloca() : memref<2xindex>
+// CHECK-DAG: %[[Dim2LvlP_0:.*]] = memref.cast %[[Dim2Lvl_0]] : memref<2xindex> to memref<?xindex>
+// CHECK-DAG: memref.store %[[TMP_c1]], %[[Dim2Lvl_0]][%[[TMP_c0]]] : memref<2xindex>
+// CHECK-DAG: memref.store %[[TMP_c0]], %[[Dim2Lvl_0]][%[[TMP_c1]]] : memref<2xindex>
+// CHECK-DAG: %[[NullPtr:.*]] = llvm.mlir.null : !llvm.ptr<i8>
+// CHECK: %[[TMP_7:.*]] = call @newSparseTensor(%[[DimSizesP_0]], %[[LvlSizesP_0]], %[[LvlTypesP_0]], %[[Lvl2DimP_0]], %[[Dim2LvlP_0]], %[[TMP_c0_i32]], %[[TMP_c0_i32]], %[[TMP_c1_i32]], %[[TMP_c0_i32]], %[[NullPtr]])
+// CHECK: %[[Values_r:.*]] = call @sparseValuesF64(%[[TMP_7]]) : (!llvm.ptr<i8>) -> memref<?xf64>
+// CHECK: %[[Values:.*]] = memref.reshape %[[Values_r]]
+// CHECK: scf.for %[[TMP_arg2:.*]] = %[[TMP_c0]] to %[[TMP_c4]] step %[[TMP_c1]] {
+// CHECK: scf.for %[[TMP_arg3:.*]] = %[[TMP_c0]] to %[[TMP_c2]] step %[[TMP_c1]] {
+// CHECK: %[[TMP_22:.*]] = tensor.extract %[[TMP_arg0]][%[[TMP_arg2]], %[[TMP_arg3]]] : tensor<4x2xf64>
+// CHECK: %[[TMP_23:.*]] = arith.cmpf une, %[[TMP_22]], %[[TMP_cst]] : f64
+// CHECK: scf.if %[[TMP_23]] {
+// CHECK: memref.store %[[TMP_22]], %[[Values]][%[[TMP_arg3]], %[[TMP_arg2]]] : memref<?x?xf64>
+// CHECK: }
+// CHECK: }
+// CHECK: }
+// CHECK-DAG: %[[LvlTypes_1:.*]] = memref.alloca() : memref<2xi8>
+// CHECK-DAG: %[[LvlTypesP_1:.*]] = memref.cast %[[LvlTypes_1]] : memref<2xi8> to memref<?xi8>
+// CHECK-DAG: memref.store %[[TMP_c8_i8]], %[[LvlTypes_1]][%[[TMP_c0]]] : memref<2xi8>
+// CHECK-DAG: memref.store %[[TMP_c8_i8]], %[[LvlTypes_1]][%[[TMP_c1]]] : memref<2xi8>
+// CHECK-DAG: %[[DimSizes_1:.*]] = memref.alloca() : memref<2xindex>
+// CHECK-DAG: %[[DimSizesP_1:.*]] = memref.cast %[[DimSizes_1]] : memref<2xindex> to memref<?xindex>
+// CHECK-DAG: memref.store %[[TMP_c4]], %[[DimSizes_1]][%[[TMP_c0]]] : memref<2xindex>
+// CHECK-DAG: memref.store %[[TMP_c3]], %[[DimSizes_1]][%[[TMP_c1]]] : memref<2xindex>
+// CHECK-DAG: %[[LvlSizes_1:.*]] = memref.alloca() : memref<2xindex>
+// CHECK-DAG: %[[LvlSizesP_1:.*]] = memref.cast %[[LvlSizes_1]] : memref<2xindex> to memref<?xindex>
+// CHECK-DAG: %[[Iota_1:.*]] = memref.alloca() : memref<2xindex>
+// CHECK-DAG: %[[IotaP_1:.*]] = memref.cast %[[Iota_1]] : memref<2xindex> to memref<?xindex>
+// CHECK-DAG: memref.store %[[TMP_c0]], %[[Iota_1]][%[[TMP_c0]]] : memref<2xindex>
+// CHECK-DAG: memref.store %[[TMP_c1]], %[[Iota_1]][%[[TMP_c1]]] : memref<2xindex>
+// CHECK: %[[TMP_17:.*]] = call @newSparseTensor(%[[DimSizesP_1]], %[[LvlSizesP_1]], %[[LvlTypesP_1]], %[[IotaP_1]], %[[IotaP_1]], %[[TMP_c0_i32]], %[[TMP_c0_i32]], %[[TMP_c1_i32]], %[[TMP_c6_i32]], %[[TMP_arg1]])
+// CHECK: %[[TMP_18:.*]] = memref.alloca() : memref<2xindex>
+// CHECK: %[[TMP_19:.*]] = memref.cast %[[TMP_18]] : memref<2xindex> to memref<?xindex>
+// CHECK: %[[TMP_20:.*]] = memref.alloca() : memref<f64>
+// CHECK: scf.while : () -> () {
+// CHECK: %[[TMP_22:.*]] = func.call @getNextF64(%[[TMP_17]], %[[TMP_19]], %[[TMP_20]]) : (!llvm.ptr<i8>, memref<?xindex>, memref<f64>) -> i1
+// CHECK: scf.condition(%[[TMP_22]])
+// CHECK: } do {
+// CHECK: %[[TMP_22:.*]] = memref.load %[[TMP_18]][%[[TMP_c0]]] : memref<2xindex>
+// CHECK: %[[TMP_23:.*]] = memref.load %[[TMP_18]][%[[TMP_c1]]] : memref<2xindex>
+// CHECK: %[[TMP_24:.*]] = arith.addi %[[TMP_23]], %[[TMP_c2]] : index
+// CHECK: %[[TMP_25:.*]] = memref.load %[[TMP_20]][] : memref<f64>
+// CHECK: memref.store %[[TMP_25]], %[[Values]][%[[TMP_24]], %[[TMP_22]]] : memref<?x?xf64>
+// CHECK: scf.yield
+// CHECK: }
+// CHECK: call @delSparseTensorIteratorF64(%[[TMP_17]]) : (!llvm.ptr<i8>) -> ()
+// CHECK: return %[[TMP_7]] : !llvm.ptr<i8>
+// CHECK: }
+func.func @concat_annotated_dense(%arg0: tensor<4x2xf64>, %arg1: tensor<4x3xf64, #SparseMatrix_P>) -> tensor<4x5xf64, #SparseMatrix_D_P> {
+ %0 = sparse_tensor.concatenate %arg0, %arg1 {dimension = 1 : index}
+ : tensor<4x2xf64>, tensor<4x3xf64, #SparseMatrix_P> to tensor<4x5xf64, #SparseMatrix_D_P>
+ return %0 : tensor<4x5xf64, #SparseMatrix_D_P>
+}
More information about the Mlir-commits
mailing list