[Mlir-commits] [mlir] d50d678 - [mlir][sparse] Add lowering rules for sparse_tensor.storage Op

Peiming Liu llvmlistbot at llvm.org
Tue Sep 6 14:04:24 PDT 2022


Author: Peiming Liu
Date: 2022-09-06T21:04:16Z
New Revision: d50d67885452eea43dc45f3950a7f59a232bc35d

URL: https://github.com/llvm/llvm-project/commit/d50d67885452eea43dc45f3950a7f59a232bc35d
DIFF: https://github.com/llvm/llvm-project/commit/d50d67885452eea43dc45f3950a7f59a232bc35d.diff

LOG: [mlir][sparse] Add lowering rules for sparse_tensor.storage Op

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D133368

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageExpansion.cpp
    mlir/test/Dialect/SparseTensor/codegen.mlir
    mlir/test/Dialect/SparseTensor/sparse_tensor_storage.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageExpansion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageExpansion.cpp
index 31370e32cb063..1f7afa1d77804 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageExpansion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageExpansion.cpp
@@ -65,6 +65,22 @@ static void flattenOperands(ValueRange operands,
 // Conversion rules.
 //===----------------------------------------------------------------------===//
 
+/// Sparse tensor storage conversion rule for sparse_tensor::storage.
+class SparseStorageConversion : public OpConversionPattern<StorageOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(StorageOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // Simply convert it to a unrealize_conversion_cast.
+    // We should guarantee that all uses of sparse_tensor.storage op will
+    // be eventually eliminated by accessing the flattened SSA values directly.
+    rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
+        op, TypeRange{op.getType()}, adaptor.getInputs());
+    return success();
+  }
+};
+
 /// Sparse tensor storage conversion rule for sparse_tensor::storage_get.
 class SparseStorageGetConverter : public OpConversionPattern<StorageGetOp> {
 public:
@@ -195,7 +211,8 @@ mlir::SparseTensorStorageTupleExpander::SparseTensorStorageTupleExpander() {
 /// to expand compounded sparse tensor tuples.
 void mlir::populateSparseTensorStorageExpansionPatterns(
     TypeConverter &typeConverter, RewritePatternSet &patterns) {
-  patterns.add<SparseStorageGetConverter, SparseStorageSetConverter,
-               SparseStorageReturnConverter, SparseStorageCallConverter>(
-      typeConverter, patterns.getContext());
+  patterns.add<SparseStorageConversion, SparseStorageGetConverter,
+               SparseStorageSetConverter, SparseStorageReturnConverter,
+               SparseStorageCallConverter>(typeConverter,
+                                           patterns.getContext());
 }

diff  --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir
index a5500a8d351c2..8c7968022e6f6 100644
--- a/mlir/test/Dialect/SparseTensor/codegen.mlir
+++ b/mlir/test/Dialect/SparseTensor/codegen.mlir
@@ -1,6 +1,5 @@
-// RUN: mlir-opt %s --sparse-tensor-codegen  --canonicalize --cse | FileCheck %s --check-prefix=CHECK-CODEGEN
-// FIXME:
-// R_U_N: mlir-opt %s --sparse-tensor-codegen --sparse-tensor-storage-expansion --canonicalize --cse | FileCheck %s --check-prefix=CHECK-STORAGE
+// RUN: mlir-opt %s --sparse-tensor-codegen  --canonicalize --cse | FileCheck %s --check-prefixes=CHECK,CHECK-CODEGEN
+// RUN: mlir-opt %s --sparse-tensor-codegen --sparse-tensor-storage-expansion --canonicalize --cse | FileCheck %s --check-prefixes=CHECK,CHECK-STORAGE
 
 #SparseVector = #sparse_tensor.encoding<{
   dimLevelType = [ "compressed" ],
@@ -263,43 +262,49 @@ func.func @sparse_dealloc_csr(%arg0: tensor<?x?xf64, #CSR>) {
   return
 }
 
-// CHECK-CODEGEN-LABEL: func @sparse_alloc_csc(
-//  CHECK-CODEGEN-SAME: %[[A:.*]]: index)
-//   CHECK-CODEGEN-DAG: %[[C0:.*]] = arith.constant 0 : index
-//   CHECK-CODEGEN-DAG: %[[C1:.*]] = arith.constant 1 : index
-//   CHECK-CODEGEN-DAG: %[[C10:.*]] = arith.constant 10 : index
-//      CHECK-CODEGEN:  %[[T0:.*]] = memref.alloc() : memref<2xindex>
-//      CHECK-CODEGEN:  memref.store %[[A]], %[[T0]][%[[C0]]] : memref<2xindex>
-//      CHECK-CODEGEN:  memref.store %[[C10]], %[[T0]][%[[C1]]] : memref<2xindex>
-//      CHECK-CODEGEN:  %[[T1:.*]] = memref.alloc() : memref<1xindex>
-//      CHECK-CODEGEN:  %[[T2:.*]] = memref.cast %[[T1]] : memref<1xindex> to memref<?xindex>
-//      CHECK-CODEGEN:  %[[T3:.*]] = memref.alloc() : memref<1xindex>
-//      CHECK-CODEGEN:  %[[T4:.*]] = memref.cast %[[T3]] : memref<1xindex> to memref<?xindex>
-//      CHECK-CODEGEN:  %[[T5:.*]] = memref.alloc() : memref<1xf64>
-//      CHECK-CODEGEN:  %[[T6:.*]] = memref.cast %[[T5]] : memref<1xf64> to memref<?xf64>
-//      CHECK-CODEGEN:  %[[T:.*]] = sparse_tensor.storage(%[[T0]], %[[T2]], %[[T4]], %[[T6]])
-//      CHECK-CODEGEN:  return %[[T]] : tuple<memref<2xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>>
+//        CHECK-LABEL: func @sparse_alloc_csc(
+//         CHECK-SAME: %[[A:.*]]: index) ->
+// CHECK-CODEGEN-SAME: tuple<memref<2xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>>
+// CHECK-STORAGE-SAME: memref<2xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>
+//          CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+//          CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+//          CHECK-DAG: %[[C10:.*]] = arith.constant 10 : index
+//              CHECK: %[[T0:.*]] = memref.alloc() : memref<2xindex>
+//              CHECK: memref.store %[[A]], %[[T0]][%[[C0]]] : memref<2xindex>
+//              CHECK: memref.store %[[C10]], %[[T0]][%[[C1]]] : memref<2xindex>
+//              CHECK: %[[T1:.*]] = memref.alloc() : memref<1xindex>
+//              CHECK: %[[T2:.*]] = memref.cast %[[T1]] : memref<1xindex> to memref<?xindex>
+//              CHECK: %[[T3:.*]] = memref.alloc() : memref<1xindex>
+//              CHECK: %[[T4:.*]] = memref.cast %[[T3]] : memref<1xindex> to memref<?xindex>
+//              CHECK: %[[T5:.*]] = memref.alloc() : memref<1xf64>
+//              CHECK: %[[T6:.*]] = memref.cast %[[T5]] : memref<1xf64> to memref<?xf64>
+//      CHECK-CODEGEN: %[[T:.*]] = sparse_tensor.storage(%[[T0]], %[[T2]], %[[T4]], %[[T6]])
+//      CHECK-CODEGEN: return %[[T]]
+//      CHECK-STORAGE: return %[[T0]], %[[T2]], %[[T4]], %[[T6]] 
 func.func @sparse_alloc_csc(%arg0: index) -> tensor<10x?xf64, #CSC> {
   %0 = bufferization.alloc_tensor(%arg0) : tensor<10x?xf64, #CSC>
   %1 = sparse_tensor.load %0 : tensor<10x?xf64, #CSC>
   return %1 : tensor<10x?xf64, #CSC>
 }
 
-// CHECK-CODEGEN-LABEL: func @sparse_alloc_3d() -> tuple<memref<3xindex>, memref<?xf64>>
-//   CHECK-CODEGEN-DAG: %[[C0:.*]] = arith.constant 0 : index
-//   CHECK-CODEGEN-DAG: %[[C1:.*]] = arith.constant 1 : index
-//   CHECK-CODEGEN-DAG: %[[C2:.*]] = arith.constant 2 : index
-//   CHECK-CODEGEN-DAG: %[[C10:.*]] = arith.constant 10 : index
-//   CHECK-CODEGEN-DAG: %[[C20:.*]] = arith.constant 20 : index
-//   CHECK-CODEGEN-DAG: %[[C30:.*]] = arith.constant 30 : index
-//       CHECK-CODEGEN: %[[A0:.*]] = memref.alloc() : memref<3xindex>
-//       CHECK-CODEGEN: memref.store %[[C30]], %[[A0]][%[[C0]]] : memref<3xindex>
-//       CHECK-CODEGEN: memref.store %[[C10]], %[[A0]][%[[C1]]] : memref<3xindex>
-//       CHECK-CODEGEN: memref.store %[[C20]], %[[A0]][%[[C2]]] : memref<3xindex>
-//       CHECK-CODEGEN: %[[A:.*]] = memref.alloc() : memref<6000xf64>
-//       CHECK-CODEGEN: %[[A1:.*]] = memref.cast %[[A]] : memref<6000xf64> to memref<?xf64>
-//       CHECK-CODEGEN: %[[T:.*]] = sparse_tensor.storage(%[[A0]], %[[A1]])
-//       CHECK-CODEGEN: return %[[T]] : tuple<memref<3xindex>, memref<?xf64>>
+//        CHECK-LABEL: func @sparse_alloc_3d() ->
+// CHECK-CODEGEN-SAME: tuple<memref<3xindex>, memref<?xf64>>
+// CHECK-STORAGE-SAME: memref<3xindex>, memref<?xf64>
+//          CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+//          CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+//          CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+//          CHECK-DAG: %[[C10:.*]] = arith.constant 10 : index
+//          CHECK-DAG: %[[C20:.*]] = arith.constant 20 : index
+//          CHECK-DAG: %[[C30:.*]] = arith.constant 30 : index
+//              CHECK: %[[A0:.*]] = memref.alloc() : memref<3xindex>
+//              CHECK: memref.store %[[C30]], %[[A0]][%[[C0]]] : memref<3xindex>
+//              CHECK: memref.store %[[C10]], %[[A0]][%[[C1]]] : memref<3xindex>
+//              CHECK: memref.store %[[C20]], %[[A0]][%[[C2]]] : memref<3xindex>
+//              CHECK: %[[A:.*]] = memref.alloc() : memref<6000xf64>
+//              CHECK: %[[A1:.*]] = memref.cast %[[A]] : memref<6000xf64> to memref<?xf64>
+//      CHECK-CODEGEN: %[[T:.*]] = sparse_tensor.storage(%[[A0]], %[[A1]])
+//      CHECK-CODEGEN: return %[[T]] : tuple<memref<3xindex>, memref<?xf64>>
+//      CHECK-STORAGE: return %[[A0]], %[[A1]] : memref<3xindex>, memref<?xf64>
 func.func @sparse_alloc_3d() -> tensor<10x20x30xf64, #Dense3D> {
   %0 = bufferization.alloc_tensor() : tensor<10x20x30xf64, #Dense3D>
   %1 = sparse_tensor.load %0 : tensor<10x20x30xf64, #Dense3D>

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_tensor_storage.mlir b/mlir/test/Dialect/SparseTensor/sparse_tensor_storage.mlir
index 87391978b0674..d2d4769353a3c 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_tensor_storage.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_tensor_storage.mlir
@@ -13,8 +13,8 @@ func.func @sparse_storage_expand(%arg0: tuple<memref<?xf64>, memref<?xf64>, f64>
 // CHECK-LABEL:  func @call_sparse_storage_expand(
 // CHECK-SAME:     %[[TMP_arg0:.*0]]: memref<?xf64>,
 // CHECK-SAME:     %[[TMP_arg1:.*1]]: memref<?xf64>,
-// CHECK-SAME:     %[[TMP_arg2:.*]]: f64) 
-// CHECK:          %[[TMP_0:.*]]:3 = call @sparse_storage_expand(%[[TMP_arg0]], %[[TMP_arg1]], %[[TMP_arg2]]) 
+// CHECK-SAME:     %[[TMP_arg2:.*]]: f64)
+// CHECK:          %[[TMP_0:.*]]:3 = call @sparse_storage_expand(%[[TMP_arg0]], %[[TMP_arg1]], %[[TMP_arg2]])
 // CHECK:          return %[[TMP_0]]#0, %[[TMP_0]]#1, %[[TMP_0]]#2 : memref<?xf64>, memref<?xf64>, f64
 func.func @call_sparse_storage_expand(%arg0: tuple<memref<?xf64>, memref<?xf64>, f64>)
                                           -> tuple<memref<?xf64>, memref<?xf64>, f64> {
@@ -23,10 +23,21 @@ func.func @call_sparse_storage_expand(%arg0: tuple<memref<?xf64>, memref<?xf64>,
   return %1 : tuple<memref<?xf64>, memref<?xf64>, f64>
 }
 
+// CHECK-LABEL: func @sparse_storage(
+// CHECK-SAME:    %[[TMP_arg0:.*0]]: memref<?xf64>,
+// CHECK-SAME:    %[[TMP_arg1:.*1]]: memref<?xf64>,
+// CHECK-SAME:    %[[TMP_arg2:.*2]]: memref<?xf64>)
+// CHECK:         return %[[TMP_arg0]], %[[TMP_arg1]], %[[TMP_arg2]]
+func.func @sparse_storage(%arg0: memref<?xf64>, %arg1: memref<?xf64>, %arg2: memref<?xf64>)
+                        -> tuple<memref<?xf64>, memref<?xf64>, memref<?xf64>> {
+  %1 = sparse_tensor.storage(%arg0, %arg1, %arg2) : memref<?xf64>, memref<?xf64>, memref<?xf64> to tuple<memref<?xf64>, memref<?xf64>, memref<?xf64>>
+  return %1 : tuple<memref<?xf64>, memref<?xf64>, memref<?xf64>>
+}
+
 // CHECK-LABEL:  func @sparse_storage_get(
 // CHECK-SAME:     %[[TMP_arg0:.*0]]: memref<?xf64>,
 // CHECK-SAME:     %[[TMP_arg1:.*1]]: memref<?xf64>,
-// CHECK-SAME:     %[[TMP_arg2:.*]]: f64) 
+// CHECK-SAME:     %[[TMP_arg2:.*]]: f64)
 // CHECK:          return %[[TMP_arg0]] : memref<?xf64>
 func.func @sparse_storage_get(%arg0: tuple<memref<?xf64>, memref<?xf64>, f64>) -> memref<?xf64> {
   %0 = sparse_tensor.storage_get %arg0[0]
@@ -38,7 +49,7 @@ func.func @sparse_storage_get(%arg0: tuple<memref<?xf64>, memref<?xf64>, f64>) -
 // CHECK-SAME:     %[[TMP_arg0:.*0]]: memref<?xf64>,
 // CHECK-SAME:     %[[TMP_arg1:.*1]]: memref<?xf64>,
 // CHECK-SAME:     %[[TMP_arg2:.*]]: f64,
-// CHECK-SAME:     %[[TMP_arg3:.*]]: memref<?xf64>) 
+// CHECK-SAME:     %[[TMP_arg3:.*]]: memref<?xf64>)
 // CHECK:          return %[[TMP_arg3]], %[[TMP_arg1]], %[[TMP_arg2]] : memref<?xf64>, memref<?xf64>, f64
 func.func @sparse_storage_set(%arg0: tuple<memref<?xf64>, memref<?xf64>, f64>,
                               %arg1: memref<?xf64>) -> tuple<memref<?xf64>, memref<?xf64>, f64> {


        


More information about the Mlir-commits mailing list