[Mlir-commits] [mlir] 9a018a7 - [mlir][sparse] relax constraints on tensor.cast with pre-rewriting

Aart Bik llvmlistbot at llvm.org
Mon May 1 16:03:58 PDT 2023


Author: Aart Bik
Date: 2023-05-01T16:03:44-07:00
New Revision: 9a018a7b48f0f46bfa3a16b4d2579ad85409fb4c

URL: https://github.com/llvm/llvm-project/commit/9a018a7b48f0f46bfa3a16b4d2579ad85409fb4c
DIFF: https://github.com/llvm/llvm-project/commit/9a018a7b48f0f46bfa3a16b4d2579ad85409fb4c.diff

LOG: [mlir][sparse] relax constraints on tensor.cast with pre-rewriting

Reviewed By: wrengr

Differential Revision: https://reviews.llvm.org/D149489

Added: 
    mlir/test/Dialect/SparseTensor/post_rewriting.mlir
    mlir/test/Dialect/SparseTensor/pre_rewriting.mlir

Modified: 
    mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
    mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Removed: 
    mlir/test/Dialect/SparseTensor/rewriting.mlir


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
index b6e6024719728..cbc93849c0a85 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
+++ b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
@@ -142,6 +142,9 @@ FailureOr<Value> getOrCreateDestination(OpBuilder &b, Location loc,
 LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op,
                                       SmallVector<Value> &result);
 
+/// Tests if types are the same when ignoring encoding on ranked tensors.
+bool isSameTypeWithoutEncoding(Type tp1, Type tp2);
+
 /// Function to control the folding of constant and extract slice.
 using ControlConstantExtractSliceFusionFn = std::function<bool(ExtractSliceOp)>;
 

diff  --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 42776c7d80a32..4b98030a9dbfa 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -449,8 +449,6 @@ mlir::sparse_tensor::getSparseTensorEncoding(Type type) {
   return nullptr;
 }
 
-/// Returns true iff the given sparse tensor encoding attribute has a trailing
-/// COO region starting at the given level.
 bool mlir::sparse_tensor::isCOOType(SparseTensorEncodingAttr enc,
                                     Level startLvl, bool isUnique) {
   if (!enc ||

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index d22ea43fbf5ac..73cbe611aa376 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -346,6 +346,45 @@ struct FuseSparseMultiplyOverAdd : public OpRewritePattern<GenericOp> {
   }
 };
 
+// Fuse a tensor cast into producing operation. Note that a tensor.cast
+// should really not be used to convert between sparse encodings. Since
+// the pattern currently appears as a result of some prior rewriting
+// we make an attempt to repair very obvious cases.
+// TODO: audit the pure tensor dialect rewriting rules
+struct FuseTensorCast : public OpRewritePattern<tensor::CastOp> {
+public:
+  using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tensor::CastOp op,
+                                PatternRewriter &rewriter) const override {
+    Type srcType = op.getSource().getType();
+    Type dstType = op.getDest().getType();
+    // A nop cast simply folds away.
+    if (srcType == dstType) {
+      rewriter.replaceOp(op, op->getResults());
+      return success();
+    }
+    // See if a sparsity changing cast can be fused into producer.
+    if (tensor::isSameTypeWithoutEncoding(srcType, dstType)) {
+      if (Operation *def = op.getSource().getDefiningOp()) {
+        if (def->hasOneUse() && isa<tensor::ExtractSliceOp>(def)) {
+          def->getResult(0).setType(op->getResultTypes()[0]);
+          rewriter.replaceOp(op, def->getResult(0));
+          return success();
+        }
+      }
+    }
+    // Repair tensor casts with at least one sparse operand into the
+    // the properly supported sparse_tensor.convert.
+    if (getSparseTensorEncoding(srcType) || getSparseTensorEncoding(dstType)) {
+      rewriter.replaceOpWithNewOp<ConvertOp>(op, dstType, op.getSource());
+      return success();
+    }
+    // Fail otherwise.
+    return failure();
+  }
+};
+
 /// Sparse rewriting rule for sparse-to-sparse reshape operator.
 template <typename ReshapeOp>
 struct Sparse2SparseReshapeRewriter : public OpRewritePattern<ReshapeOp> {
@@ -1125,7 +1164,7 @@ struct OutRewriter : public OpRewritePattern<OutOp> {
 //===---------------------------------------------------------------------===//
 
 void mlir::populatePreSparsificationRewriting(RewritePatternSet &patterns) {
-  patterns.add<FoldInvariantYield, FuseSparseMultiplyOverAdd>(
+  patterns.add<FoldInvariantYield, FuseSparseMultiplyOverAdd, FuseTensorCast>(
       patterns.getContext());
 }
 

diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 33639e1913ecc..8e2461f499b2a 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -110,6 +110,16 @@ LogicalResult tensor::getOrCreateDestinations(OpBuilder &b, Location loc,
   return success();
 }
 
+bool tensor::isSameTypeWithoutEncoding(Type tp1, Type tp2) {
+  if (auto rtp1 = tp1.dyn_cast<RankedTensorType>()) {
+    if (auto rtp2 = tp2.dyn_cast<RankedTensorType>())
+      return rtp1.getShape() == rtp2.getShape() &&
+             rtp1.getElementType() == rtp2.getElementType();
+    return false;
+  }
+  return tp1 == tp2; // default implementation
+}
+
 /// Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or
 /// rank-extending tensor.insert_slice op.
 static llvm::SmallBitVector getDroppedDims(ArrayRef<int64_t> reducedShape,
@@ -1343,18 +1353,6 @@ void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
                       getReassociationIndicesAttribute(b, reassociation));
 }
 
-// Checks if types are the same, but ignoring encoding on ranked tensors.
-static bool isSameTypesWithoutEncoding(Type tp1, Type tp2) {
-  if (auto rtp1 = tp1.dyn_cast<RankedTensorType>()) {
-    if (auto rtp2 = tp2.dyn_cast<RankedTensorType>())
-      return rtp1.getShape() == rtp2.getShape() &&
-             rtp1.getElementType() == rtp2.getElementType();
-    return false;
-  }
-  // Default implementation.
-  return tp1 == tp2;
-}
-
 template <typename TensorReshapeOp, bool isExpansion = std::is_same<
                                         TensorReshapeOp, ExpandShapeOp>::value>
 static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op,
@@ -1367,7 +1365,7 @@ static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op,
   auto maps = op.getReassociationMaps();
   RankedTensorType expectedType =
       CollapseShapeOp::inferCollapsedType(expandedType, maps);
-  if (!isSameTypesWithoutEncoding(collapsedType, expectedType))
+  if (!isSameTypeWithoutEncoding(collapsedType, expectedType))
     return op.emitOpError("expected collapsed type to be ")
            << expectedType << ", but got " << collapsedType;
   return success();

diff  --git a/mlir/test/Dialect/SparseTensor/rewriting.mlir b/mlir/test/Dialect/SparseTensor/post_rewriting.mlir
old mode 100755
new mode 100644
similarity index 100%
rename from mlir/test/Dialect/SparseTensor/rewriting.mlir
rename to mlir/test/Dialect/SparseTensor/post_rewriting.mlir

diff  --git a/mlir/test/Dialect/SparseTensor/pre_rewriting.mlir b/mlir/test/Dialect/SparseTensor/pre_rewriting.mlir
new file mode 100644
index 0000000000000..bbe1d6a10ee76
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/pre_rewriting.mlir
@@ -0,0 +1,45 @@
+// RUN: mlir-opt %s -pre-sparsification-rewrite | FileCheck %s
+
+#SparseVector = #sparse_tensor.encoding<{
+  dimLevelType = ["compressed"]
+}>
+
+#SortedCOO = #sparse_tensor.encoding<{
+  dimLevelType = [ "compressed-nu", "singleton" ]
+}>
+
+#Slice = #sparse_tensor.encoding<{
+  dimLevelType = [ "compressed-nu", "singleton" ],
+  slice = [ (?, 1, 1), (?, 3, 1) ]
+}>
+
+// CHECK-LABEL: func @sparse_nop_cast(
+//  CHECK-SAME: %[[A:.*]]: tensor<?xf32, #sparse_tensor.encoding<{{{.*}}}>>)
+//       CHECK: return %[[A]] : tensor<?xf32, #sparse_tensor.encoding<{{{.*}}}>>
+func.func @sparse_nop_cast(%a : tensor<?xf32, #SparseVector>) -> tensor<?xf32, #SparseVector> {
+  %0 = tensor.cast %a : tensor<?xf32, #SparseVector> to tensor<?xf32, #SparseVector>
+  %1 = tensor.cast %0 : tensor<?xf32, #SparseVector> to tensor<?xf32, #SparseVector>
+  %2 = tensor.cast %1 : tensor<?xf32, #SparseVector> to tensor<?xf32, #SparseVector>
+  return %2 : tensor<?xf32, #SparseVector>
+}
+
+// CHECK-LABEL: func @sparse_repair_cast(
+//  CHECK-SAME: %[[A:.*]]: tensor<?xf32>)
+//       CHECK: %[[C:.*]] = sparse_tensor.convert %[[A]] : tensor<?xf32> to tensor<?xf32, #sparse_tensor.encoding<{{{.*}}}>
+//       CHECK: return %[[C]] : tensor<?xf32, #sparse_tensor.encoding<{{{.*}}}>>
+func.func @sparse_repair_cast(%a : tensor<?xf32>) -> tensor<?xf32, #SparseVector> {
+  %0 = tensor.cast %a : tensor<?xf32> to tensor<?xf32, #SparseVector>
+  return %0 : tensor<?xf32, #SparseVector>
+}
+
+// CHECK-LABEL: func @sparse_fuse_slice(
+//  CHECK-SAME: %[[A:.*]]: tensor<2x3xi64, #sparse_tensor.encoding<{{{.*}}}>>)
+//       CHECK: %[[E:.*]] = tensor.extract_slice %[[A]][1, 0] [1, 3] [1, 1] : tensor<2x3xi64, #sparse_tensor.encoding<{{{.*}}}>> to tensor<1x3xi64, #sparse_tensor.encoding<{{{.*}}}>>
+//       CHECK: %[[C:.*]] = sparse_tensor.convert %[[E]] : tensor<1x3xi64, #sparse_tensor.encoding<{{{.*}}}>> to tensor<1x3xi64, #sparse_tensor.encoding<{{{.*}}}>>
+//       CHECK: return %[[C]] : tensor<1x3xi64, #sparse_tensor.encoding<{{{.*}}}>>
+func.func @sparse_fuse_slice(%a : tensor<2x3xi64, #SortedCOO>) -> tensor<1x3xi64, #SortedCOO> {
+  %extracted_slice = tensor.extract_slice %a[1, 0] [1, 3] [1, 1] : tensor<2x3xi64, #SortedCOO> to tensor<1x3xi64>
+  %cast = tensor.cast %extracted_slice : tensor<1x3xi64> to tensor<1x3xi64, #Slice>
+  %0 = sparse_tensor.convert %cast : tensor<1x3xi64, #Slice> to tensor<1x3xi64, #SortedCOO>
+  return %0 : tensor<1x3xi64, #SortedCOO>
+}


        


More information about the Mlir-commits mailing list