[Mlir-commits] [mlir] [mlir][sparse] recognize ReLu operation during sparsification (PR #92016)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon May 13 12:35:14 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Aart Bik (aartbik)
<details>
<summary>Changes</summary>
This is a proof of concept recognition of the most basic forms of ReLu operations, used to show-case sparsification of end-to-end PyTorch models. In the long run, we must avoid lowering such constructs too early (with this need for raising them back).
See discussion at
https://discourse.llvm.org/t/min-max-abs-relu-recognition-starter-project/78918
---
Full diff: https://github.com/llvm/llvm-project/pull/92016.diff
3 Files Affected:
- (modified) mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h (+2-1)
- (modified) mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp (+90-10)
- (added) mlir/test/Dialect/SparseTensor/sparse_relu.mlir (+34)
``````````diff
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
index 7f9820df984b2..b8d278152dc05 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
@@ -144,6 +144,7 @@ enum class TensorExp::Kind {
kExpm1C,
kLog1pF,
kLog1pC,
+ kRelu,
kSinF,
kSinC,
kTanhF,
@@ -316,7 +317,7 @@ class Merger {
/// lattice point on an expression E is simply copied over, but with OP E
/// as new expression. Returns the identifier of the new set.
LatSetId mapSet(TensorExp::Kind kind, LatSetId s, Value v = Value(),
- Operation *op = nullptr);
+ Operation *op = nullptr, Attribute attr = nullptr);
/// Maps the binary operator to the same operation but with one of its operand
/// set to zero, i.e. each lattice point on an expression E is simply copied
diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index 308fbd965259d..0258f797143cb 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -44,6 +44,7 @@ static ExpArity getExpArity(TensorExp::Kind k) {
case TensorExp::Kind::kExpm1C:
case TensorExp::Kind::kLog1pF:
case TensorExp::Kind::kLog1pC:
+ case TensorExp::Kind::kRelu:
case TensorExp::Kind::kSinF:
case TensorExp::Kind::kSinC:
case TensorExp::Kind::kTanhF:
@@ -104,7 +105,7 @@ static ExpArity getExpArity(TensorExp::Kind k) {
TensorExp::TensorExp(TensorExp::Kind k, unsigned x, ExprId y, Value v,
Operation *o, Attribute a)
- : kind(k), val(v), op(o) {
+ : kind(k), val(v), op(o), attr(a) {
switch (kind) {
// Leaf.
case TensorExp::Kind::kTensor:
@@ -133,6 +134,7 @@ TensorExp::TensorExp(TensorExp::Kind k, unsigned x, ExprId y, Value v,
case TensorExp::Kind::kExpm1C:
case TensorExp::Kind::kLog1pF:
case TensorExp::Kind::kLog1pC:
+ case TensorExp::Kind::kRelu:
case TensorExp::Kind::kSinF:
case TensorExp::Kind::kSinC:
case TensorExp::Kind::kTanhF:
@@ -201,7 +203,6 @@ TensorExp::TensorExp(TensorExp::Kind k, unsigned x, ExprId y, Value v,
case TensorExp::Kind::kCmpF:
case TensorExp::Kind::kCmpI:
assert(x != detail::kInvalidId && y != detail::kInvalidId && !v && !o);
- attr = a;
children.e0 = x;
children.e1 = y;
return;
@@ -337,7 +338,6 @@ LatSetId Merger::conjSet(ExprId e, LatSetId s0, LatSetId s1, Operation *op) {
LatSetId Merger::disjSet(ExprId e, LatSetId s0, LatSetId s1, Operation *op) {
const LatSetId sNew = conjSet(e, s0, s1, op);
TensorExp::Kind kind = exp(e).kind;
-
// Followed by all in s0.
latSets[sNew].append(latSets[s0]);
// Map binary 0-y to unary -y.
@@ -381,31 +381,32 @@ LatSetId Merger::combiSet(ExprId e, LatSetId s0, LatSetId s1, Operation *orig,
bool includeLeft, TensorExp::Kind ltrans,
Operation *opleft, bool includeRight,
TensorExp::Kind rtrans, Operation *opright) {
+ Attribute a = exp(e).attr;
const LatSetId sNew = conjSet(e, s0, s1, orig);
// Left Region.
if (includeLeft) {
if (opleft)
- s0 = mapSet(ltrans, s0, Value(), opleft);
+ s0 = mapSet(ltrans, s0, Value(), opleft, a);
latSets[sNew].append(latSets[s0]);
}
// Right Region.
if (includeRight) {
if (opright)
- s1 = mapSet(rtrans, s1, Value(), opright);
+ s1 = mapSet(rtrans, s1, Value(), opright, a);
latSets[sNew].append(latSets[s1]);
}
return sNew;
}
LatSetId Merger::mapSet(TensorExp::Kind kind, LatSetId s0, Value v,
- Operation *op) {
+ Operation *op, Attribute a) {
assert((TensorExp::Kind::kAbsF <= kind && kind <= TensorExp::Kind::kSelect) ||
TensorExp::Kind::kDenseOp == kind);
const LatSetId sNew = addSet();
auto &setNew = latSets[sNew];
for (const LatPointId p : set(s0)) {
const auto &point = latPoints[p];
- setNew.push_back(addLat(point.bits, addExp(kind, point.exp, v, op)));
+ setNew.push_back(addLat(point.bits, addExp(kind, point.exp, v, op, a)));
}
return sNew;
}
@@ -596,6 +597,7 @@ bool Merger::isSingleCondition(TensorId t, ExprId e) const {
case TensorExp::Kind::kExpm1C:
case TensorExp::Kind::kLog1pF:
case TensorExp::Kind::kLog1pC:
+ case TensorExp::Kind::kRelu:
case TensorExp::Kind::kSinF:
case TensorExp::Kind::kSinC:
case TensorExp::Kind::kTanhF:
@@ -717,6 +719,8 @@ static const char *kindToOpSymbol(TensorExp::Kind kind) {
case TensorExp::Kind::kLog1pF:
case TensorExp::Kind::kLog1pC:
return "log1p";
+ case TensorExp::Kind::kRelu:
+ return "relu";
case TensorExp::Kind::kSinF:
case TensorExp::Kind::kSinC:
return "sin";
@@ -824,6 +828,7 @@ void Merger::dumpExp(ExprId e) const {
case TensorExp::Kind::kExpm1C:
case TensorExp::Kind::kLog1pF:
case TensorExp::Kind::kLog1pC:
+ case TensorExp::Kind::kRelu:
case TensorExp::Kind::kSinF:
case TensorExp::Kind::kSinC:
case TensorExp::Kind::kTanhF:
@@ -972,6 +977,7 @@ LatSetId Merger::buildLattices(ExprId e, LoopId i) {
case TensorExp::Kind::kExpm1C:
case TensorExp::Kind::kLog1pF:
case TensorExp::Kind::kLog1pC:
+ case TensorExp::Kind::kRelu:
case TensorExp::Kind::kSinF:
case TensorExp::Kind::kSinC:
case TensorExp::Kind::kTanhF:
@@ -1001,7 +1007,8 @@ LatSetId Merger::buildLattices(ExprId e, LoopId i) {
{
const ExprId e0 = expr.children.e0;
const Value v = expr.val;
- return mapSet(kind, buildLattices(e0, i), v);
+ Attribute a = expr.attr;
+ return mapSet(kind, buildLattices(e0, i), v, nullptr, a);
}
case TensorExp::Kind::kBinaryBranch:
case TensorExp::Kind::kSelect:
@@ -1190,10 +1197,26 @@ std::optional<ExprId> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) {
return buildTensorExp(op, yield->getOperand(0)).first;
}
+/// Only returns true if we are certain this is a zero.
+static bool isCertainZero(Value val) {
+ if (auto c = val.getDefiningOp<complex::ConstantOp>()) {
+ ArrayAttr arrayAttr = c.getValue();
+ return cast<FloatAttr>(arrayAttr[0]).getValue().isZero() &&
+ cast<FloatAttr>(arrayAttr[1]).getValue().isZero();
+ }
+ if (auto c = val.getDefiningOp<arith::ConstantIntOp>())
+ return c.value() == 0;
+ if (auto c = val.getDefiningOp<arith::ConstantFloatOp>())
+ return c.value().isZero();
+ return false;
+}
+
/// Only returns false if we are certain this is a nonzero.
bool Merger::maybeZero(ExprId e) const {
const auto &expr = exp(e);
if (expr.kind == TensorExp::Kind::kInvariant) {
+ // Note that this is different from isCertainZero() in a subtle
+ // way by always returning true for non-constants.
if (auto c = expr.val.getDefiningOp<complex::ConstantOp>()) {
ArrayAttr arrayAttr = c.getValue();
return cast<FloatAttr>(arrayAttr[0]).getValue().isZero() &&
@@ -1247,6 +1270,21 @@ static bool isAdmissibleBranch(Operation *op, Region ®ion) {
return isAdmissibleBranchExp(op, ®ion.front(), yield->getOperand(0));
}
+// Recognizes a direct GT comparison.
+static bool isGreater(TensorExp::Kind kind, Attribute attr) {
+ if (kind == TensorExp::Kind::kCmpI) {
+ auto pred = llvm::cast<arith::CmpIPredicateAttr>(attr).getValue();
+ return pred == arith::CmpIPredicate::ugt ||
+ pred == arith::CmpIPredicate::sgt;
+ }
+ if (kind == TensorExp::Kind::kCmpF) {
+ auto pred = llvm::cast<arith::CmpFPredicateAttr>(attr).getValue();
+ return pred == arith::CmpFPredicate::UGT ||
+ pred == arith::CmpFPredicate::OGT;
+ }
+ return false;
+}
+
std::pair<std::optional<ExprId>, bool>
Merger::buildTensorExp(linalg::GenericOp op, Value v) {
// Recursion leaves.
@@ -1266,6 +1304,7 @@ Merger::buildTensorExp(linalg::GenericOp op, Value v) {
// or belonging to an enveloping op) is considered invariant.
return {addInvariantExp(v), /*hasSpDep=*/false};
}
+
// Something defined outside is invariant.
Operation *def = v.getDefiningOp();
if (def->getBlock() != &op.getRegion().front())
@@ -1352,6 +1391,7 @@ Merger::buildTensorExp(linalg::GenericOp op, Value v) {
}
}
}
+
// Construct binary operations if subexpressions can be built.
// See buildLattices() for an explanation of rejecting certain
// division and shift operations.
@@ -1447,6 +1487,7 @@ Merger::buildTensorExp(linalg::GenericOp op, Value v) {
}
}
}
+
// Construct ternary operations if subexpressions can be built.
if (def->getNumOperands() == 3) {
const auto [x, xDepSp] = buildTensorExp(op, def->getOperand(0));
@@ -1460,6 +1501,26 @@ Merger::buildTensorExp(linalg::GenericOp op, Value v) {
if (isAdmissibleBranch(redop, redop.getRegion()))
return {addExp(TensorExp::Kind::kReduce, e0, e1, def), hasSpDep};
}
+ if (auto selop = dyn_cast<arith::SelectOp>(def)) {
+ // Recognize an integral or floating-point ReLu(x) = Max(x, 0)
+ // operation inside a very specific ternary select operation.
+ // TODO: capture MIN/MAX/ABS/RELU structure in a more generic way
+ const auto &cnd = exp(*x);
+ if (isGreater(cnd.kind, cnd.attr) &&
+ exp(*y).kind == TensorExp::Kind::kTensor &&
+ exp(*z).kind == TensorExp::Kind::kInvariant &&
+ isCertainZero(exp(*z).val)) {
+ const auto &a = exp(cnd.children.e0);
+ const auto &b = exp(cnd.children.e1);
+ if (a.kind == TensorExp::Kind::kTensor &&
+ a.tensor == exp(*y).tensor &&
+ b.kind == TensorExp::Kind::kInvariant && isCertainZero(b.val)) {
+ return {addExp(TensorExp::Kind::kRelu, *y, detail::kInvalidId,
+ nullptr, cnd.attr),
+ yDepSp};
+ }
+ }
+ }
}
}
@@ -1469,7 +1530,6 @@ Merger::buildTensorExp(linalg::GenericOp op, Value v) {
// tensors).
if (def->getNumResults() != 1) // only handle single result operation.
return {std::nullopt, false};
-
SmallVector<std::pair<std::optional<ExprId>, bool>, 2> subExp;
// Builds all the sub-expressions
for (Value operand : def->getOperands())
@@ -1489,6 +1549,7 @@ Merger::buildTensorExp(linalg::GenericOp op, Value v) {
return {e, false};
}
}
+
// Cannot build.
return {std::nullopt, false};
}
@@ -1538,6 +1599,22 @@ static Value buildBinaryOverlap(RewriterBase &rewriter, Location loc,
return insertYieldOp(rewriter, loc, overlapRegion, {v0, v1});
}
+static Value buildRelu(RewriterBase &rewriter, Location loc, Value v0,
+ Attribute attr) {
+ Type tp = v0.getType();
+ auto zero =
+ rewriter.create<arith::ConstantOp>(loc, tp, rewriter.getZeroAttr(tp));
+ Value cmp;
+ if (isa<FloatType>(tp)) {
+ auto pred = llvm::cast<arith::CmpFPredicateAttr>(attr);
+ cmp = rewriter.create<arith::CmpFOp>(loc, pred, v0, zero);
+ } else {
+ auto pred = llvm::cast<arith::CmpIPredicateAttr>(attr);
+ cmp = rewriter.create<arith::CmpIOp>(loc, pred, v0, zero);
+ }
+ return rewriter.create<arith::SelectOp>(loc, cmp, v0, zero);
+}
+
Value Merger::buildExp(RewriterBase &rewriter, Location loc, ExprId e, Value v0,
Value v1) const {
const auto &expr = exp(e);
@@ -1574,6 +1651,8 @@ Value Merger::buildExp(RewriterBase &rewriter, Location loc, ExprId e, Value v0,
return rewriter.create<math::Log1pOp>(loc, v0);
case TensorExp::Kind::kLog1pC:
return rewriter.create<complex::Log1pOp>(loc, v0);
+ case TensorExp::Kind::kRelu:
+ return buildRelu(rewriter, loc, v0, expr.attr);
case TensorExp::Kind::kSinF:
return rewriter.create<math::SinOp>(loc, v0);
case TensorExp::Kind::kSinC:
@@ -1677,7 +1756,8 @@ Value Merger::buildExp(RewriterBase &rewriter, Location loc, ExprId e, Value v0,
case TensorExp::Kind::kUnary:
return buildUnaryPresent(rewriter, loc, expr.op, v0);
case TensorExp::Kind::kSelect:
- return insertYieldOp(rewriter, loc, cast<SelectOp>(expr.op).getRegion(),
+ return insertYieldOp(rewriter, loc,
+ cast<sparse_tensor::SelectOp>(expr.op).getRegion(),
{v0});
case TensorExp::Kind::kBinary:
return buildBinaryOverlap(rewriter, loc, expr.op, v0, v1);
diff --git a/mlir/test/Dialect/SparseTensor/sparse_relu.mlir b/mlir/test/Dialect/SparseTensor/sparse_relu.mlir
new file mode 100644
index 0000000000000..25f0c790b43d7
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/sparse_relu.mlir
@@ -0,0 +1,34 @@
+// RUN: mlir-opt %s --sparsification-and-bufferization | FileCheck %s
+
+#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+
+#sparse = #sparse_tensor.encoding<{
+ map = (d0, d1, d2) -> (d0 : dense, d1 : dense, d2 : compressed)
+}>
+
+//
+// Make sure a simple ReLU passes the sparsifier
+//
+// CHECK-LABEL: func.func @relu
+// CHECK: scf.for
+// CHECK: scf.for
+// CHECK: scf.for
+// CHECK: arith.cmpf ugt
+// CHECK: arith.select
+//
+func.func @relu(%arg0: tensor<10x20x30xf64, #sparse>) -> tensor<10x20x30xf64, #sparse> {
+ %cst = arith.constant 0.000000e+00 : f64
+ %0 = tensor.empty() : tensor<10x20x30xf64>
+ %1 = linalg.generic {
+ indexing_maps = [#map, #map],
+ iterator_types = ["parallel", "parallel", "parallel"]}
+ ins(%arg0 : tensor<10x20x30xf64, #sparse>)
+ outs(%0 : tensor<10x20x30xf64>) {
+ ^bb0(%in: f64, %out: f64):
+ %2 = arith.cmpf ugt, %in, %cst : f64
+ %3 = arith.select %2, %in, %cst : f64
+ linalg.yield %3 : f64
+ } -> tensor<10x20x30xf64>
+ %cast = tensor.cast %1 : tensor<10x20x30xf64> to tensor<10x20x30xf64, #sparse>
+ return %cast : tensor<10x20x30xf64, #sparse>
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/92016
More information about the Mlir-commits
mailing list