[Mlir-commits] [mlir] ca5d0a7 - [mlir][sparse] keep runtime support library signature consistent
Aart Bik
llvmlistbot at llvm.org
Wed May 12 10:00:03 PDT 2021
Author: Aart Bik
Date: 2021-05-12T09:59:46-07:00
New Revision: ca5d0a7310bfb21730ac6dd735e06502e7e45099
URL: https://github.com/llvm/llvm-project/commit/ca5d0a7310bfb21730ac6dd735e06502e7e45099
DIFF: https://github.com/llvm/llvm-project/commit/ca5d0a7310bfb21730ac6dd735e06502e7e45099.diff
LOG: [mlir][sparse] keep runtime support library signature consistent
Reviewed By: bixia
Differential Revision: https://reviews.llvm.org/D102285
Added:
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
mlir/test/Dialect/SparseTensor/conversion.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
index 336e834cc109c..68adb6fe1db18 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
@@ -19,6 +19,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
MLIRSCF
MLIRStandard
MLIRSparseTensor
+ MLIRTensor
MLIRTransforms
MLIRVector
)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index 71515fecb0606..a2c7b8516d41e 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -19,6 +19,7 @@
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Transforms/DialectConversion.h"
using namespace mlir;
@@ -103,15 +104,21 @@ class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
return failure();
// User pointer.
params.push_back(operands[0]);
- // Sparsity annotations.
+ // Sparsity annotations in tensor constant form. Note that we cast
+ // the static shape into a dynamic shape to ensure that the method
+ // signature remains uniform accross
diff erent tensor dimensions.
SmallVector<bool, 4> attrs;
unsigned sz = enc.getDimLevelType().size();
for (unsigned i = 0; i < sz; i++)
attrs.push_back(enc.getDimLevelType()[i] ==
SparseTensorEncodingAttr::DimLevelType::Compressed);
- auto elts = DenseElementsAttr::get(
- RankedTensorType::get({sz}, rewriter.getIntegerType(1)), attrs);
- params.push_back(rewriter.create<ConstantOp>(loc, elts));
+ Type etp = rewriter.getIntegerType(1);
+ RankedTensorType tt1 = RankedTensorType::get({sz}, etp);
+ RankedTensorType tt2 =
+ RankedTensorType::get({ShapedType::kDynamicSize}, etp);
+ auto elts =
+ rewriter.create<ConstantOp>(loc, DenseElementsAttr::get(tt1, attrs));
+ params.push_back(rewriter.create<tensor::CastOp>(loc, tt2, elts));
// Seconary and primary types encoding.
unsigned secPtr = getOverheadTypeEncoding(enc.getPointerBitWidth());
unsigned secInd = getOverheadTypeEncoding(enc.getIndexBitWidth());
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index 641ba4af4363b..05fd7537a474e 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -120,6 +120,7 @@ struct SparseTensorConversionPass
target.addDynamicallyLegalOp<ReturnOp>(
[&](ReturnOp op) { return converter.isLegal(op.getOperandTypes()); });
target.addLegalOp<ConstantOp>();
+ target.addLegalOp<tensor::CastOp>();
populateFuncOpTypeConversionPattern(patterns, converter);
populateCallOpTypeConversionPattern(patterns, converter);
populateSparseTensorConversionPatterns(converter, patterns);
diff --git a/mlir/test/Dialect/SparseTensor/conversion.mlir b/mlir/test/Dialect/SparseTensor/conversion.mlir
index 54bfa745dff4b..c7496658db966 100644
--- a/mlir/test/Dialect/SparseTensor/conversion.mlir
+++ b/mlir/test/Dialect/SparseTensor/conversion.mlir
@@ -16,6 +16,10 @@
indexBitWidth = 32
}>
+#SparseMatrix = #sparse_tensor.encoding<{
+ dimLevelType = ["dense", "compressed"]
+}>
+
// CHECK-LABEL: func @sparse_dim(
// CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>)
// CHECK: %[[C:.*]] = constant 0 : index
@@ -27,15 +31,28 @@ func @sparse_dim(%arg0: tensor<?xf64, #SparseVector>) -> index {
return %0 : index
}
-// CHECK-LABEL: func @sparse_new(
+// CHECK-LABEL: func @sparse_new1d(
// CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>) -> !llvm.ptr<i8>
-// CHECK: %[[T:.*]] = call @newSparseTensor(%[[A]]
+// CHECK: %[[D:.*]] = constant dense<true> : tensor<1xi1>
+// CHECK: %[[C:.*]] = tensor.cast %[[D]] : tensor<1xi1> to tensor<?xi1>
+// CHECK: %[[T:.*]] = call @newSparseTensor(%[[A]], %[[C]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr<i8>, tensor<?xi1>, i64, i64, i64) -> !llvm.ptr<i8>
// CHECK: return %[[T]] : !llvm.ptr<i8>
-func @sparse_new(%arg0: !llvm.ptr<i8>) -> tensor<128xf64, #SparseVector> {
+func @sparse_new1d(%arg0: !llvm.ptr<i8>) -> tensor<128xf64, #SparseVector> {
%0 = sparse_tensor.new %arg0 : !llvm.ptr<i8> to tensor<128xf64, #SparseVector>
return %0 : tensor<128xf64, #SparseVector>
}
+// CHECK-LABEL: func @sparse_new2d(
+// CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>) -> !llvm.ptr<i8>
+// CHECK: %[[D:.*]] = constant dense<[false, true]> : tensor<2xi1>
+// CHECK: %[[C:.*]] = tensor.cast %[[D]] : tensor<2xi1> to tensor<?xi1>
+// CHECK: %[[T:.*]] = call @newSparseTensor(%[[A]], %[[C]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr<i8>, tensor<?xi1>, i64, i64, i64) -> !llvm.ptr<i8>
+// CHECK: return %[[T]] : !llvm.ptr<i8>
+func @sparse_new2d(%arg0: !llvm.ptr<i8>) -> tensor<?x?xf32, #SparseMatrix> {
+ %0 = sparse_tensor.new %arg0 : !llvm.ptr<i8> to tensor<?x?xf32, #SparseMatrix>
+ return %0 : tensor<?x?xf32, #SparseMatrix>
+}
+
// CHECK-LABEL: func @sparse_pointers(
// CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>)
// CHECK: %[[C:.*]] = constant 0 : index
More information about the Mlir-commits
mailing list