[Mlir-commits] [mlir] 509974a - [mlir][sparse] add folder to sparse_tensor.storage.get operation.
Peiming Liu
llvmlistbot at llvm.org
Thu Dec 15 15:51:44 PST 2022
Author: Peiming Liu
Date: 2022-12-15T23:51:38Z
New Revision: 509974af0260bb8540f4a384c84a3c956a5c700c
URL: https://github.com/llvm/llvm-project/commit/509974af0260bb8540f4a384c84a3c956a5c700c
DIFF: https://github.com/llvm/llvm-project/commit/509974af0260bb8540f4a384c84a3c956a5c700c.diff
LOG: [mlir][sparse] add folder to sparse_tensor.storage.get operation.
Reviewed By: aartbik, wrengr
Differential Revision: https://reviews.llvm.org/D140172
Added:
Modified:
mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
mlir/test/Dialect/SparseTensor/fold.mlir
mlir/test/Dialect/SparseTensor/roundtrip.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 30f864fd20ff..13c6e033c13a 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -214,6 +214,7 @@ def SparseTensor_GetStorageSpecifierOp : SparseTensor_Op<"storage_specifier.get"
let assemblyFormat = "$specifier $specifierKind (`at` $dim^)? attr-dict `:` "
"qualified(type($specifier)) `to` type($result)";
let hasVerifier = 1;
+ let hasFolder = 1;
}
def SparseTensor_SetStorageSpecifierOp : SparseTensor_Op<"storage_specifier.set",
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 9da4ac8f92a1..628b391bdaaa 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -493,6 +493,20 @@ LogicalResult GetStorageSpecifierOp::verify() {
return success();
}
+template <typename SpecifierOp>
+static SetStorageSpecifierOp getSpecifierSetDef(SpecifierOp op) {
+ return op.getSpecifier().template getDefiningOp<SetStorageSpecifierOp>();
+}
+
+OpFoldResult GetStorageSpecifierOp::fold(ArrayRef<Attribute> operands) {
+ StorageSpecifierKind kind = getSpecifierKind();
+ Optional<APInt> dim = getDim();
+ for (auto op = getSpecifierSetDef(*this); op; op = getSpecifierSetDef(op))
+ if (kind == op.getSpecifierKind() && dim == op.getDim())
+ return op.getValue();
+ return {};
+}
+
LogicalResult SetStorageSpecifierOp::verify() {
if (failed(verifySparsifierGetterSetter(getSpecifierKind(), getDim(),
getSpecifier(), getOperation()))) {
diff --git a/mlir/test/Dialect/SparseTensor/fold.mlir b/mlir/test/Dialect/SparseTensor/fold.mlir
index 6b1ebb173e24..7397c0b22958 100644
--- a/mlir/test/Dialect/SparseTensor/fold.mlir
+++ b/mlir/test/Dialect/SparseTensor/fold.mlir
@@ -32,6 +32,7 @@ func.func @sparse_dce_getters(%arg0: tensor<64xf32, #SparseVector>) {
%2 = sparse_tensor.values %arg0 : tensor<64xf32, #SparseVector> to memref<?xf32>
return
}
+
// CHECK-LABEL: func @sparse_concat_dce(
// CHECK-NOT: sparse_tensor.concatenate
// CHECK: return
@@ -45,3 +46,19 @@ func.func @sparse_concat_dce(%arg0: tensor<2xf64, #SparseVector>,
return
}
+// CHECK-LABEL: func @sparse_get_specifier_dce_fold(
+// CHECK-SAME: %[[A0:.*]]: !sparse_tensor.storage_specifier
+// CHECK-SAME: %[[A1:.*]]: i64,
+// CHECK-SAME: %[[A2:.*]]: i64)
+// CHECK-NOT: sparse_tensor.storage_specifier.set
+// CHECK-NOT: sparse_tensor.storage_specifier.get
+// CHECK: return %[[A1]]
+func.func @sparse_get_specifier_dce_fold(%arg0: !sparse_tensor.storage_specifier<#SparseVector>, %arg1: i64, %arg2: i64) -> i64 {
+ %0 = sparse_tensor.storage_specifier.set %arg0 dim_sz at 0 with %arg1
+ : i64, !sparse_tensor.storage_specifier<#SparseVector>
+ %1 = sparse_tensor.storage_specifier.set %0 ptr_mem_sz at 0 with %arg2
+ : i64, !sparse_tensor.storage_specifier<#SparseVector>
+ %2 = sparse_tensor.storage_specifier.get %1 dim_sz at 0
+ : !sparse_tensor.storage_specifier<#SparseVector> to i64
+ return %2 : i64
+}
diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
index 07f6e85a067b..48b4509d6ac1 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
@@ -133,7 +133,7 @@ func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#SparseVector>)
#SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}>
// CHECK-LABEL: func @sparse_set_md(
-// CHECK-SAME: %[[A:.*]]: !sparse_tensor.storage_specifier<#{{.*}}>,
+// CHECK-SAME: %[[A:.*]]: !sparse_tensor.storage_specifier<#{{.*}}>,
// CHECK-SAME: %[[I:.*]]: i64)
// CHECK: %[[T:.*]] = sparse_tensor.storage_specifier.set %[[A]] dim_sz at 0 with %[[I]]
// CHECK: return %[[T]] : !sparse_tensor.storage_specifier<#{{.*}}>
@@ -553,4 +553,3 @@ func.func @sparse_sort_coo_stable(%arg0: index, %arg1: memref<?xi64>, %arg2: mem
sparse_tensor.sort_coo stable %arg0, %arg1 jointly %arg2 { nx=2 : index, ny=1 : index}: memref<?xi64> jointly memref<?xf32>
return %arg1, %arg2 : memref<?xi64>, memref<?xf32>
}
-
More information about the Mlir-commits
mailing list