[Mlir-commits] [mlir] 1b15160 - [mlir][sparse] lower trivial tensor.cast on identical sparse tensors

Aart Bik llvmlistbot at llvm.org
Mon Oct 25 10:30:27 PDT 2021


Author: Aart Bik
Date: 2021-10-25T10:30:19-07:00
New Revision: 1b15160ef3b3edaf6fdcf403e5aa9e11ba217ce1

URL: https://github.com/llvm/llvm-project/commit/1b15160ef3b3edaf6fdcf403e5aa9e11ba217ce1
DIFF: https://github.com/llvm/llvm-project/commit/1b15160ef3b3edaf6fdcf403e5aa9e11ba217ce1.diff

LOG: [mlir][sparse] lower trivial tensor.cast on identical sparse tensors

Even though tensor.cast is not part of the sparse tensor dialect,
it may be used to cast static dimension sizes to dynamic dimension
sizes for sparse tensors without changing the actual sparse tensor
itself. Those cases should be lowered properly when replacing sparse
tensor types with their opaque pointers. Likewise, no op sparse
conversions are handled by this revision in a similar manner.

Reviewed By: bixia

Differential Revision: https://reviews.llvm.org/D112173

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
    mlir/test/Dialect/SparseTensor/conversion.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index 4b122aec293d..8a2daa9ecb9d 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -418,6 +418,22 @@ class SparseTensorToDimSizeConverter
   }
 };
 
+/// Sparse conversion rule for trivial tensor casts.
+class SparseCastConverter : public OpConversionPattern<tensor::CastOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // Only rewrite identically annotated source/dest.
+    auto encDst = getSparseTensorEncoding(op.getType());
+    auto encSrc = getSparseTensorEncoding(op.source().getType());
+    if (!encDst || encDst != encSrc)
+      return failure();
+    rewriter.replaceOp(op, adaptor.getOperands());
+    return success();
+  }
+};
+
 /// Sparse conversion rule for the new operator.
 class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
   using OpConversionPattern::OpConversionPattern;
@@ -478,6 +494,10 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
       // Using the coordinate scheme as an intermediate does not always
       // yield the fastest conversion but avoids the need for a full
       // O(N^2) conversion matrix.
+      if (encDst == encSrc) {
+        rewriter.replaceOp(op, adaptor.getOperands()); // hidden nop cast
+        return success();
+      }
       SmallVector<Value, 4> sizes;
       SmallVector<Value, 8> params;
       sizesFromPtr(rewriter, sizes, op, encSrc, srcType.cast<ShapedType>(),
@@ -719,9 +739,10 @@ class SparseTensorToTensorConverter : public OpConversionPattern<ToTensorOp> {
 void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter,
                                                   RewritePatternSet &patterns) {
   patterns.add<SparseReturnConverter, SparseTensorToDimSizeConverter,
-               SparseTensorNewConverter, SparseTensorInitConverter,
-               SparseTensorConvertConverter, SparseTensorReleaseConverter,
-               SparseTensorToPointersConverter, SparseTensorToIndicesConverter,
-               SparseTensorToValuesConverter, SparseTensorToTensorConverter>(
-      typeConverter, patterns.getContext());
+               SparseCastConverter, SparseTensorNewConverter,
+               SparseTensorInitConverter, SparseTensorConvertConverter,
+               SparseTensorReleaseConverter, SparseTensorToPointersConverter,
+               SparseTensorToIndicesConverter, SparseTensorToValuesConverter,
+               SparseTensorToTensorConverter>(typeConverter,
+                                              patterns.getContext());
 }

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index 9875a5c58ba7..69657445336b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -97,11 +97,11 @@ struct SparseTensorConversionPass
     RewritePatternSet patterns(ctx);
     SparseTensorTypeConverter converter;
     ConversionTarget target(*ctx);
-    target.addIllegalOp<ConvertOp, NewOp, ToIndicesOp, ToPointersOp, ToTensorOp,
-                        ToValuesOp>();
-    // All dynamic rules below accept new function, call, return, and dimop
-    // operations as legal output of the rewriting provided that all sparse
-    // tensor types have been fully rewritten.
+    // Everything in the sparse dialect must go!
+    target.addIllegalDialect<SparseTensorDialect>();
+    // All dynamic rules below accept new function, call, return, and tensor
+    // dim and cast operations as legal output of the rewriting provided that
+    // all sparse tensor types have been fully rewritten.
     target.addDynamicallyLegalOp<FuncOp>(
         [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
     target.addDynamicallyLegalOp<CallOp>([&](CallOp op) {
@@ -112,10 +112,13 @@ struct SparseTensorConversionPass
     target.addDynamicallyLegalOp<tensor::DimOp>([&](tensor::DimOp op) {
       return converter.isLegal(op.getOperandTypes());
     });
+    target.addDynamicallyLegalOp<tensor::CastOp>([&](tensor::CastOp op) {
+      return converter.isLegal(op.getOperand().getType());
+    });
     // The following operations and dialects may be introduced by the
     // rewriting rules, and are therefore marked as legal.
     target.addLegalOp<arith::CmpFOp, arith::CmpIOp, arith::ConstantOp,
-                      arith::IndexCastOp, tensor::CastOp, tensor::ExtractOp>();
+                      arith::IndexCastOp, tensor::ExtractOp>();
     target.addLegalDialect<LLVM::LLVMDialect, memref::MemRefDialect,
                            scf::SCFDialect>();
     // Populate with rules and apply rewriting rules.

diff  --git a/mlir/test/Dialect/SparseTensor/conversion.mlir b/mlir/test/Dialect/SparseTensor/conversion.mlir
index 41dd00b80e16..2d74f1db4e49 100644
--- a/mlir/test/Dialect/SparseTensor/conversion.mlir
+++ b/mlir/test/Dialect/SparseTensor/conversion.mlir
@@ -150,6 +150,22 @@ func @sparse_nop_convert(%arg0: tensor<64xf32, #SparseVector>) -> tensor<64xf32,
   return %0 : tensor<64xf32, #SparseVector>
 }
 
+// CHECK-LABEL: func @sparse_hidden_nop_cast(
+//  CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>) -> !llvm.ptr<i8>
+//       CHECK: return %[[A]] : !llvm.ptr<i8>
+func @sparse_hidden_nop_cast(%arg0: tensor<32xf32, #SparseVector>) -> tensor<?xf32, #SparseVector> {
+  %0 = sparse_tensor.convert %arg0 : tensor<32xf32, #SparseVector> to tensor<?xf32, #SparseVector>
+  return %0 : tensor<?xf32, #SparseVector>
+}
+
+// CHECK-LABEL: func @sparse_nop_cast(
+//  CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>) -> !llvm.ptr<i8>
+//       CHECK: return %[[A]] : !llvm.ptr<i8>
+func @sparse_nop_cast(%arg0: tensor<64xf32, #SparseVector>) -> tensor<?xf32, #SparseVector> {
+  %0 = tensor.cast %arg0 : tensor<64xf32, #SparseVector> to tensor<?xf32, #SparseVector>
+  return %0 : tensor<?xf32, #SparseVector>
+}
+
 // CHECK-LABEL: func @sparse_convert_1d(
 //  CHECK-SAME: %[[A:.*]]: tensor<?xi32>) -> !llvm.ptr<i8>
 //   CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index


        


More information about the Mlir-commits mailing list