[Mlir-commits] [mlir] fde04ae - [mlir][sparse] refine bufferization allocation lowering
Aart Bik
llvmlistbot at llvm.org
Tue Jun 21 15:17:34 PDT 2022
Author: Aart Bik
Date: 2022-06-21T15:17:25-07:00
New Revision: fde04aee33f4f530c8b3942210fe3daa69e915a7
URL: https://github.com/llvm/llvm-project/commit/fde04aee33f4f530c8b3942210fe3daa69e915a7
DIFF: https://github.com/llvm/llvm-project/commit/fde04aee33f4f530c8b3942210fe3daa69e915a7.diff
LOG: [mlir][sparse] refine bufferization allocation lowering
Marking bufferization allocation operation as invalid
during sparse lowering is too strict, since dense and
sparse allocation can co-exist. This revision refines
the lowering with a dynamic type check.
Reviewed By: bixia
Differential Revision: https://reviews.llvm.org/D128305
Added:
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
mlir/test/Dialect/SparseTensor/conversion.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index 5bbb4a5e1ddac..66b92ec3b3601 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -92,9 +92,9 @@ struct SparseTensorConversionPass
ConversionTarget target(*ctx);
// Everything in the sparse dialect must go!
target.addIllegalDialect<SparseTensorDialect>();
- // All dynamic rules below accept new function, call, return, and tensor
- // dim and cast operations as legal output of the rewriting provided that
- // all sparse tensor types have been fully rewritten.
+ // All dynamic rules below accept new function, call, return, and various
+ // tensor and bufferization operations as legal output of the rewriting
+ // provided that all sparse tensor types have been fully rewritten.
target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
return converter.isSignatureLegal(op.getFunctionType());
});
@@ -110,6 +110,10 @@ struct SparseTensorConversionPass
target.addDynamicallyLegalOp<tensor::CastOp>([&](tensor::CastOp op) {
return converter.isLegal(op.getOperand().getType());
});
+ target.addDynamicallyLegalOp<bufferization::AllocTensorOp>(
+ [&](bufferization::AllocTensorOp op) {
+ return converter.isLegal(op.getType());
+ });
// The following operations and dialects may be introduced by the
// rewriting rules, and are therefore marked as legal.
target.addLegalOp<arith::CmpFOp, arith::CmpIOp, arith::ConstantOp,
@@ -119,7 +123,6 @@ struct SparseTensorConversionPass
target
.addLegalDialect<bufferization::BufferizationDialect, LLVM::LLVMDialect,
memref::MemRefDialect, scf::SCFDialect>();
- target.addIllegalOp<bufferization::AllocTensorOp>();
// Translate strategy flags to strategy options.
SparseTensorConversionOptions options(
sparseToSparseConversionStrategy(sparseToSparse));
diff --git a/mlir/test/Dialect/SparseTensor/conversion.mlir b/mlir/test/Dialect/SparseTensor/conversion.mlir
index 2a85d012b98e3..d9b3ed1c2a3bc 100644
--- a/mlir/test/Dialect/SparseTensor/conversion.mlir
+++ b/mlir/test/Dialect/SparseTensor/conversion.mlir
@@ -572,3 +572,14 @@ func.func @sparse_out2(%arg0: tensor<?x?x?xf32, #SparseTensor>, %arg1: !llvm.ptr
sparse_tensor.out %arg0, %arg1 : tensor<?x?x?xf32, #SparseTensor>, !llvm.ptr<i8>
return
}
+
+// CHECK-LABEL: func @sparse_and_dense_init(
+// CHECK: %[[S:.*]] = call @newSparseTensor
+// CHECK: %[[D:.*]] = bufferization.alloc_tensor
+// CHECK: return %[[S]], %[[D]] : !llvm.ptr<i8>, tensor<?x?xf64>
+func.func @sparse_and_dense_init(%arg0: index, %arg1: index)
+ -> (tensor<?x?xf64, #SparseMatrix>, tensor<?x?xf64>) {
+ %0 = bufferization.alloc_tensor(%arg0, %arg1) : tensor<?x?xf64, #SparseMatrix>
+ %1 = bufferization.alloc_tensor(%arg0, %arg1) : tensor<?x?xf64>
+ return %0, %1 : tensor<?x?xf64, #SparseMatrix>, tensor<?x?xf64>
+}
More information about the Mlir-commits
mailing list