[Mlir-commits] [mlir] fde04ae - [mlir][sparse] refine bufferization allocation lowering

Aart Bik llvmlistbot at llvm.org
Tue Jun 21 15:17:34 PDT 2022


Author: Aart Bik
Date: 2022-06-21T15:17:25-07:00
New Revision: fde04aee33f4f530c8b3942210fe3daa69e915a7

URL: https://github.com/llvm/llvm-project/commit/fde04aee33f4f530c8b3942210fe3daa69e915a7
DIFF: https://github.com/llvm/llvm-project/commit/fde04aee33f4f530c8b3942210fe3daa69e915a7.diff

LOG: [mlir][sparse] refine bufferization allocation lowering

Marking bufferization allocation operation as invalid
during sparse lowering is too strict, since dense and
sparse allocation can co-exist. This revision refines
the lowering with a dynamic type check.

Reviewed By: bixia

Differential Revision: https://reviews.llvm.org/D128305

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
    mlir/test/Dialect/SparseTensor/conversion.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index 5bbb4a5e1ddac..66b92ec3b3601 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -92,9 +92,9 @@ struct SparseTensorConversionPass
     ConversionTarget target(*ctx);
     // Everything in the sparse dialect must go!
     target.addIllegalDialect<SparseTensorDialect>();
-    // All dynamic rules below accept new function, call, return, and tensor
-    // dim and cast operations as legal output of the rewriting provided that
-    // all sparse tensor types have been fully rewritten.
+    // All dynamic rules below accept new function, call, return, and various
+    // tensor and bufferization operations as legal output of the rewriting
+    // provided that all sparse tensor types have been fully rewritten.
     target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
       return converter.isSignatureLegal(op.getFunctionType());
     });
@@ -110,6 +110,10 @@ struct SparseTensorConversionPass
     target.addDynamicallyLegalOp<tensor::CastOp>([&](tensor::CastOp op) {
       return converter.isLegal(op.getOperand().getType());
     });
+    target.addDynamicallyLegalOp<bufferization::AllocTensorOp>(
+        [&](bufferization::AllocTensorOp op) {
+          return converter.isLegal(op.getType());
+        });
     // The following operations and dialects may be introduced by the
     // rewriting rules, and are therefore marked as legal.
     target.addLegalOp<arith::CmpFOp, arith::CmpIOp, arith::ConstantOp,
@@ -119,7 +123,6 @@ struct SparseTensorConversionPass
     target
         .addLegalDialect<bufferization::BufferizationDialect, LLVM::LLVMDialect,
                          memref::MemRefDialect, scf::SCFDialect>();
-    target.addIllegalOp<bufferization::AllocTensorOp>();
     // Translate strategy flags to strategy options.
     SparseTensorConversionOptions options(
         sparseToSparseConversionStrategy(sparseToSparse));

diff  --git a/mlir/test/Dialect/SparseTensor/conversion.mlir b/mlir/test/Dialect/SparseTensor/conversion.mlir
index 2a85d012b98e3..d9b3ed1c2a3bc 100644
--- a/mlir/test/Dialect/SparseTensor/conversion.mlir
+++ b/mlir/test/Dialect/SparseTensor/conversion.mlir
@@ -572,3 +572,14 @@ func.func @sparse_out2(%arg0: tensor<?x?x?xf32, #SparseTensor>, %arg1: !llvm.ptr
   sparse_tensor.out %arg0, %arg1 : tensor<?x?x?xf32, #SparseTensor>, !llvm.ptr<i8>
   return
 }
+
+// CHECK-LABEL: func @sparse_and_dense_init(
+//       CHECK: %[[S:.*]] = call @newSparseTensor
+//       CHECK: %[[D:.*]] = bufferization.alloc_tensor
+//       CHECK: return %[[S]], %[[D]] : !llvm.ptr<i8>, tensor<?x?xf64>
+func.func @sparse_and_dense_init(%arg0: index, %arg1: index)
+           -> (tensor<?x?xf64, #SparseMatrix>, tensor<?x?xf64>) {
+  %0 = bufferization.alloc_tensor(%arg0, %arg1) : tensor<?x?xf64, #SparseMatrix>
+  %1 = bufferization.alloc_tensor(%arg0, %arg1) : tensor<?x?xf64>
+  return %0, %1 : tensor<?x?xf64, #SparseMatrix>, tensor<?x?xf64>
+}


        


More information about the Mlir-commits mailing list