[Mlir-commits] [mlir] ca5d0a7 - [mlir][sparse] keep runtime support library signature consistent

Aart Bik llvmlistbot at llvm.org
Wed May 12 10:00:03 PDT 2021


Author: Aart Bik
Date: 2021-05-12T09:59:46-07:00
New Revision: ca5d0a7310bfb21730ac6dd735e06502e7e45099

URL: https://github.com/llvm/llvm-project/commit/ca5d0a7310bfb21730ac6dd735e06502e7e45099
DIFF: https://github.com/llvm/llvm-project/commit/ca5d0a7310bfb21730ac6dd735e06502e7e45099.diff

LOG: [mlir][sparse] keep runtime support library signature consistent

Reviewed By: bixia

Differential Revision: https://reviews.llvm.org/D102285

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
    mlir/test/Dialect/SparseTensor/conversion.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
index 336e834cc109c..68adb6fe1db18 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
@@ -19,6 +19,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
   MLIRSCF
   MLIRStandard
   MLIRSparseTensor
+  MLIRTensor
   MLIRTransforms
   MLIRVector
 )

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index 71515fecb0606..a2c7b8516d41e 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -19,6 +19,7 @@
 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Transforms/DialectConversion.h"
 
 using namespace mlir;
@@ -103,15 +104,21 @@ class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
       return failure();
     // User pointer.
     params.push_back(operands[0]);
-    // Sparsity annotations.
+    // Sparsity annotations in tensor constant form. Note that we cast
+    // the static shape into a dynamic shape to ensure that the method
+    // signature remains uniform accross 
diff erent tensor dimensions.
     SmallVector<bool, 4> attrs;
     unsigned sz = enc.getDimLevelType().size();
     for (unsigned i = 0; i < sz; i++)
       attrs.push_back(enc.getDimLevelType()[i] ==
                       SparseTensorEncodingAttr::DimLevelType::Compressed);
-    auto elts = DenseElementsAttr::get(
-        RankedTensorType::get({sz}, rewriter.getIntegerType(1)), attrs);
-    params.push_back(rewriter.create<ConstantOp>(loc, elts));
+    Type etp = rewriter.getIntegerType(1);
+    RankedTensorType tt1 = RankedTensorType::get({sz}, etp);
+    RankedTensorType tt2 =
+        RankedTensorType::get({ShapedType::kDynamicSize}, etp);
+    auto elts =
+        rewriter.create<ConstantOp>(loc, DenseElementsAttr::get(tt1, attrs));
+    params.push_back(rewriter.create<tensor::CastOp>(loc, tt2, elts));
     // Seconary and primary types encoding.
     unsigned secPtr = getOverheadTypeEncoding(enc.getPointerBitWidth());
     unsigned secInd = getOverheadTypeEncoding(enc.getIndexBitWidth());

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index 641ba4af4363b..05fd7537a474e 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -120,6 +120,7 @@ struct SparseTensorConversionPass
     target.addDynamicallyLegalOp<ReturnOp>(
         [&](ReturnOp op) { return converter.isLegal(op.getOperandTypes()); });
     target.addLegalOp<ConstantOp>();
+    target.addLegalOp<tensor::CastOp>();
     populateFuncOpTypeConversionPattern(patterns, converter);
     populateCallOpTypeConversionPattern(patterns, converter);
     populateSparseTensorConversionPatterns(converter, patterns);

diff  --git a/mlir/test/Dialect/SparseTensor/conversion.mlir b/mlir/test/Dialect/SparseTensor/conversion.mlir
index 54bfa745dff4b..c7496658db966 100644
--- a/mlir/test/Dialect/SparseTensor/conversion.mlir
+++ b/mlir/test/Dialect/SparseTensor/conversion.mlir
@@ -16,6 +16,10 @@
   indexBitWidth = 32
 }>
 
+#SparseMatrix = #sparse_tensor.encoding<{
+  dimLevelType = ["dense", "compressed"]
+}>
+
 // CHECK-LABEL: func @sparse_dim(
 //  CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>)
 //       CHECK: %[[C:.*]] = constant 0 : index
@@ -27,15 +31,28 @@ func @sparse_dim(%arg0: tensor<?xf64, #SparseVector>) -> index {
   return %0 : index
 }
 
-// CHECK-LABEL: func @sparse_new(
+// CHECK-LABEL: func @sparse_new1d(
 //  CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>) -> !llvm.ptr<i8>
-//       CHECK: %[[T:.*]] = call @newSparseTensor(%[[A]]
+//       CHECK: %[[D:.*]] = constant dense<true> : tensor<1xi1>
+//       CHECK: %[[C:.*]] = tensor.cast %[[D]] : tensor<1xi1> to tensor<?xi1>
+//       CHECK: %[[T:.*]] = call @newSparseTensor(%[[A]], %[[C]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr<i8>, tensor<?xi1>, i64, i64, i64) -> !llvm.ptr<i8>
 //       CHECK: return %[[T]] : !llvm.ptr<i8>
-func @sparse_new(%arg0: !llvm.ptr<i8>) -> tensor<128xf64, #SparseVector> {
+func @sparse_new1d(%arg0: !llvm.ptr<i8>) -> tensor<128xf64, #SparseVector> {
   %0 = sparse_tensor.new %arg0 : !llvm.ptr<i8> to tensor<128xf64, #SparseVector>
   return %0 : tensor<128xf64, #SparseVector>
 }
 
+// CHECK-LABEL: func @sparse_new2d(
+//  CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>) -> !llvm.ptr<i8>
+//       CHECK: %[[D:.*]] = constant dense<[false, true]> : tensor<2xi1>
+//       CHECK: %[[C:.*]] = tensor.cast %[[D]] : tensor<2xi1> to tensor<?xi1>
+//       CHECK: %[[T:.*]] = call @newSparseTensor(%[[A]], %[[C]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr<i8>, tensor<?xi1>, i64, i64, i64) -> !llvm.ptr<i8>
+//       CHECK: return %[[T]] : !llvm.ptr<i8>
+func @sparse_new2d(%arg0: !llvm.ptr<i8>) -> tensor<?x?xf32, #SparseMatrix> {
+  %0 = sparse_tensor.new %arg0 : !llvm.ptr<i8> to tensor<?x?xf32, #SparseMatrix>
+  return %0 : tensor<?x?xf32, #SparseMatrix>
+}
+
 // CHECK-LABEL: func @sparse_pointers(
 //  CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>)
 //       CHECK: %[[C:.*]] = constant 0 : index


        


More information about the Mlir-commits mailing list