[Mlir-commits] [mlir] debdf7e - [mlir][sparse] refine single condition set up for semi-ring ops
Aart Bik
llvmlistbot at llvm.org
Wed Jun 14 09:23:18 PDT 2023
Author: Aart Bik
Date: 2023-06-14T09:23:09-07:00
New Revision: debdf7e0ffe154d85c7c5aee7b2aa88a85eccc0d
URL: https://github.com/llvm/llvm-project/commit/debdf7e0ffe154d85c7c5aee7b2aa88a85eccc0d
DIFF: https://github.com/llvm/llvm-project/commit/debdf7e0ffe154d85c7c5aee7b2aa88a85eccc0d.diff
LOG: [mlir][sparse] refine single condition set up for semi-ring ops
Reviewed By: Peiming, K-Wu
Differential Revision: https://reviews.llvm.org/D152874
Added:
mlir/test/Dialect/SparseTensor/semi_ring.mlir
Modified:
mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index c546a7f5e1c5a..6ec5d42a78c36 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -538,9 +538,9 @@ bool Merger::isSingleCondition(TensorId t, ExprId e) const {
case TensorExp::Kind::kCIm:
case TensorExp::Kind::kCRe:
case TensorExp::Kind::kBitCast:
+ case TensorExp::Kind::kUnary:
return isSingleCondition(t, expr.children.e0);
case TensorExp::Kind::kBinaryBranch:
- case TensorExp::Kind::kUnary:
case TensorExp::Kind::kSelect:
return false;
// Binary operations.
@@ -559,6 +559,7 @@ bool Merger::isSingleCondition(TensorId t, ExprId e) const {
case TensorExp::Kind::kMulC:
case TensorExp::Kind::kMulI:
case TensorExp::Kind::kAndI:
+ case TensorExp::Kind::kReduce:
if (isSingleCondition(t, expr.children.e0))
return isSingleCondition(t, expr.children.e1) ||
isInvariant(expr.children.e1);
@@ -576,7 +577,6 @@ bool Merger::isSingleCondition(TensorId t, ExprId e) const {
case TensorExp::Kind::kOrI:
case TensorExp::Kind::kXorI:
case TensorExp::Kind::kBinary:
- case TensorExp::Kind::kReduce:
return false;
}
llvm_unreachable("unexpected kind");
@@ -783,6 +783,7 @@ void Merger::dumpExp(ExprId e) const {
llvm::dbgs() << " " << kindToOpSymbol(expr.kind) << " ";
dumpExp(expr.children.e1);
llvm::dbgs() << ")";
+ break;
}
}
@@ -917,11 +918,11 @@ LatSetId Merger::buildLattices(ExprId e, LoopId i) {
UnaryOp unop = cast<UnaryOp>(expr.op);
const LatSetId child0 = buildLattices(e0, i);
Region &absentRegion = unop.getAbsentRegion();
-
if (absentRegion.empty()) {
// Simple mapping over existing values.
return mapSet(kind, child0, Value(), unop);
- } // Use a disjunction with `unop` on the left and the absent value as an
+ }
+ // Use a disjunction with `unop` on the left and the absent value as an
// invariant on the right.
Block &absentBlock = absentRegion.front();
YieldOp absentYield = cast<YieldOp>(absentBlock.getTerminator());
diff --git a/mlir/test/Dialect/SparseTensor/semi_ring.mlir b/mlir/test/Dialect/SparseTensor/semi_ring.mlir
new file mode 100644
index 0000000000000..762ef5f678e43
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/semi_ring.mlir
@@ -0,0 +1,58 @@
+// RUN: mlir-opt %s -sparsification | FileCheck %s
+
+#SM = #sparse_tensor.encoding<{ lvlTypes = [ "dense", "compressed" ] }>
+
+#trait = {
+ indexing_maps = [
+ affine_map<(i,j) -> (i,j)> // A
+ ],
+ iterator_types = ["parallel", "parallel"],
+ doc = "A(i,j) += 2.0 where A(i,j) != 0"
+}
+
+module {
+ // Example of a semi-ring operation that only adds a
+ // constant at stored values (something that would
+ // typically not sparsify since it would densify the
+ // implicit zeros in the normal case). The sparse
+ // compiler should see that this is a "simply dynamic"
+ // operation, and the values can be change "in-place".
+ //
+ // CHECK-LABEL: func.func @add_only_where_nonzero(
+ // CHECK-SAME: %[[VAL_0:.*]]: tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>> {
+ // CHECK-DAG: %[[VAL_1:.*]] = arith.constant 8 : index
+ // CHECK-DAG: %[[VAL_2:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[VAL_3:.*]] = arith.constant 1 : index
+ // CHECK-DAG: %[[VAL_4:.*]] = arith.constant 2.000000e+00 : f64
+ // CHECK-DAG: %[[VAL_5:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
+ // CHECK-DAG: %[[VAL_6:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xf64>
+ // CHECK: scf.for %[[VAL_7:.*]] = %[[VAL_2]] to %[[VAL_1]] step %[[VAL_3]] {
+ // CHECK: %[[VAL_8:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_7]]] : memref<?xindex>
+ // CHECK: %[[VAL_9:.*]] = arith.addi %[[VAL_7]], %[[VAL_3]] : index
+ // CHECK: %[[VAL_10:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_9]]] : memref<?xindex>
+ // CHECK: scf.for %[[VAL_11:.*]] = %[[VAL_8]] to %[[VAL_10]] step %[[VAL_3]] {
+ // CHECK: %[[VAL_12:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_11]]] : memref<?xf64>
+ // CHECK: %[[VAL_13:.*]] = arith.addf %[[VAL_12]], %[[VAL_4]] : f64
+ // CHECK: memref.store %[[VAL_13]], %[[VAL_6]]{{\[}}%[[VAL_11]]] : memref<?xf64>
+ // CHECK: } {"Emitted from" = "linalg.generic"}
+ // CHECK: } {"Emitted from" = "linalg.generic"}
+ // CHECK: %[[VAL_14:.*]] = sparse_tensor.load %[[VAL_0]] : tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>>
+ // CHECK: return %[[VAL_14]] : tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>>
+ // CHECK: }
+ func.func @add_only_where_nonzero(%argA: tensor<8x8xf64, #SM>) -> tensor<8x8xf64, #SM> {
+ %c = arith.constant 2.0 : f64
+ %result = linalg.generic #trait
+ outs(%argA: tensor<8x8xf64, #SM>) {
+ ^bb(%a: f64):
+ %u = sparse_tensor.unary %a : f64 to f64
+ present={
+ ^bb0(%p: f64):
+ %add = arith.addf %p, %c : f64
+ sparse_tensor.yield %add : f64
+ }
+ absent={}
+ linalg.yield %u : f64
+ } -> tensor<8x8xf64, #SM>
+ return %result : tensor<8x8xf64, #SM>
+ }
+}
More information about the Mlir-commits
mailing list