[Mlir-commits] [mlir] debdf7e - [mlir][sparse] refine single condition set up for semi-ring ops

Aart Bik llvmlistbot at llvm.org
Wed Jun 14 09:23:18 PDT 2023


Author: Aart Bik
Date: 2023-06-14T09:23:09-07:00
New Revision: debdf7e0ffe154d85c7c5aee7b2aa88a85eccc0d

URL: https://github.com/llvm/llvm-project/commit/debdf7e0ffe154d85c7c5aee7b2aa88a85eccc0d
DIFF: https://github.com/llvm/llvm-project/commit/debdf7e0ffe154d85c7c5aee7b2aa88a85eccc0d.diff

LOG: [mlir][sparse] refine single condition set up for semi-ring ops

Reviewed By: Peiming, K-Wu

Differential Revision: https://reviews.llvm.org/D152874

Added: 
    mlir/test/Dialect/SparseTensor/semi_ring.mlir

Modified: 
    mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index c546a7f5e1c5a..6ec5d42a78c36 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -538,9 +538,9 @@ bool Merger::isSingleCondition(TensorId t, ExprId e) const {
   case TensorExp::Kind::kCIm:
   case TensorExp::Kind::kCRe:
   case TensorExp::Kind::kBitCast:
+  case TensorExp::Kind::kUnary:
     return isSingleCondition(t, expr.children.e0);
   case TensorExp::Kind::kBinaryBranch:
-  case TensorExp::Kind::kUnary:
   case TensorExp::Kind::kSelect:
     return false;
   // Binary operations.
@@ -559,6 +559,7 @@ bool Merger::isSingleCondition(TensorId t, ExprId e) const {
   case TensorExp::Kind::kMulC:
   case TensorExp::Kind::kMulI:
   case TensorExp::Kind::kAndI:
+  case TensorExp::Kind::kReduce:
     if (isSingleCondition(t, expr.children.e0))
       return isSingleCondition(t, expr.children.e1) ||
              isInvariant(expr.children.e1);
@@ -576,7 +577,6 @@ bool Merger::isSingleCondition(TensorId t, ExprId e) const {
   case TensorExp::Kind::kOrI:
   case TensorExp::Kind::kXorI:
   case TensorExp::Kind::kBinary:
-  case TensorExp::Kind::kReduce:
     return false;
   }
   llvm_unreachable("unexpected kind");
@@ -783,6 +783,7 @@ void Merger::dumpExp(ExprId e) const {
     llvm::dbgs() << " " << kindToOpSymbol(expr.kind) << " ";
     dumpExp(expr.children.e1);
     llvm::dbgs() << ")";
+    break;
   }
 }
 
@@ -917,11 +918,11 @@ LatSetId Merger::buildLattices(ExprId e, LoopId i) {
       UnaryOp unop = cast<UnaryOp>(expr.op);
       const LatSetId child0 = buildLattices(e0, i);
       Region &absentRegion = unop.getAbsentRegion();
-
       if (absentRegion.empty()) {
         // Simple mapping over existing values.
         return mapSet(kind, child0, Value(), unop);
-      } // Use a disjunction with `unop` on the left and the absent value as an
+      }
+      // Use a disjunction with `unop` on the left and the absent value as an
       // invariant on the right.
       Block &absentBlock = absentRegion.front();
       YieldOp absentYield = cast<YieldOp>(absentBlock.getTerminator());

diff  --git a/mlir/test/Dialect/SparseTensor/semi_ring.mlir b/mlir/test/Dialect/SparseTensor/semi_ring.mlir
new file mode 100644
index 0000000000000..762ef5f678e43
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/semi_ring.mlir
@@ -0,0 +1,58 @@
+// RUN: mlir-opt %s -sparsification | FileCheck %s
+
+#SM = #sparse_tensor.encoding<{ lvlTypes = [ "dense", "compressed" ] }>
+
+#trait = {
+  indexing_maps = [
+    affine_map<(i,j) -> (i,j)> // A
+  ],
+  iterator_types = ["parallel", "parallel"],
+  doc = "A(i,j) += 2.0 where A(i,j) != 0"
+}
+
+module {
+  // Example of a semi-ring operation that only adds a
+  // constant at stored values (something that would
+  // typically not sparsify since it would densify the
+  // implicit zeros in the normal case). The sparse
+  // compiler should see that this is a "simply dynamic"
+  // operation, and the values can be change "in-place".
+  //
+  // CHECK-LABEL: func.func @add_only_where_nonzero(
+  // CHECK-SAME:    %[[VAL_0:.*]]: tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>> {
+  // CHECK-DAG:     %[[VAL_1:.*]] = arith.constant 8 : index
+  // CHECK-DAG:     %[[VAL_2:.*]] = arith.constant 0 : index
+  // CHECK-DAG:     %[[VAL_3:.*]] = arith.constant 1 : index
+  // CHECK-DAG:     %[[VAL_4:.*]] = arith.constant 2.000000e+00 : f64
+  // CHECK-DAG:     %[[VAL_5:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
+  // CHECK-DAG:     %[[VAL_6:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xf64>
+  // CHECK:         scf.for %[[VAL_7:.*]] = %[[VAL_2]] to %[[VAL_1]] step %[[VAL_3]] {
+  // CHECK:           %[[VAL_8:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_7]]] : memref<?xindex>
+  // CHECK:           %[[VAL_9:.*]] = arith.addi %[[VAL_7]], %[[VAL_3]] : index
+  // CHECK:           %[[VAL_10:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_9]]] : memref<?xindex>
+  // CHECK:           scf.for %[[VAL_11:.*]] = %[[VAL_8]] to %[[VAL_10]] step %[[VAL_3]] {
+  // CHECK:             %[[VAL_12:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_11]]] : memref<?xf64>
+  // CHECK:             %[[VAL_13:.*]] = arith.addf %[[VAL_12]], %[[VAL_4]] : f64
+  // CHECK:             memref.store %[[VAL_13]], %[[VAL_6]]{{\[}}%[[VAL_11]]] : memref<?xf64>
+  // CHECK:           } {"Emitted from" = "linalg.generic"}
+  // CHECK:         } {"Emitted from" = "linalg.generic"}
+  // CHECK:         %[[VAL_14:.*]] = sparse_tensor.load %[[VAL_0]] : tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>>
+  // CHECK:         return %[[VAL_14]] : tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>>
+  // CHECK:       }
+  func.func @add_only_where_nonzero(%argA: tensor<8x8xf64, #SM>) -> tensor<8x8xf64, #SM> {
+    %c = arith.constant 2.0 : f64
+    %result = linalg.generic #trait
+      outs(%argA: tensor<8x8xf64, #SM>) {
+        ^bb(%a: f64):
+           %u = sparse_tensor.unary %a : f64 to f64
+             present={
+                ^bb0(%p: f64):
+                  %add = arith.addf %p, %c : f64
+                  sparse_tensor.yield %add : f64
+             }
+             absent={}
+           linalg.yield %u : f64
+    } -> tensor<8x8xf64, #SM>
+    return %result : tensor<8x8xf64, #SM>
+  }
+}


        


More information about the Mlir-commits mailing list