[Mlir-commits] [mlir] 70e227a - [mlir][sparse] recognize ReLu operation during sparsification (#92016)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon May 13 14:02:32 PDT 2024


Author: Aart Bik
Date: 2024-05-13T14:02:29-07:00
New Revision: 70e227a404e51f9248c7ad5d79953805b2afacb4

URL: https://github.com/llvm/llvm-project/commit/70e227a404e51f9248c7ad5d79953805b2afacb4
DIFF: https://github.com/llvm/llvm-project/commit/70e227a404e51f9248c7ad5d79953805b2afacb4.diff

LOG: [mlir][sparse] recognize ReLu operation during sparsification (#92016)

This is a proof of concept recognition of the most basic forms of ReLu
operations, used to show-case sparsification of end-to-end PyTorch
models. In the long run, we must avoid lowering such constructs too
early (with this need for raising them back).

See discussion at

https://discourse.llvm.org/t/min-max-abs-relu-recognition-starter-project/78918

Added: 
    mlir/test/Dialect/SparseTensor/sparse_relu.mlir

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
    mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
    mlir/unittests/Dialect/SparseTensor/MergerTest.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
index 7f9820df984b2..b8d278152dc05 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
@@ -144,6 +144,7 @@ enum class TensorExp::Kind {
   kExpm1C,
   kLog1pF,
   kLog1pC,
+  kRelu,
   kSinF,
   kSinC,
   kTanhF,
@@ -316,7 +317,7 @@ class Merger {
   /// lattice point on an expression E is simply copied over, but with OP E
   /// as new expression. Returns the identifier of the new set.
   LatSetId mapSet(TensorExp::Kind kind, LatSetId s, Value v = Value(),
-                  Operation *op = nullptr);
+                  Operation *op = nullptr, Attribute attr = nullptr);
 
   /// Maps the binary operator to the same operation but with one of its operand
   /// set to zero, i.e. each lattice point on an expression E is simply copied

diff  --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index 308fbd965259d..0258f797143cb 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -44,6 +44,7 @@ static ExpArity getExpArity(TensorExp::Kind k) {
   case TensorExp::Kind::kExpm1C:
   case TensorExp::Kind::kLog1pF:
   case TensorExp::Kind::kLog1pC:
+  case TensorExp::Kind::kRelu:
   case TensorExp::Kind::kSinF:
   case TensorExp::Kind::kSinC:
   case TensorExp::Kind::kTanhF:
@@ -104,7 +105,7 @@ static ExpArity getExpArity(TensorExp::Kind k) {
 
 TensorExp::TensorExp(TensorExp::Kind k, unsigned x, ExprId y, Value v,
                      Operation *o, Attribute a)
-    : kind(k), val(v), op(o) {
+    : kind(k), val(v), op(o), attr(a) {
   switch (kind) {
   // Leaf.
   case TensorExp::Kind::kTensor:
@@ -133,6 +134,7 @@ TensorExp::TensorExp(TensorExp::Kind k, unsigned x, ExprId y, Value v,
   case TensorExp::Kind::kExpm1C:
   case TensorExp::Kind::kLog1pF:
   case TensorExp::Kind::kLog1pC:
+  case TensorExp::Kind::kRelu:
   case TensorExp::Kind::kSinF:
   case TensorExp::Kind::kSinC:
   case TensorExp::Kind::kTanhF:
@@ -201,7 +203,6 @@ TensorExp::TensorExp(TensorExp::Kind k, unsigned x, ExprId y, Value v,
   case TensorExp::Kind::kCmpF:
   case TensorExp::Kind::kCmpI:
     assert(x != detail::kInvalidId && y != detail::kInvalidId && !v && !o);
-    attr = a;
     children.e0 = x;
     children.e1 = y;
     return;
@@ -337,7 +338,6 @@ LatSetId Merger::conjSet(ExprId e, LatSetId s0, LatSetId s1, Operation *op) {
 LatSetId Merger::disjSet(ExprId e, LatSetId s0, LatSetId s1, Operation *op) {
   const LatSetId sNew = conjSet(e, s0, s1, op);
   TensorExp::Kind kind = exp(e).kind;
-
   // Followed by all in s0.
   latSets[sNew].append(latSets[s0]);
   // Map binary 0-y to unary -y.
@@ -381,31 +381,32 @@ LatSetId Merger::combiSet(ExprId e, LatSetId s0, LatSetId s1, Operation *orig,
                           bool includeLeft, TensorExp::Kind ltrans,
                           Operation *opleft, bool includeRight,
                           TensorExp::Kind rtrans, Operation *opright) {
+  Attribute a = exp(e).attr;
   const LatSetId sNew = conjSet(e, s0, s1, orig);
   // Left Region.
   if (includeLeft) {
     if (opleft)
-      s0 = mapSet(ltrans, s0, Value(), opleft);
+      s0 = mapSet(ltrans, s0, Value(), opleft, a);
     latSets[sNew].append(latSets[s0]);
   }
   // Right Region.
   if (includeRight) {
     if (opright)
-      s1 = mapSet(rtrans, s1, Value(), opright);
+      s1 = mapSet(rtrans, s1, Value(), opright, a);
     latSets[sNew].append(latSets[s1]);
   }
   return sNew;
 }
 
 LatSetId Merger::mapSet(TensorExp::Kind kind, LatSetId s0, Value v,
-                        Operation *op) {
+                        Operation *op, Attribute a) {
   assert((TensorExp::Kind::kAbsF <= kind && kind <= TensorExp::Kind::kSelect) ||
          TensorExp::Kind::kDenseOp == kind);
   const LatSetId sNew = addSet();
   auto &setNew = latSets[sNew];
   for (const LatPointId p : set(s0)) {
     const auto &point = latPoints[p];
-    setNew.push_back(addLat(point.bits, addExp(kind, point.exp, v, op)));
+    setNew.push_back(addLat(point.bits, addExp(kind, point.exp, v, op, a)));
   }
   return sNew;
 }
@@ -596,6 +597,7 @@ bool Merger::isSingleCondition(TensorId t, ExprId e) const {
   case TensorExp::Kind::kExpm1C:
   case TensorExp::Kind::kLog1pF:
   case TensorExp::Kind::kLog1pC:
+  case TensorExp::Kind::kRelu:
   case TensorExp::Kind::kSinF:
   case TensorExp::Kind::kSinC:
   case TensorExp::Kind::kTanhF:
@@ -717,6 +719,8 @@ static const char *kindToOpSymbol(TensorExp::Kind kind) {
   case TensorExp::Kind::kLog1pF:
   case TensorExp::Kind::kLog1pC:
     return "log1p";
+  case TensorExp::Kind::kRelu:
+    return "relu";
   case TensorExp::Kind::kSinF:
   case TensorExp::Kind::kSinC:
     return "sin";
@@ -824,6 +828,7 @@ void Merger::dumpExp(ExprId e) const {
   case TensorExp::Kind::kExpm1C:
   case TensorExp::Kind::kLog1pF:
   case TensorExp::Kind::kLog1pC:
+  case TensorExp::Kind::kRelu:
   case TensorExp::Kind::kSinF:
   case TensorExp::Kind::kSinC:
   case TensorExp::Kind::kTanhF:
@@ -972,6 +977,7 @@ LatSetId Merger::buildLattices(ExprId e, LoopId i) {
   case TensorExp::Kind::kExpm1C:
   case TensorExp::Kind::kLog1pF:
   case TensorExp::Kind::kLog1pC:
+  case TensorExp::Kind::kRelu:
   case TensorExp::Kind::kSinF:
   case TensorExp::Kind::kSinC:
   case TensorExp::Kind::kTanhF:
@@ -1001,7 +1007,8 @@ LatSetId Merger::buildLattices(ExprId e, LoopId i) {
     {
       const ExprId e0 = expr.children.e0;
       const Value v = expr.val;
-      return mapSet(kind, buildLattices(e0, i), v);
+      Attribute a = expr.attr;
+      return mapSet(kind, buildLattices(e0, i), v, nullptr, a);
     }
   case TensorExp::Kind::kBinaryBranch:
   case TensorExp::Kind::kSelect:
@@ -1190,10 +1197,26 @@ std::optional<ExprId> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) {
   return buildTensorExp(op, yield->getOperand(0)).first;
 }
 
+/// Only returns true if we are certain this is a zero.
+static bool isCertainZero(Value val) {
+  if (auto c = val.getDefiningOp<complex::ConstantOp>()) {
+    ArrayAttr arrayAttr = c.getValue();
+    return cast<FloatAttr>(arrayAttr[0]).getValue().isZero() &&
+           cast<FloatAttr>(arrayAttr[1]).getValue().isZero();
+  }
+  if (auto c = val.getDefiningOp<arith::ConstantIntOp>())
+    return c.value() == 0;
+  if (auto c = val.getDefiningOp<arith::ConstantFloatOp>())
+    return c.value().isZero();
+  return false;
+}
+
 /// Only returns false if we are certain this is a nonzero.
 bool Merger::maybeZero(ExprId e) const {
   const auto &expr = exp(e);
   if (expr.kind == TensorExp::Kind::kInvariant) {
+    // Note that this is 
diff erent from isCertainZero() in a subtle
+    // way by always returning true for non-constants.
     if (auto c = expr.val.getDefiningOp<complex::ConstantOp>()) {
       ArrayAttr arrayAttr = c.getValue();
       return cast<FloatAttr>(arrayAttr[0]).getValue().isZero() &&
@@ -1247,6 +1270,21 @@ static bool isAdmissibleBranch(Operation *op, Region &region) {
   return isAdmissibleBranchExp(op, &region.front(), yield->getOperand(0));
 }
 
+// Recognizes a direct GT comparison.
+static bool isGreater(TensorExp::Kind kind, Attribute attr) {
+  if (kind == TensorExp::Kind::kCmpI) {
+    auto pred = llvm::cast<arith::CmpIPredicateAttr>(attr).getValue();
+    return pred == arith::CmpIPredicate::ugt ||
+           pred == arith::CmpIPredicate::sgt;
+  }
+  if (kind == TensorExp::Kind::kCmpF) {
+    auto pred = llvm::cast<arith::CmpFPredicateAttr>(attr).getValue();
+    return pred == arith::CmpFPredicate::UGT ||
+           pred == arith::CmpFPredicate::OGT;
+  }
+  return false;
+}
+
 std::pair<std::optional<ExprId>, bool>
 Merger::buildTensorExp(linalg::GenericOp op, Value v) {
   // Recursion leaves.
@@ -1266,6 +1304,7 @@ Merger::buildTensorExp(linalg::GenericOp op, Value v) {
     // or belonging to an enveloping op) is considered invariant.
     return {addInvariantExp(v), /*hasSpDep=*/false};
   }
+
   // Something defined outside is invariant.
   Operation *def = v.getDefiningOp();
   if (def->getBlock() != &op.getRegion().front())
@@ -1352,6 +1391,7 @@ Merger::buildTensorExp(linalg::GenericOp op, Value v) {
       }
     }
   }
+
   // Construct binary operations if subexpressions can be built.
   // See buildLattices() for an explanation of rejecting certain
   // division and shift operations.
@@ -1447,6 +1487,7 @@ Merger::buildTensorExp(linalg::GenericOp op, Value v) {
       }
     }
   }
+
   // Construct ternary operations if subexpressions can be built.
   if (def->getNumOperands() == 3) {
     const auto [x, xDepSp] = buildTensorExp(op, def->getOperand(0));
@@ -1460,6 +1501,26 @@ Merger::buildTensorExp(linalg::GenericOp op, Value v) {
         if (isAdmissibleBranch(redop, redop.getRegion()))
           return {addExp(TensorExp::Kind::kReduce, e0, e1, def), hasSpDep};
       }
+      if (auto selop = dyn_cast<arith::SelectOp>(def)) {
+        // Recognize an integral or floating-point ReLu(x) = Max(x, 0)
+        // operation inside a very specific ternary select operation.
+        // TODO: capture MIN/MAX/ABS/RELU structure in a more generic way
+        const auto &cnd = exp(*x);
+        if (isGreater(cnd.kind, cnd.attr) &&
+            exp(*y).kind == TensorExp::Kind::kTensor &&
+            exp(*z).kind == TensorExp::Kind::kInvariant &&
+            isCertainZero(exp(*z).val)) {
+          const auto &a = exp(cnd.children.e0);
+          const auto &b = exp(cnd.children.e1);
+          if (a.kind == TensorExp::Kind::kTensor &&
+              a.tensor == exp(*y).tensor &&
+              b.kind == TensorExp::Kind::kInvariant && isCertainZero(b.val)) {
+            return {addExp(TensorExp::Kind::kRelu, *y, detail::kInvalidId,
+                           nullptr, cnd.attr),
+                    yDepSp};
+          }
+        }
+      }
     }
   }
 
@@ -1469,7 +1530,6 @@ Merger::buildTensorExp(linalg::GenericOp op, Value v) {
   // tensors).
   if (def->getNumResults() != 1) // only handle single result operation.
     return {std::nullopt, false};
-
   SmallVector<std::pair<std::optional<ExprId>, bool>, 2> subExp;
   // Builds all the sub-expressions
   for (Value operand : def->getOperands())
@@ -1489,6 +1549,7 @@ Merger::buildTensorExp(linalg::GenericOp op, Value v) {
       return {e, false};
     }
   }
+
   // Cannot build.
   return {std::nullopt, false};
 }
@@ -1538,6 +1599,22 @@ static Value buildBinaryOverlap(RewriterBase &rewriter, Location loc,
   return insertYieldOp(rewriter, loc, overlapRegion, {v0, v1});
 }
 
+static Value buildRelu(RewriterBase &rewriter, Location loc, Value v0,
+                       Attribute attr) {
+  Type tp = v0.getType();
+  auto zero =
+      rewriter.create<arith::ConstantOp>(loc, tp, rewriter.getZeroAttr(tp));
+  Value cmp;
+  if (isa<FloatType>(tp)) {
+    auto pred = llvm::cast<arith::CmpFPredicateAttr>(attr);
+    cmp = rewriter.create<arith::CmpFOp>(loc, pred, v0, zero);
+  } else {
+    auto pred = llvm::cast<arith::CmpIPredicateAttr>(attr);
+    cmp = rewriter.create<arith::CmpIOp>(loc, pred, v0, zero);
+  }
+  return rewriter.create<arith::SelectOp>(loc, cmp, v0, zero);
+}
+
 Value Merger::buildExp(RewriterBase &rewriter, Location loc, ExprId e, Value v0,
                        Value v1) const {
   const auto &expr = exp(e);
@@ -1574,6 +1651,8 @@ Value Merger::buildExp(RewriterBase &rewriter, Location loc, ExprId e, Value v0,
     return rewriter.create<math::Log1pOp>(loc, v0);
   case TensorExp::Kind::kLog1pC:
     return rewriter.create<complex::Log1pOp>(loc, v0);
+  case TensorExp::Kind::kRelu:
+    return buildRelu(rewriter, loc, v0, expr.attr);
   case TensorExp::Kind::kSinF:
     return rewriter.create<math::SinOp>(loc, v0);
   case TensorExp::Kind::kSinC:
@@ -1677,7 +1756,8 @@ Value Merger::buildExp(RewriterBase &rewriter, Location loc, ExprId e, Value v0,
   case TensorExp::Kind::kUnary:
     return buildUnaryPresent(rewriter, loc, expr.op, v0);
   case TensorExp::Kind::kSelect:
-    return insertYieldOp(rewriter, loc, cast<SelectOp>(expr.op).getRegion(),
+    return insertYieldOp(rewriter, loc,
+                         cast<sparse_tensor::SelectOp>(expr.op).getRegion(),
                          {v0});
   case TensorExp::Kind::kBinary:
     return buildBinaryOverlap(rewriter, loc, expr.op, v0, v1);

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_relu.mlir b/mlir/test/Dialect/SparseTensor/sparse_relu.mlir
new file mode 100644
index 0000000000000..25f0c790b43d7
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/sparse_relu.mlir
@@ -0,0 +1,34 @@
+// RUN: mlir-opt %s --sparsification-and-bufferization | FileCheck %s
+
+#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+
+#sparse = #sparse_tensor.encoding<{
+    map = (d0, d1, d2) -> (d0 : dense, d1 : dense, d2 : compressed)
+}>
+
+//
+// Make sure a simple ReLU passes the sparsifier
+//
+// CHECK-LABEL: func.func @relu
+// CHECK:       scf.for
+// CHECK:         scf.for
+// CHECK:           scf.for
+// CHECK:             arith.cmpf ugt
+// CHECK:             arith.select
+//
+func.func @relu(%arg0: tensor<10x20x30xf64, #sparse>) -> tensor<10x20x30xf64, #sparse> {
+  %cst = arith.constant 0.000000e+00 : f64
+  %0 = tensor.empty() : tensor<10x20x30xf64>
+  %1 = linalg.generic {
+      indexing_maps = [#map, #map],
+      iterator_types = ["parallel", "parallel", "parallel"]}
+      ins(%arg0 : tensor<10x20x30xf64, #sparse>)
+      outs(%0 : tensor<10x20x30xf64>) {
+  ^bb0(%in: f64, %out: f64):
+      %2 = arith.cmpf ugt, %in, %cst : f64
+      %3 = arith.select %2, %in, %cst : f64
+      linalg.yield %3 : f64
+  } -> tensor<10x20x30xf64>
+  %cast = tensor.cast %1 : tensor<10x20x30xf64> to tensor<10x20x30xf64, #sparse>
+  return %cast : tensor<10x20x30xf64, #sparse>
+}

diff  --git a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
index 943e7d5c120b8..abc6c70766943 100644
--- a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
+++ b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
@@ -236,6 +236,7 @@ class MergerTestBase : public ::testing::TestWithParam<LevelFormat> {
     case TensorExp::Kind::kExpm1C:
     case TensorExp::Kind::kLog1pF:
     case TensorExp::Kind::kLog1pC:
+    case TensorExp::Kind::kRelu:
     case TensorExp::Kind::kSinF:
     case TensorExp::Kind::kSinC:
     case TensorExp::Kind::kTanhF:


        


More information about the Mlir-commits mailing list