[Mlir-commits] [mlir] 81cb70e - [mlir][sparse] fix a bug in UnpackOp converter.
Peiming Liu
llvmlistbot at llvm.org
Tue Feb 14 18:36:06 PST 2023
Author: Peiming Liu
Date: 2023-02-15T02:36:00Z
New Revision: 81cb70e46ea1edb16fe97b5e44e3d710d64b2dbb
URL: https://github.com/llvm/llvm-project/commit/81cb70e46ea1edb16fe97b5e44e3d710d64b2dbb
DIFF: https://github.com/llvm/llvm-project/commit/81cb70e46ea1edb16fe97b5e44e3d710d64b2dbb.diff
LOG: [mlir][sparse] fix a bug in UnpackOp converter.
UnpackOp Converter used to create reallocOp unconditionally, but it might cause issue when the requested memory size is smaller than the actually storage.
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D144065
Added:
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
mlir/test/Dialect/SparseTensor/sparse_pack.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 17319b7ffa2a7..797a31892306f 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -575,6 +575,34 @@ static void genEndInsert(OpBuilder &builder, Location loc,
}
}
+/// Returns a memref that fits the requested length (reallocates if requested
+/// length is larger, or creates a subview if it is smaller).
+static Value reallocOrSubView(OpBuilder &builder, Location loc, int64_t len,
+ Value buffer) {
+ MemRefType memTp = getMemRefType(buffer);
+ auto retTp = MemRefType::get(ArrayRef{len}, memTp.getElementType());
+
+ Value targetLen = constantIndex(builder, loc, len);
+ Value bufferLen = linalg::createOrFoldDimOp(builder, loc, buffer, 0);
+ Value reallocP = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
+ targetLen, bufferLen);
+ scf::IfOp ifOp = builder.create<scf::IfOp>(loc, retTp, reallocP, true);
+ // If targetLen > bufferLen, reallocate to get enough sparse to return.
+ builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+ Value reallocBuf = builder.create<memref::ReallocOp>(loc, retTp, buffer);
+ builder.create<scf::YieldOp>(loc, reallocBuf);
+ // Else, return a subview to fit the size.
+ builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
+ Value subViewBuf = builder.create<memref::SubViewOp>(
+ loc, retTp, buffer, /*offset=*/ArrayRef<int64_t>{0},
+ /*size=*/ArrayRef<int64_t>{len},
+ /*stride=*/ArrayRef<int64_t>{1});
+ builder.create<scf::YieldOp>(loc, subViewBuf);
+ // Resets insertion point.
+ builder.setInsertionPointAfter(ifOp);
+ return ifOp.getResult(0);
+}
+
//===----------------------------------------------------------------------===//
// Codegen rules.
//===----------------------------------------------------------------------===//
@@ -1174,16 +1202,13 @@ struct SparseUnpackOpConverter : public OpConversionPattern<UnpackOp> {
// to ensure that we meet their need.
TensorType dataTp = op.getData().getType();
if (dataTp.hasStaticShape()) {
- dataBuf = rewriter.create<memref::ReallocOp>(
- loc, MemRefType::get(dataTp.getShape(), dataTp.getElementType()),
- dataBuf);
+ dataBuf = reallocOrSubView(rewriter, loc, dataTp.getShape()[0], dataBuf);
}
TensorType indicesTp = op.getIndices().getType();
if (indicesTp.hasStaticShape()) {
auto len = indicesTp.getShape()[0] * indicesTp.getShape()[1];
- flatBuf = rewriter.create<memref::ReallocOp>(
- loc, MemRefType::get({len}, indicesTp.getElementType()), flatBuf);
+ flatBuf = reallocOrSubView(rewriter, loc, len, flatBuf);
}
Value idxBuf = rewriter.create<memref::ExpandShapeOp>(
diff --git a/mlir/test/Dialect/SparseTensor/sparse_pack.mlir b/mlir/test/Dialect/SparseTensor/sparse_pack.mlir
index eeb41fe7d00a2..057153a20c955 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_pack.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_pack.mlir
@@ -43,14 +43,33 @@ func.func @sparse_pack(%data: tensor<6xf64>, %index: tensor<6x2xi32>)
// CHECK-SAME: %[[VAL_1:.*]]: memref<?xi32>,
// CHECK-SAME: %[[VAL_2:.*]]: memref<?xf64>,
// CHECK-SAME: %[[VAL_3:.*]]: !sparse_tensor.storage_specifier
-// CHECK: %[[VAL_4:.*]] = memref.realloc %[[VAL_2]] : memref<?xf64> to memref<6xf64>
-// CHECK: %[[VAL_5:.*]] = memref.realloc %[[VAL_1]] : memref<?xi32> to memref<12xi32>
-// CHECK: %[[VAL_6:.*]] = memref.expand_shape %[[VAL_5]] {{\[\[}}0, 1]] : memref<12xi32> into memref<6x2xi32>
-// CHECK: %[[VAL_7:.*]] = bufferization.to_tensor %[[VAL_4]] : memref<6xf64>
-// CHECK: %[[VAL_8:.*]] = bufferization.to_tensor %[[VAL_6]] : memref<6x2xi32>
-// CHECK: %[[VAL_9:.*]] = sparse_tensor.storage_specifier.get %[[VAL_3]] val_mem_sz
-// CHECK: %[[VAL_10:.*]] = arith.index_cast %[[VAL_9]] : i32 to index
-// CHECK: return %[[VAL_7]], %[[VAL_8]], %[[VAL_10]] : tensor<6xf64>, tensor<6x2xi32>, index
+// CHECK: %[[VAL_4:.*]] = arith.constant 6 : index
+// CHECK: %[[VAL_5:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_6:.*]] = memref.dim %[[VAL_2]], %[[VAL_5]] : memref<?xf64>
+// CHECK: %[[VAL_7:.*]] = arith.cmpi ult, %[[VAL_4]], %[[VAL_6]] : index
+// CHECK: %[[VAL_8:.*]] = scf.if %[[VAL_7]] -> (memref<6xf64>) {
+// CHECK: %[[VAL_9:.*]] = memref.realloc %[[VAL_2]] : memref<?xf64> to memref<6xf64>
+// CHECK: scf.yield %[[VAL_9]] : memref<6xf64>
+// CHECK: } else {
+// CHECK: %[[VAL_10:.*]] = memref.subview %[[VAL_2]][0] [6] [1] : memref<?xf64> to memref<6xf64>
+// CHECK: scf.yield %[[VAL_10]] : memref<6xf64>
+// CHECK: }
+// CHECK: %[[VAL_11:.*]] = arith.constant 12 : index
+// CHECK: %[[VAL_12:.*]] = memref.dim %[[VAL_1]], %[[VAL_5]] : memref<?xi32>
+// CHECK: %[[VAL_13:.*]] = arith.cmpi ult, %[[VAL_11]], %[[VAL_12]] : index
+// CHECK: %[[VAL_14:.*]] = scf.if %[[VAL_13]] -> (memref<12xi32>) {
+// CHECK: %[[VAL_15:.*]] = memref.realloc %[[VAL_1]] : memref<?xi32> to memref<12xi32>
+// CHECK: scf.yield %[[VAL_15]] : memref<12xi32>
+// CHECK: } else {
+// CHECK: %[[VAL_16:.*]] = memref.subview %[[VAL_1]][0] [12] [1] : memref<?xi32> to memref<12xi32>
+// CHECK: scf.yield %[[VAL_16]] : memref<12xi32>
+// CHECK: }
+// CHECK: %[[VAL_17:.*]] = memref.expand_shape %[[VAL_18:.*]] {{\[\[}}0, 1]] : memref<12xi32> into memref<6x2xi32>
+// CHECK: %[[VAL_19:.*]] = bufferization.to_tensor %[[VAL_20:.*]] : memref<6xf64>
+// CHECK: %[[VAL_21:.*]] = bufferization.to_tensor %[[VAL_17]] : memref<6x2xi32>
+// CHECK: %[[VAL_22:.*]] = sparse_tensor.storage_specifier
+// CHECK: %[[VAL_23:.*]] = arith.index_cast %[[VAL_22]] : i32 to index
+// CHECK: return %[[VAL_19]], %[[VAL_21]], %[[VAL_23]] : tensor<6xf64>, tensor<6x2xi32>, index
// CHECK: }
func.func @sparse_unpack(%sp: tensor<100x100xf64, #COO>) -> (tensor<6xf64>, tensor<6x2xi32>, index) {
%d, %i, %nnz = sparse_tensor.unpack %sp : tensor<100x100xf64, #COO>
More information about the Mlir-commits
mailing list