[Mlir-commits] [mlir] 9a018a7 - [mlir][sparse] relax constraints on tensor.cast with pre-rewriting
Aart Bik
llvmlistbot at llvm.org
Mon May 1 16:03:58 PDT 2023
Author: Aart Bik
Date: 2023-05-01T16:03:44-07:00
New Revision: 9a018a7b48f0f46bfa3a16b4d2579ad85409fb4c
URL: https://github.com/llvm/llvm-project/commit/9a018a7b48f0f46bfa3a16b4d2579ad85409fb4c
DIFF: https://github.com/llvm/llvm-project/commit/9a018a7b48f0f46bfa3a16b4d2579ad85409fb4c.diff
LOG: [mlir][sparse] relax constraints on tensor.cast with pre-rewriting
Reviewed By: wrengr
Differential Revision: https://reviews.llvm.org/D149489
Added:
mlir/test/Dialect/SparseTensor/post_rewriting.mlir
mlir/test/Dialect/SparseTensor/pre_rewriting.mlir
Modified:
mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
Removed:
mlir/test/Dialect/SparseTensor/rewriting.mlir
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
index b6e6024719728..cbc93849c0a85 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
+++ b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
@@ -142,6 +142,9 @@ FailureOr<Value> getOrCreateDestination(OpBuilder &b, Location loc,
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op,
SmallVector<Value> &result);
+/// Tests if types are the same when ignoring encoding on ranked tensors.
+bool isSameTypeWithoutEncoding(Type tp1, Type tp2);
+
/// Function to control the folding of constant and extract slice.
using ControlConstantExtractSliceFusionFn = std::function<bool(ExtractSliceOp)>;
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 42776c7d80a32..4b98030a9dbfa 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -449,8 +449,6 @@ mlir::sparse_tensor::getSparseTensorEncoding(Type type) {
return nullptr;
}
-/// Returns true iff the given sparse tensor encoding attribute has a trailing
-/// COO region starting at the given level.
bool mlir::sparse_tensor::isCOOType(SparseTensorEncodingAttr enc,
Level startLvl, bool isUnique) {
if (!enc ||
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index d22ea43fbf5ac..73cbe611aa376 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -346,6 +346,45 @@ struct FuseSparseMultiplyOverAdd : public OpRewritePattern<GenericOp> {
}
};
+// Fuse a tensor cast into producing operation. Note that a tensor.cast
+// should really not be used to convert between sparse encodings. Since
+// the pattern currently appears as a result of some prior rewriting
+// we make an attempt to repair very obvious cases.
+// TODO: audit the pure tensor dialect rewriting rules
+struct FuseTensorCast : public OpRewritePattern<tensor::CastOp> {
+public:
+ using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tensor::CastOp op,
+ PatternRewriter &rewriter) const override {
+ Type srcType = op.getSource().getType();
+ Type dstType = op.getDest().getType();
+ // A nop cast simply folds away.
+ if (srcType == dstType) {
+ rewriter.replaceOp(op, op->getResults());
+ return success();
+ }
+ // See if a sparsity changing cast can be fused into producer.
+ if (tensor::isSameTypeWithoutEncoding(srcType, dstType)) {
+ if (Operation *def = op.getSource().getDefiningOp()) {
+ if (def->hasOneUse() && isa<tensor::ExtractSliceOp>(def)) {
+ def->getResult(0).setType(op->getResultTypes()[0]);
+ rewriter.replaceOp(op, def->getResult(0));
+ return success();
+ }
+ }
+ }
+ // Repair tensor casts with at least one sparse operand into the
+ // the properly supported sparse_tensor.convert.
+ if (getSparseTensorEncoding(srcType) || getSparseTensorEncoding(dstType)) {
+ rewriter.replaceOpWithNewOp<ConvertOp>(op, dstType, op.getSource());
+ return success();
+ }
+ // Fail otherwise.
+ return failure();
+ }
+};
+
/// Sparse rewriting rule for sparse-to-sparse reshape operator.
template <typename ReshapeOp>
struct Sparse2SparseReshapeRewriter : public OpRewritePattern<ReshapeOp> {
@@ -1125,7 +1164,7 @@ struct OutRewriter : public OpRewritePattern<OutOp> {
//===---------------------------------------------------------------------===//
void mlir::populatePreSparsificationRewriting(RewritePatternSet &patterns) {
- patterns.add<FoldInvariantYield, FuseSparseMultiplyOverAdd>(
+ patterns.add<FoldInvariantYield, FuseSparseMultiplyOverAdd, FuseTensorCast>(
patterns.getContext());
}
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 33639e1913ecc..8e2461f499b2a 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -110,6 +110,16 @@ LogicalResult tensor::getOrCreateDestinations(OpBuilder &b, Location loc,
return success();
}
+bool tensor::isSameTypeWithoutEncoding(Type tp1, Type tp2) {
+ if (auto rtp1 = tp1.dyn_cast<RankedTensorType>()) {
+ if (auto rtp2 = tp2.dyn_cast<RankedTensorType>())
+ return rtp1.getShape() == rtp2.getShape() &&
+ rtp1.getElementType() == rtp2.getElementType();
+ return false;
+ }
+ return tp1 == tp2; // default implementation
+}
+
/// Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or
/// rank-extending tensor.insert_slice op.
static llvm::SmallBitVector getDroppedDims(ArrayRef<int64_t> reducedShape,
@@ -1343,18 +1353,6 @@ void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
getReassociationIndicesAttribute(b, reassociation));
}
-// Checks if types are the same, but ignoring encoding on ranked tensors.
-static bool isSameTypesWithoutEncoding(Type tp1, Type tp2) {
- if (auto rtp1 = tp1.dyn_cast<RankedTensorType>()) {
- if (auto rtp2 = tp2.dyn_cast<RankedTensorType>())
- return rtp1.getShape() == rtp2.getShape() &&
- rtp1.getElementType() == rtp2.getElementType();
- return false;
- }
- // Default implementation.
- return tp1 == tp2;
-}
-
template <typename TensorReshapeOp, bool isExpansion = std::is_same<
TensorReshapeOp, ExpandShapeOp>::value>
static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op,
@@ -1367,7 +1365,7 @@ static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op,
auto maps = op.getReassociationMaps();
RankedTensorType expectedType =
CollapseShapeOp::inferCollapsedType(expandedType, maps);
- if (!isSameTypesWithoutEncoding(collapsedType, expectedType))
+ if (!isSameTypeWithoutEncoding(collapsedType, expectedType))
return op.emitOpError("expected collapsed type to be ")
<< expectedType << ", but got " << collapsedType;
return success();
diff --git a/mlir/test/Dialect/SparseTensor/rewriting.mlir b/mlir/test/Dialect/SparseTensor/post_rewriting.mlir
old mode 100755
new mode 100644
similarity index 100%
rename from mlir/test/Dialect/SparseTensor/rewriting.mlir
rename to mlir/test/Dialect/SparseTensor/post_rewriting.mlir
diff --git a/mlir/test/Dialect/SparseTensor/pre_rewriting.mlir b/mlir/test/Dialect/SparseTensor/pre_rewriting.mlir
new file mode 100644
index 0000000000000..bbe1d6a10ee76
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/pre_rewriting.mlir
@@ -0,0 +1,45 @@
+// RUN: mlir-opt %s -pre-sparsification-rewrite | FileCheck %s
+
+#SparseVector = #sparse_tensor.encoding<{
+ dimLevelType = ["compressed"]
+}>
+
+#SortedCOO = #sparse_tensor.encoding<{
+ dimLevelType = [ "compressed-nu", "singleton" ]
+}>
+
+#Slice = #sparse_tensor.encoding<{
+ dimLevelType = [ "compressed-nu", "singleton" ],
+ slice = [ (?, 1, 1), (?, 3, 1) ]
+}>
+
+// CHECK-LABEL: func @sparse_nop_cast(
+// CHECK-SAME: %[[A:.*]]: tensor<?xf32, #sparse_tensor.encoding<{{{.*}}}>>)
+// CHECK: return %[[A]] : tensor<?xf32, #sparse_tensor.encoding<{{{.*}}}>>
+func.func @sparse_nop_cast(%a : tensor<?xf32, #SparseVector>) -> tensor<?xf32, #SparseVector> {
+ %0 = tensor.cast %a : tensor<?xf32, #SparseVector> to tensor<?xf32, #SparseVector>
+ %1 = tensor.cast %0 : tensor<?xf32, #SparseVector> to tensor<?xf32, #SparseVector>
+ %2 = tensor.cast %1 : tensor<?xf32, #SparseVector> to tensor<?xf32, #SparseVector>
+ return %2 : tensor<?xf32, #SparseVector>
+}
+
+// CHECK-LABEL: func @sparse_repair_cast(
+// CHECK-SAME: %[[A:.*]]: tensor<?xf32>)
+// CHECK: %[[C:.*]] = sparse_tensor.convert %[[A]] : tensor<?xf32> to tensor<?xf32, #sparse_tensor.encoding<{{{.*}}}>
+// CHECK: return %[[C]] : tensor<?xf32, #sparse_tensor.encoding<{{{.*}}}>>
+func.func @sparse_repair_cast(%a : tensor<?xf32>) -> tensor<?xf32, #SparseVector> {
+ %0 = tensor.cast %a : tensor<?xf32> to tensor<?xf32, #SparseVector>
+ return %0 : tensor<?xf32, #SparseVector>
+}
+
+// CHECK-LABEL: func @sparse_fuse_slice(
+// CHECK-SAME: %[[A:.*]]: tensor<2x3xi64, #sparse_tensor.encoding<{{{.*}}}>>)
+// CHECK: %[[E:.*]] = tensor.extract_slice %[[A]][1, 0] [1, 3] [1, 1] : tensor<2x3xi64, #sparse_tensor.encoding<{{{.*}}}>> to tensor<1x3xi64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK: %[[C:.*]] = sparse_tensor.convert %[[E]] : tensor<1x3xi64, #sparse_tensor.encoding<{{{.*}}}>> to tensor<1x3xi64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK: return %[[C]] : tensor<1x3xi64, #sparse_tensor.encoding<{{{.*}}}>>
+func.func @sparse_fuse_slice(%a : tensor<2x3xi64, #SortedCOO>) -> tensor<1x3xi64, #SortedCOO> {
+ %extracted_slice = tensor.extract_slice %a[1, 0] [1, 3] [1, 1] : tensor<2x3xi64, #SortedCOO> to tensor<1x3xi64>
+ %cast = tensor.cast %extracted_slice : tensor<1x3xi64> to tensor<1x3xi64, #Slice>
+ %0 = sparse_tensor.convert %cast : tensor<1x3xi64, #Slice> to tensor<1x3xi64, #SortedCOO>
+ return %0 : tensor<1x3xi64, #SortedCOO>
+}
More information about the Mlir-commits
mailing list