[Mlir-commits] [mlir] [mlir][sparse] Replace `getSparseTensorType` with `tryGetSparseTensorType` (PR #109435)

Longsheng Mou llvmlistbot at llvm.org
Fri Sep 20 08:00:44 PDT 2024


https://github.com/CoTinker created https://github.com/llvm/llvm-project/pull/109435

This PR fixes a bug in `SparseTensorDimOpRewriter` when `tensor.dim` has an unranked tensor type. To prevent crashes, we now use `tryGetSparseTensorType` instead of `getSparseTensorType`. Fixes #107807.

>From c4b27e30e8d628b329e159122a7e14520724729d Mon Sep 17 00:00:00 2001
From: Longsheng Mou <longshengmou at gmail.com>
Date: Fri, 20 Sep 2024 22:50:17 +0800
Subject: [PATCH] [mlir][sparse] Replace `getSparseTensorType` with
 `tryGetSparseTensorType`

This PR fixes a bug in `SparseTensorDimOpRewriter` when `tensor.dim`
has an unranked tensor type. To prevent crashes, we now use
`tryGetSparseTensorType` instead of `getSparseTensorType`.
---
 .../Transforms/SparseTensorRewriting.cpp      | 42 ++++++++++---------
 mlir/test/Dialect/SparseTensor/codegen.mlir   | 16 +++++++
 2 files changed, 38 insertions(+), 20 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index cc372ed1be6217..60db71d96547fe 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -881,25 +881,27 @@ struct TensorReshapeRewriter : public OpRewritePattern<tensor::ReshapeOp> {
                                 PatternRewriter &rewriter) const override {
     Location loc = op.getLoc();
     Value srcTensor = op.getSource();
-    const auto srcTp = getSparseTensorType(srcTensor);
-    const auto dstTp = getSparseTensorType(op.getResult());
+    const auto srcTp = tryGetSparseTensorType(srcTensor);
+    const auto dstTp = tryGetSparseTensorType(op.getResult());
+    if (!srcTp || !dstTp)
+      return failure();
 
-    if (!srcTp.hasEncoding() || !dstTp.hasEncoding() ||
-        !dstTp.hasStaticDimShape())
+    if (!srcTp->hasEncoding() || !dstTp->hasEncoding() ||
+        !dstTp->hasStaticDimShape())
       return failure();
 
     SmallVector<Value> srcSizes;
-    sizesForTensor(rewriter, srcSizes, loc, srcTp, srcTensor);
+    sizesForTensor(rewriter, srcSizes, loc, *srcTp, srcTensor);
     SmallVector<Value> dstSizes;
-    for (Dimension d : dstTp.getDimShape())
+    for (Dimension d : dstTp->getDimShape())
       dstSizes.push_back(constantIndex(rewriter, loc, d));
 
     Value nnz = rewriter.create<NumberOfEntriesOp>(loc, srcTensor);
     // Only need an unordered COO buffer if input and output are not sorted
     // in the same way.
     Type bufferTp = getBufferType(
-        dstTp.withoutDimToLvl(),
-        !srcTp.isAllOrdered() || !srcTp.isIdentity() || !dstTp.isIdentity());
+        dstTp->withoutDimToLvl(),
+        !srcTp->isAllOrdered() || !srcTp->isIdentity() || !dstTp->isIdentity());
     SmallVector<Value> dynSizes;
     Value buffer = rewriter
                        .create<AllocTensorOp>(loc, bufferTp, dynSizes, Value(),
@@ -917,12 +919,12 @@ struct TensorReshapeRewriter : public OpRewritePattern<tensor::ReshapeOp> {
     // followed by an optional
     //   %t = sparse_tensor.cast %tmp
     // depending on whether the input/output are sorted in the same way.
-    const auto encSrc = srcTp.getEncoding();
+    const auto encSrc = srcTp->getEncoding();
     ForeachOp foreachOp = rewriter.create<ForeachOp>(
         loc, srcTensor, buffer,
         [&](OpBuilder &builder, Location loc, ValueRange srcLcvs, Value v,
             ValueRange reduc) {
-          const Dimension srcRank = srcTp.getDimRank();
+          const Dimension srcRank = srcTp->getDimRank();
           SmallVector<Value> srcDcvs;
           srcDcvs.reserve(srcRank);
           for (Dimension d = 0; d < srcRank; d++) {
@@ -945,7 +947,7 @@ struct TensorReshapeRewriter : public OpRewritePattern<tensor::ReshapeOp> {
                      collapsedSizes, collapsedDcvs);
 
           ReassociationIndices expandIdx;
-          for (Dimension i = 0; i < dstTp.getDimRank(); i++)
+          for (Dimension i = 0; i < dstTp->getDimRank(); i++)
             expandIdx.push_back(i);
           SmallVector<ReassociationIndices, 1> expandReass = {expandIdx};
           SmallVector<Value> dstDcvs;
@@ -958,8 +960,8 @@ struct TensorReshapeRewriter : public OpRewritePattern<tensor::ReshapeOp> {
         });
 
     Value t = rewriter.create<LoadOp>(loc, foreachOp.getResult(0), true);
-    if (bufferTp != dstTp) {
-      auto dstRTT = dstTp.getRankedTensorType();
+    if (bufferTp != *dstTp) {
+      auto dstRTT = dstTp->getRankedTensorType();
       Value converted = rewriter.create<ConvertOp>(loc, dstRTT, t).getResult();
       rewriter.create<DeallocTensorOp>(loc, t);
       t = converted;
@@ -1139,13 +1141,13 @@ struct SparseTensorDimOpRewriter : public OpRewritePattern<tensor::DimOp> {
   LogicalResult matchAndRewrite(tensor::DimOp op,
                                 PatternRewriter &rewriter) const override {
     std::optional<int64_t> dim = op.getConstantIndex();
-    auto stt = getSparseTensorType(op.getSource());
-    if (!dim || !stt.hasEncoding())
+    auto stt = tryGetSparseTensorType(op.getSource());
+    if (!dim || !stt || !stt->hasEncoding())
       return failure();
 
-    if (stt.isPermutation()) {
+    if (stt->isPermutation()) {
       rewriter.replaceOpWithNewOp<LvlOp>(op, op.getSource(),
-                                         toLvl(stt.getEncoding(), *dim));
+                                         toLvl(stt->getEncoding(), *dim));
       return success();
     }
 
@@ -1157,16 +1159,16 @@ struct SparseTensorDimOpRewriter : public OpRewritePattern<tensor::DimOp> {
     // computed simply by lvl_size * block_size.
     Location loc = op.getLoc();
     SmallVector<Value> maxLvlCrds;
-    for (Level l = 0; l < stt.getLvlRank(); l++) {
+    for (Level l = 0; l < stt->getLvlRank(); l++) {
       Value lvlSz = rewriter.create<LvlOp>(loc, op.getSource(), l);
       Value maxLvlCrd = rewriter.create<arith::SubIOp>(
           loc, lvlSz, constantOne(rewriter, loc, rewriter.getIndexType()));
       maxLvlCrds.push_back(maxLvlCrd);
     }
 
-    AffineExpr lvl2DimExp = stt.getLvlToDim().getResult(*dim);
+    AffineExpr lvl2DimExp = stt->getLvlToDim().getResult(*dim);
     Value maxDimCrd = rewriter.create<affine::AffineApplyOp>(
-        op.getLoc(), AffineMap::get(stt.getLvlRank(), 0, lvl2DimExp),
+        op.getLoc(), AffineMap::get(stt->getLvlRank(), 0, lvl2DimExp),
         maxLvlCrds);
 
     Value dimSz = rewriter.create<arith::AddIOp>(
diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir
index af78458f109329..df03d871ba3a3e 100644
--- a/mlir/test/Dialect/SparseTensor/codegen.mlir
+++ b/mlir/test/Dialect/SparseTensor/codegen.mlir
@@ -826,3 +826,19 @@ func.func @sparse_new_coo_permute_no(%arg0: !llvm.ptr) -> tensor<?x?xf32, #CooPN
   %0 = sparse_tensor.new %arg0 : !llvm.ptr to tensor<?x?xf32, #CooPNo>
   return %0 : tensor<?x?xf32, #CooPNo>
 }
+
+// CHECK-LABEL: func.func @test_tensor_dim_unranked
+//       CHECK: tensor.dim
+func.func @test_tensor_dim_unranked(%arg0: tensor<*xf32>) -> index {
+  %c = arith.constant 0 : index
+  %0 = tensor.dim %arg0, %c : tensor<*xf32>
+  return %0 : index
+}
+
+// CHECK-LABEL: func.func @test_tensor_reshape_unranked
+//       CHECK: tensor.reshape
+func.func @test_tensor_reshape_unranked(%src: tensor<*xf32>, %shape: tensor<1xi32>) -> tensor<?xf32> {
+  %dst = tensor.reshape %src(%shape)
+         : (tensor<*xf32>, tensor<1xi32>) -> tensor<?xf32>
+  return %dst : tensor<?xf32>
+}



More information about the Mlir-commits mailing list