[Mlir-commits] [mlir] 509974a - [mlir][sparse] add folder to sparse_tensor.storage.get operation.

Peiming Liu llvmlistbot at llvm.org
Thu Dec 15 15:51:44 PST 2022


Author: Peiming Liu
Date: 2022-12-15T23:51:38Z
New Revision: 509974af0260bb8540f4a384c84a3c956a5c700c

URL: https://github.com/llvm/llvm-project/commit/509974af0260bb8540f4a384c84a3c956a5c700c
DIFF: https://github.com/llvm/llvm-project/commit/509974af0260bb8540f4a384c84a3c956a5c700c.diff

LOG: [mlir][sparse] add folder to sparse_tensor.storage.get operation.

Reviewed By: aartbik, wrengr

Differential Revision: https://reviews.llvm.org/D140172

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
    mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
    mlir/test/Dialect/SparseTensor/fold.mlir
    mlir/test/Dialect/SparseTensor/roundtrip.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 30f864fd20ff..13c6e033c13a 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -214,6 +214,7 @@ def SparseTensor_GetStorageSpecifierOp : SparseTensor_Op<"storage_specifier.get"
   let assemblyFormat = "$specifier $specifierKind (`at` $dim^)? attr-dict `:` "
                        "qualified(type($specifier)) `to` type($result)";
   let hasVerifier = 1;
+  let hasFolder = 1;
 }
 
 def SparseTensor_SetStorageSpecifierOp : SparseTensor_Op<"storage_specifier.set",

diff  --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 9da4ac8f92a1..628b391bdaaa 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -493,6 +493,20 @@ LogicalResult GetStorageSpecifierOp::verify() {
   return success();
 }
 
+template <typename SpecifierOp>
+static SetStorageSpecifierOp getSpecifierSetDef(SpecifierOp op) {
+  return op.getSpecifier().template getDefiningOp<SetStorageSpecifierOp>();
+}
+
+OpFoldResult GetStorageSpecifierOp::fold(ArrayRef<Attribute> operands) {
+  StorageSpecifierKind kind = getSpecifierKind();
+  Optional<APInt> dim = getDim();
+  for (auto op = getSpecifierSetDef(*this); op; op = getSpecifierSetDef(op))
+    if (kind == op.getSpecifierKind() && dim == op.getDim())
+      return op.getValue();
+  return {};
+}
+
 LogicalResult SetStorageSpecifierOp::verify() {
   if (failed(verifySparsifierGetterSetter(getSpecifierKind(), getDim(),
                                           getSpecifier(), getOperation()))) {

diff  --git a/mlir/test/Dialect/SparseTensor/fold.mlir b/mlir/test/Dialect/SparseTensor/fold.mlir
index 6b1ebb173e24..7397c0b22958 100644
--- a/mlir/test/Dialect/SparseTensor/fold.mlir
+++ b/mlir/test/Dialect/SparseTensor/fold.mlir
@@ -32,6 +32,7 @@ func.func @sparse_dce_getters(%arg0: tensor<64xf32, #SparseVector>) {
   %2 = sparse_tensor.values %arg0 : tensor<64xf32, #SparseVector> to memref<?xf32>
   return
 }
+
 // CHECK-LABEL: func @sparse_concat_dce(
 //   CHECK-NOT: sparse_tensor.concatenate
 //       CHECK: return
@@ -45,3 +46,19 @@ func.func @sparse_concat_dce(%arg0: tensor<2xf64, #SparseVector>,
   return
 }
 
+// CHECK-LABEL: func @sparse_get_specifier_dce_fold(
+//  CHECK-SAME:  %[[A0:.*]]: !sparse_tensor.storage_specifier
+//  CHECK-SAME:  %[[A1:.*]]: i64,
+//  CHECK-SAME:  %[[A2:.*]]: i64)
+//   CHECK-NOT:  sparse_tensor.storage_specifier.set
+//   CHECK-NOT:  sparse_tensor.storage_specifier.get
+//       CHECK:  return %[[A1]]
+func.func @sparse_get_specifier_dce_fold(%arg0: !sparse_tensor.storage_specifier<#SparseVector>, %arg1: i64, %arg2: i64) -> i64 {
+  %0 = sparse_tensor.storage_specifier.set %arg0 dim_sz at 0 with %arg1
+       : i64, !sparse_tensor.storage_specifier<#SparseVector>
+  %1 = sparse_tensor.storage_specifier.set %0 ptr_mem_sz at 0 with %arg2
+       : i64, !sparse_tensor.storage_specifier<#SparseVector>
+  %2 = sparse_tensor.storage_specifier.get %1 dim_sz at 0
+       : !sparse_tensor.storage_specifier<#SparseVector> to i64
+  return %2 : i64
+}

diff  --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
index 07f6e85a067b..48b4509d6ac1 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
@@ -133,7 +133,7 @@ func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#SparseVector>)
 #SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}>
 
 // CHECK-LABEL: func @sparse_set_md(
-//  CHECK-SAME: %[[A:.*]]: !sparse_tensor.storage_specifier<#{{.*}}>, 
+//  CHECK-SAME: %[[A:.*]]: !sparse_tensor.storage_specifier<#{{.*}}>,
 //  CHECK-SAME: %[[I:.*]]: i64)
 //       CHECK: %[[T:.*]] = sparse_tensor.storage_specifier.set %[[A]] dim_sz at 0 with %[[I]]
 //       CHECK: return %[[T]] : !sparse_tensor.storage_specifier<#{{.*}}>
@@ -553,4 +553,3 @@ func.func @sparse_sort_coo_stable(%arg0: index, %arg1: memref<?xi64>, %arg2: mem
   sparse_tensor.sort_coo stable %arg0, %arg1 jointly %arg2 { nx=2 : index, ny=1 : index}: memref<?xi64> jointly memref<?xf32>
   return %arg1, %arg2 : memref<?xi64>, memref<?xf32>
 }
-


        


More information about the Mlir-commits mailing list