[Mlir-commits] [mlir] d50d678 - [mlir][sparse] Add lowering rules for sparse_tensor.storage Op
Peiming Liu
llvmlistbot at llvm.org
Tue Sep 6 14:04:24 PDT 2022
Author: Peiming Liu
Date: 2022-09-06T21:04:16Z
New Revision: d50d67885452eea43dc45f3950a7f59a232bc35d
URL: https://github.com/llvm/llvm-project/commit/d50d67885452eea43dc45f3950a7f59a232bc35d
DIFF: https://github.com/llvm/llvm-project/commit/d50d67885452eea43dc45f3950a7f59a232bc35d.diff
LOG: [mlir][sparse] Add lowering rules for sparse_tensor.storage Op
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D133368
Added:
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageExpansion.cpp
mlir/test/Dialect/SparseTensor/codegen.mlir
mlir/test/Dialect/SparseTensor/sparse_tensor_storage.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageExpansion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageExpansion.cpp
index 31370e32cb063..1f7afa1d77804 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageExpansion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageExpansion.cpp
@@ -65,6 +65,22 @@ static void flattenOperands(ValueRange operands,
// Conversion rules.
//===----------------------------------------------------------------------===//
+/// Sparse tensor storage conversion rule for sparse_tensor::storage.
+class SparseStorageConversion : public OpConversionPattern<StorageOp> {
+public:
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(StorageOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // Simply convert it to a unrealize_conversion_cast.
+ // We should guarantee that all uses of sparse_tensor.storage op will
+ // be eventually eliminated by accessing the flattened SSA values directly.
+ rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
+ op, TypeRange{op.getType()}, adaptor.getInputs());
+ return success();
+ }
+};
+
/// Sparse tensor storage conversion rule for sparse_tensor::storage_get.
class SparseStorageGetConverter : public OpConversionPattern<StorageGetOp> {
public:
@@ -195,7 +211,8 @@ mlir::SparseTensorStorageTupleExpander::SparseTensorStorageTupleExpander() {
/// to expand compounded sparse tensor tuples.
void mlir::populateSparseTensorStorageExpansionPatterns(
TypeConverter &typeConverter, RewritePatternSet &patterns) {
- patterns.add<SparseStorageGetConverter, SparseStorageSetConverter,
- SparseStorageReturnConverter, SparseStorageCallConverter>(
- typeConverter, patterns.getContext());
+ patterns.add<SparseStorageConversion, SparseStorageGetConverter,
+ SparseStorageSetConverter, SparseStorageReturnConverter,
+ SparseStorageCallConverter>(typeConverter,
+ patterns.getContext());
}
diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir
index a5500a8d351c2..8c7968022e6f6 100644
--- a/mlir/test/Dialect/SparseTensor/codegen.mlir
+++ b/mlir/test/Dialect/SparseTensor/codegen.mlir
@@ -1,6 +1,5 @@
-// RUN: mlir-opt %s --sparse-tensor-codegen --canonicalize --cse | FileCheck %s --check-prefix=CHECK-CODEGEN
-// FIXME:
-// R_U_N: mlir-opt %s --sparse-tensor-codegen --sparse-tensor-storage-expansion --canonicalize --cse | FileCheck %s --check-prefix=CHECK-STORAGE
+// RUN: mlir-opt %s --sparse-tensor-codegen --canonicalize --cse | FileCheck %s --check-prefixes=CHECK,CHECK-CODEGEN
+// RUN: mlir-opt %s --sparse-tensor-codegen --sparse-tensor-storage-expansion --canonicalize --cse | FileCheck %s --check-prefixes=CHECK,CHECK-STORAGE
#SparseVector = #sparse_tensor.encoding<{
dimLevelType = [ "compressed" ],
@@ -263,43 +262,49 @@ func.func @sparse_dealloc_csr(%arg0: tensor<?x?xf64, #CSR>) {
return
}
-// CHECK-CODEGEN-LABEL: func @sparse_alloc_csc(
-// CHECK-CODEGEN-SAME: %[[A:.*]]: index)
-// CHECK-CODEGEN-DAG: %[[C0:.*]] = arith.constant 0 : index
-// CHECK-CODEGEN-DAG: %[[C1:.*]] = arith.constant 1 : index
-// CHECK-CODEGEN-DAG: %[[C10:.*]] = arith.constant 10 : index
-// CHECK-CODEGEN: %[[T0:.*]] = memref.alloc() : memref<2xindex>
-// CHECK-CODEGEN: memref.store %[[A]], %[[T0]][%[[C0]]] : memref<2xindex>
-// CHECK-CODEGEN: memref.store %[[C10]], %[[T0]][%[[C1]]] : memref<2xindex>
-// CHECK-CODEGEN: %[[T1:.*]] = memref.alloc() : memref<1xindex>
-// CHECK-CODEGEN: %[[T2:.*]] = memref.cast %[[T1]] : memref<1xindex> to memref<?xindex>
-// CHECK-CODEGEN: %[[T3:.*]] = memref.alloc() : memref<1xindex>
-// CHECK-CODEGEN: %[[T4:.*]] = memref.cast %[[T3]] : memref<1xindex> to memref<?xindex>
-// CHECK-CODEGEN: %[[T5:.*]] = memref.alloc() : memref<1xf64>
-// CHECK-CODEGEN: %[[T6:.*]] = memref.cast %[[T5]] : memref<1xf64> to memref<?xf64>
-// CHECK-CODEGEN: %[[T:.*]] = sparse_tensor.storage(%[[T0]], %[[T2]], %[[T4]], %[[T6]])
-// CHECK-CODEGEN: return %[[T]] : tuple<memref<2xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>>
+// CHECK-LABEL: func @sparse_alloc_csc(
+// CHECK-SAME: %[[A:.*]]: index) ->
+// CHECK-CODEGEN-SAME: tuple<memref<2xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>>
+// CHECK-STORAGE-SAME: memref<2xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C10:.*]] = arith.constant 10 : index
+// CHECK: %[[T0:.*]] = memref.alloc() : memref<2xindex>
+// CHECK: memref.store %[[A]], %[[T0]][%[[C0]]] : memref<2xindex>
+// CHECK: memref.store %[[C10]], %[[T0]][%[[C1]]] : memref<2xindex>
+// CHECK: %[[T1:.*]] = memref.alloc() : memref<1xindex>
+// CHECK: %[[T2:.*]] = memref.cast %[[T1]] : memref<1xindex> to memref<?xindex>
+// CHECK: %[[T3:.*]] = memref.alloc() : memref<1xindex>
+// CHECK: %[[T4:.*]] = memref.cast %[[T3]] : memref<1xindex> to memref<?xindex>
+// CHECK: %[[T5:.*]] = memref.alloc() : memref<1xf64>
+// CHECK: %[[T6:.*]] = memref.cast %[[T5]] : memref<1xf64> to memref<?xf64>
+// CHECK-CODEGEN: %[[T:.*]] = sparse_tensor.storage(%[[T0]], %[[T2]], %[[T4]], %[[T6]])
+// CHECK-CODEGEN: return %[[T]]
+// CHECK-STORAGE: return %[[T0]], %[[T2]], %[[T4]], %[[T6]]
func.func @sparse_alloc_csc(%arg0: index) -> tensor<10x?xf64, #CSC> {
%0 = bufferization.alloc_tensor(%arg0) : tensor<10x?xf64, #CSC>
%1 = sparse_tensor.load %0 : tensor<10x?xf64, #CSC>
return %1 : tensor<10x?xf64, #CSC>
}
-// CHECK-CODEGEN-LABEL: func @sparse_alloc_3d() -> tuple<memref<3xindex>, memref<?xf64>>
-// CHECK-CODEGEN-DAG: %[[C0:.*]] = arith.constant 0 : index
-// CHECK-CODEGEN-DAG: %[[C1:.*]] = arith.constant 1 : index
-// CHECK-CODEGEN-DAG: %[[C2:.*]] = arith.constant 2 : index
-// CHECK-CODEGEN-DAG: %[[C10:.*]] = arith.constant 10 : index
-// CHECK-CODEGEN-DAG: %[[C20:.*]] = arith.constant 20 : index
-// CHECK-CODEGEN-DAG: %[[C30:.*]] = arith.constant 30 : index
-// CHECK-CODEGEN: %[[A0:.*]] = memref.alloc() : memref<3xindex>
-// CHECK-CODEGEN: memref.store %[[C30]], %[[A0]][%[[C0]]] : memref<3xindex>
-// CHECK-CODEGEN: memref.store %[[C10]], %[[A0]][%[[C1]]] : memref<3xindex>
-// CHECK-CODEGEN: memref.store %[[C20]], %[[A0]][%[[C2]]] : memref<3xindex>
-// CHECK-CODEGEN: %[[A:.*]] = memref.alloc() : memref<6000xf64>
-// CHECK-CODEGEN: %[[A1:.*]] = memref.cast %[[A]] : memref<6000xf64> to memref<?xf64>
-// CHECK-CODEGEN: %[[T:.*]] = sparse_tensor.storage(%[[A0]], %[[A1]])
-// CHECK-CODEGEN: return %[[T]] : tuple<memref<3xindex>, memref<?xf64>>
+// CHECK-LABEL: func @sparse_alloc_3d() ->
+// CHECK-CODEGEN-SAME: tuple<memref<3xindex>, memref<?xf64>>
+// CHECK-STORAGE-SAME: memref<3xindex>, memref<?xf64>
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[C10:.*]] = arith.constant 10 : index
+// CHECK-DAG: %[[C20:.*]] = arith.constant 20 : index
+// CHECK-DAG: %[[C30:.*]] = arith.constant 30 : index
+// CHECK: %[[A0:.*]] = memref.alloc() : memref<3xindex>
+// CHECK: memref.store %[[C30]], %[[A0]][%[[C0]]] : memref<3xindex>
+// CHECK: memref.store %[[C10]], %[[A0]][%[[C1]]] : memref<3xindex>
+// CHECK: memref.store %[[C20]], %[[A0]][%[[C2]]] : memref<3xindex>
+// CHECK: %[[A:.*]] = memref.alloc() : memref<6000xf64>
+// CHECK: %[[A1:.*]] = memref.cast %[[A]] : memref<6000xf64> to memref<?xf64>
+// CHECK-CODEGEN: %[[T:.*]] = sparse_tensor.storage(%[[A0]], %[[A1]])
+// CHECK-CODEGEN: return %[[T]] : tuple<memref<3xindex>, memref<?xf64>>
+// CHECK-STORAGE: return %[[A0]], %[[A1]] : memref<3xindex>, memref<?xf64>
func.func @sparse_alloc_3d() -> tensor<10x20x30xf64, #Dense3D> {
%0 = bufferization.alloc_tensor() : tensor<10x20x30xf64, #Dense3D>
%1 = sparse_tensor.load %0 : tensor<10x20x30xf64, #Dense3D>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_tensor_storage.mlir b/mlir/test/Dialect/SparseTensor/sparse_tensor_storage.mlir
index 87391978b0674..d2d4769353a3c 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_tensor_storage.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_tensor_storage.mlir
@@ -13,8 +13,8 @@ func.func @sparse_storage_expand(%arg0: tuple<memref<?xf64>, memref<?xf64>, f64>
// CHECK-LABEL: func @call_sparse_storage_expand(
// CHECK-SAME: %[[TMP_arg0:.*0]]: memref<?xf64>,
// CHECK-SAME: %[[TMP_arg1:.*1]]: memref<?xf64>,
-// CHECK-SAME: %[[TMP_arg2:.*]]: f64)
-// CHECK: %[[TMP_0:.*]]:3 = call @sparse_storage_expand(%[[TMP_arg0]], %[[TMP_arg1]], %[[TMP_arg2]])
+// CHECK-SAME: %[[TMP_arg2:.*]]: f64)
+// CHECK: %[[TMP_0:.*]]:3 = call @sparse_storage_expand(%[[TMP_arg0]], %[[TMP_arg1]], %[[TMP_arg2]])
// CHECK: return %[[TMP_0]]#0, %[[TMP_0]]#1, %[[TMP_0]]#2 : memref<?xf64>, memref<?xf64>, f64
func.func @call_sparse_storage_expand(%arg0: tuple<memref<?xf64>, memref<?xf64>, f64>)
-> tuple<memref<?xf64>, memref<?xf64>, f64> {
@@ -23,10 +23,21 @@ func.func @call_sparse_storage_expand(%arg0: tuple<memref<?xf64>, memref<?xf64>,
return %1 : tuple<memref<?xf64>, memref<?xf64>, f64>
}
+// CHECK-LABEL: func @sparse_storage(
+// CHECK-SAME: %[[TMP_arg0:.*0]]: memref<?xf64>,
+// CHECK-SAME: %[[TMP_arg1:.*1]]: memref<?xf64>,
+// CHECK-SAME: %[[TMP_arg2:.*2]]: memref<?xf64>)
+// CHECK: return %[[TMP_arg0]], %[[TMP_arg1]], %[[TMP_arg2]]
+func.func @sparse_storage(%arg0: memref<?xf64>, %arg1: memref<?xf64>, %arg2: memref<?xf64>)
+ -> tuple<memref<?xf64>, memref<?xf64>, memref<?xf64>> {
+ %1 = sparse_tensor.storage(%arg0, %arg1, %arg2) : memref<?xf64>, memref<?xf64>, memref<?xf64> to tuple<memref<?xf64>, memref<?xf64>, memref<?xf64>>
+ return %1 : tuple<memref<?xf64>, memref<?xf64>, memref<?xf64>>
+}
+
// CHECK-LABEL: func @sparse_storage_get(
// CHECK-SAME: %[[TMP_arg0:.*0]]: memref<?xf64>,
// CHECK-SAME: %[[TMP_arg1:.*1]]: memref<?xf64>,
-// CHECK-SAME: %[[TMP_arg2:.*]]: f64)
+// CHECK-SAME: %[[TMP_arg2:.*]]: f64)
// CHECK: return %[[TMP_arg0]] : memref<?xf64>
func.func @sparse_storage_get(%arg0: tuple<memref<?xf64>, memref<?xf64>, f64>) -> memref<?xf64> {
%0 = sparse_tensor.storage_get %arg0[0]
@@ -38,7 +49,7 @@ func.func @sparse_storage_get(%arg0: tuple<memref<?xf64>, memref<?xf64>, f64>) -
// CHECK-SAME: %[[TMP_arg0:.*0]]: memref<?xf64>,
// CHECK-SAME: %[[TMP_arg1:.*1]]: memref<?xf64>,
// CHECK-SAME: %[[TMP_arg2:.*]]: f64,
-// CHECK-SAME: %[[TMP_arg3:.*]]: memref<?xf64>)
+// CHECK-SAME: %[[TMP_arg3:.*]]: memref<?xf64>)
// CHECK: return %[[TMP_arg3]], %[[TMP_arg1]], %[[TMP_arg2]] : memref<?xf64>, memref<?xf64>, f64
func.func @sparse_storage_set(%arg0: tuple<memref<?xf64>, memref<?xf64>, f64>,
%arg1: memref<?xf64>) -> tuple<memref<?xf64>, memref<?xf64>, f64> {
More information about the Mlir-commits
mailing list