[Mlir-commits] [mlir] 1f58ae8 - [mlir][sparse] Making `TensorExp::Kind` a nested enum-class
wren romano
llvmlistbot at llvm.org
Mon Mar 20 16:12:39 PDT 2023
Author: wren romano
Date: 2023-03-20T16:12:31-07:00
New Revision: 1f58ae80661b7c9738ca5cff08ff8246ddecf987
URL: https://github.com/llvm/llvm-project/commit/1f58ae80661b7c9738ca5cff08ff8246ddecf987
DIFF: https://github.com/llvm/llvm-project/commit/1f58ae80661b7c9738ca5cff08ff8246ddecf987.diff
LOG: [mlir][sparse] Making `TensorExp::Kind` a nested enum-class
This improves namespacing, and follows the pattern used for "Kind" enums elsewhere in MLIR.
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D146086
Added:
Modified:
mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
index 4a83237fb1634..6e39404bb28aa 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
@@ -23,87 +23,6 @@
namespace mlir {
namespace sparse_tensor {
-/// Tensor expression kind.
-///
-/// The `kLoopVar` leaf kind is for representing `linalg::IndexOp`.
-/// That is, its argument is a `LoopId` identifying the loop-variable
-/// in question, and its value will be the current iteration's value
-/// of that loop-variable. See the `LoopId` documentation for more details.
-//
-// TODO: make this an `enum class` nested in the `TensorExp` class;
-// to improve namespacing, and match the pattern used by other "Kind"
-// enums in MLIR.
-//
-// TODO: Modify this definition so that the numeric values already encode
-// the `ExpArity` (while extending the notion of "arity" to include not
-// just the number of `ExprId` children the node has, but also whether the
-// node has a `Value` and/or `Operation*`). Doing this will avoid needing
-// to enumerate all the kinds in `getExpArity` and in the `TensorExp` ctor,
-// and should help clean up a few other places as well.
-enum Kind {
- // Leaf.
- kTensor = 0,
- kInvariant,
- kLoopVar,
- // Unary operations.
- kAbsF,
- kAbsC,
- kAbsI,
- kCeilF,
- kFloorF,
- kSqrtF,
- kSqrtC,
- kExpm1F,
- kExpm1C,
- kLog1pF,
- kLog1pC,
- kSinF,
- kSinC,
- kTanhF,
- kTanhC,
- kNegF,
- kNegC,
- kNegI,
- kTruncF,
- kExtF,
- kCastFS, // signed
- kCastFU, // unsigned
- kCastSF, // signed
- kCastUF, // unsigned
- kCastS, // signed
- kCastU, // unsigned
- kCastIdx,
- kTruncI,
- kCIm, // complex.im
- kCRe, // complex.re
- kBitCast,
- kBinaryBranch, // semiring unary branch created from a binary op
- kUnary, // semiring unary op
- kSelect, // custom selection criteria
- // Binary operations.
- kMulF,
- kMulC,
- kMulI,
- kDivF,
- kDivC, // complex
- kDivS, // signed
- kDivU, // unsigned
- kAddF,
- kAddC,
- kAddI,
- kSubF,
- kSubC,
- kSubI,
- kAndI,
- kOrI,
- kXorI,
- kShrS, // signed
- kShrU, // unsigned
- kShlI,
- kBinary, // semiring binary op
- kReduce, // semiring reduction op
-};
-
// TODO: These type aliases currently only serve to make the code more
// self-documenting, however because they are not type-checked they can
// do nothing to prevent mixups. We should really change them from mere
@@ -169,6 +88,8 @@ struct Children {
/// Tensor expression. Represents a MLIR expression in tensor index notation.
struct TensorExp {
+ enum class Kind;
+
// The `x` parameter has
diff erent types depending on the value of the
// `k` parameter. The correspondences are:
// * `kTensor` -> `TensorId`
@@ -207,6 +128,83 @@ struct TensorExp {
Operation *op;
};
+/// Tensor expression kind.
+///
+/// The `kLoopVar` leaf kind is for representing `linalg::IndexOp`.
+/// That is, its argument is a `LoopId` identifying the loop-variable
+/// in question, and its value will be the current iteration's value
+/// of that loop-variable. See the `LoopId` documentation for more details.
+//
+// TODO: Modify this definition so that the numeric values already encode
+// the `ExpArity` (while extending the notion of "arity" to include not
+// just the number of `ExprId` children the node has, but also whether the
+// node has a `Value` and/or `Operation*`). Doing this will avoid needing
+// to enumerate all the kinds in `getExpArity` and in the `TensorExp` ctor,
+// and should help clean up a few other places as well.
+enum class TensorExp::Kind {
+ // Leaf.
+ kTensor = 0,
+ kInvariant,
+ kLoopVar,
+ // Unary operations.
+ kAbsF,
+ kAbsC,
+ kAbsI,
+ kCeilF,
+ kFloorF,
+ kSqrtF,
+ kSqrtC,
+ kExpm1F,
+ kExpm1C,
+ kLog1pF,
+ kLog1pC,
+ kSinF,
+ kSinC,
+ kTanhF,
+ kTanhC,
+ kNegF,
+ kNegC,
+ kNegI,
+ kTruncF,
+ kExtF,
+ kCastFS, // signed
+ kCastFU, // unsigned
+ kCastSF, // signed
+ kCastUF, // unsigned
+ kCastS, // signed
+ kCastU, // unsigned
+ kCastIdx,
+ kTruncI,
+ kCIm, // complex.im
+ kCRe, // complex.re
+ kBitCast,
+ kBinaryBranch, // semiring unary branch created from a binary op
+ kUnary, // semiring unary op
+ kSelect, // custom selection criteria
+ // Binary operations.
+ kMulF,
+ kMulC,
+ kMulI,
+ kDivF,
+ kDivC, // complex
+ kDivS, // signed
+ kDivU, // unsigned
+ kAddF,
+ kAddC,
+ kAddI,
+ kSubF,
+ kSubC,
+ kSubI,
+ kAndI,
+ kOrI,
+ kXorI,
+ kShrS, // signed
+ kShrU, // unsigned
+ kShlI,
+ kBinary, // semiring binary op
+ kReduce, // semiring reduction op
+};
+
/// Lattice point. Each lattice point consists of a formal conjunction
/// of `TensorLoopId`s, together with the identifier of the corresponding
/// tensor expression. The formal conjunction is represented as a set of
@@ -271,12 +269,12 @@ class Merger {
/// Constructs a new tensor expression, and returns its identifier.
/// The type of the `e0` argument varies according to the value of the
/// `k` argument, as described by the `TensorExp` ctor.
- ExprId addExp(Kind k, unsigned e0, ExprId e1 = kInvalidId, Value v = Value(),
- Operation *op = nullptr);
- ExprId addExp(Kind k, ExprId e, Value v, Operation *op = nullptr) {
+ ExprId addExp(TensorExp::Kind k, unsigned e0, ExprId e1 = kInvalidId,
+ Value v = Value(), Operation *op = nullptr);
+ ExprId addExp(TensorExp::Kind k, ExprId e, Value v, Operation *op = nullptr) {
return addExp(k, e, kInvalidId, v, op);
}
- ExprId addExp(Kind k, Value v, Operation *op = nullptr) {
+ ExprId addExp(TensorExp::Kind k, Value v, Operation *op = nullptr) {
return addExp(k, kInvalidId, kInvalidId, v, op);
}
@@ -290,30 +288,31 @@ class Merger {
/// of `LoopId` (effectively constructing a larger "intersection" of those
/// loops) with a newly constructed tensor (sub)expression of given kind.
/// Returns the identifier of the new lattice point.
- LatPointId conjLat(Kind kind, LatPointId p0, LatPointId p1,
+ LatPointId conjLat(TensorExp::Kind kind, LatPointId p0, LatPointId p1,
Operation *op = nullptr);
/// Conjunctive merge of two lattice sets: `(s0 /\_op s1)`.
/// Returns the identifier of the new set.
- LatSetId conjSet(Kind kind, LatSetId s0, LatSetId s1,
+ LatSetId conjSet(TensorExp::Kind kind, LatSetId s0, LatSetId s1,
Operation *op = nullptr);
/// Disjunctive merge of two lattice sets: `(s0 /\_op s1, s0, s1)`.
/// Returns the identifier of the new set.
- LatSetId disjSet(Kind kind, LatSetId s0, LatSetId s1,
+ LatSetId disjSet(TensorExp::Kind kind, LatSetId s0, LatSetId s1,
Operation *op = nullptr);
/// Disjunctive merge of two lattice sets with custom handling of the
/// overlap, left, and right regions. Any region may be left missing
/// in the output. Returns the identifier of the new set.
- LatSetId combiSet(Kind kind, LatSetId s0, LatSetId s1, Operation *orig,
- bool includeLeft, Kind ltrans, Operation *opleft,
- bool includeRight, Kind rtrans, Operation *opright);
+ LatSetId combiSet(TensorExp::Kind kind, LatSetId s0, LatSetId s1,
+ Operation *orig, bool includeLeft, TensorExp::Kind ltrans,
+ Operation *opleft, bool includeRight,
+ TensorExp::Kind rtrans, Operation *opright);
/// Maps the unary operator over the lattice set of the operand, i.e. each
/// lattice point on an expression E is simply copied over, but with OP E
/// as new expression. Returns the identifier of the new set.
- LatSetId mapSet(Kind kind, LatSetId s, Value v = Value(),
+ LatSetId mapSet(TensorExp::Kind kind, LatSetId s, Value v = Value(),
Operation *op = nullptr);
/// Optimizes the iteration lattice points in the given set. This
@@ -377,7 +376,8 @@ class Merger {
/// Returns true if the expression is `(kTensor t)`.
bool expIsTensor(ExprId e, TensorId t) const {
- return tensorExps[e].kind == kTensor && tensorExps[e].tensor == t;
+ return tensorExps[e].kind == TensorExp::Kind::kTensor &&
+ tensorExps[e].tensor == t;
}
/// Returns true if the expression contains the tensor as an operand.
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index f189b14c60c7e..d8aeb44811534 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -1045,8 +1045,9 @@ static void genTensorStore(CodegenEnv &env, OpBuilder &builder, ExprId exp,
if (!rhs) {
// Only unary and binary are allowed to return uninitialized rhs
// to indicate missing output.
- assert(env.exp(exp).kind == kUnary || env.exp(exp).kind == kBinary);
- } else if (env.exp(exp).kind == kSelect) {
+ assert(env.exp(exp).kind == TensorExp::Kind::kUnary ||
+ env.exp(exp).kind == TensorExp::Kind::kBinary);
+ } else if (env.exp(exp).kind == TensorExp::Kind::kSelect) {
// Select operation insertion.
Value chain = env.getInsertionChain();
scf::IfOp ifOp =
@@ -1114,28 +1115,29 @@ static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e,
return Value();
const TensorExp &exp = env.exp(e);
const auto kind = exp.kind;
- if (kind == Kind::kTensor)
+ if (kind == TensorExp::Kind::kTensor)
return genTensorLoad(env, rewriter, e);
- if (kind == Kind::kInvariant)
+ if (kind == TensorExp::Kind::kInvariant)
return genInvariantValue(env, e);
- if (kind == Kind::kLoopVar)
+ if (kind == TensorExp::Kind::kLoopVar)
return env.getLoopVar(exp.loop);
- if (kind == Kind::kReduce)
+ if (kind == TensorExp::Kind::kReduce)
env.startCustomReduc(e); // enter custom
Value v0 = genExp(env, rewriter, exp.children.e0, ldx);
Value v1 = genExp(env, rewriter, exp.children.e1, ldx);
Value ee = env.merger().buildExp(rewriter, loc, e, v0, v1);
- if (ee && (kind == Kind::kUnary || kind == Kind::kBinary ||
- kind == Kind::kBinaryBranch || kind == Kind::kReduce ||
- kind == Kind::kSelect))
+ if (ee &&
+ (kind == TensorExp::Kind::kUnary || kind == TensorExp::Kind::kBinary ||
+ kind == TensorExp::Kind::kBinaryBranch ||
+ kind == TensorExp::Kind::kReduce || kind == TensorExp::Kind::kSelect))
ee = relinkBranch(env, rewriter, ee.getParentBlock(), ee, ldx);
- if (kind == Kind::kReduce)
+ if (kind == TensorExp::Kind::kReduce)
env.endCustomReduc(); // exit custom
- if (kind == kSelect) {
+ if (kind == TensorExp::Kind::kSelect) {
assert(!exp.val);
env.exp(e).val = v0; // Preserve value for later use.
}
@@ -1148,7 +1150,7 @@ static void genInvariants(CodegenEnv &env, OpBuilder &builder, ExprId exp,
LoopId ldx, bool atStart) {
if (exp == kInvalidId)
return;
- if (env.exp(exp).kind == Kind::kTensor) {
+ if (env.exp(exp).kind == TensorExp::Kind::kTensor) {
// Inspect tensor indices.
bool isAtLoop = ldx == kInvalidId;
linalg::GenericOp op = env.op();
@@ -1192,18 +1194,18 @@ static void genInvariants(CodegenEnv &env, OpBuilder &builder, ExprId exp,
// Start or end loop invariant hoisting of a tensor load.
env.exp(exp).val = atStart ? genTensorLoad(env, builder, exp) : Value();
}
- } else if (env.exp(exp).kind != Kind::kInvariant &&
- env.exp(exp).kind != Kind::kLoopVar) {
+ } else if (env.exp(exp).kind != TensorExp::Kind::kInvariant &&
+ env.exp(exp).kind != TensorExp::Kind::kLoopVar) {
// Traverse into the binary operations. Note that we only hoist
// tensor loads, since subsequent MLIR/LLVM passes know how to
// deal with all other kinds of derived loop invariants.
- if (env.exp(exp).kind == Kind::kReduce)
+ if (env.exp(exp).kind == TensorExp::Kind::kReduce)
env.startCustomReduc(exp); // enter custom
const ExprId e0 = env.exp(exp).children.e0;
const ExprId e1 = env.exp(exp).children.e1;
genInvariants(env, builder, e0, ldx, atStart);
genInvariants(env, builder, e1, ldx, atStart);
- if (env.exp(exp).kind == Kind::kReduce)
+ if (env.exp(exp).kind == TensorExp::Kind::kReduce)
env.endCustomReduc(); // exit custom
}
}
diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index 40db5411132b4..4a8c3cbfbe584 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -25,70 +25,70 @@ enum class ExpArity {
kBinary,
};
-static ExpArity getExpArity(Kind k) {
+static ExpArity getExpArity(TensorExp::Kind k) {
switch (k) {
// Leaf.
- case kTensor:
- case kInvariant:
- case kLoopVar:
+ case TensorExp::Kind::kTensor:
+ case TensorExp::Kind::kInvariant:
+ case TensorExp::Kind::kLoopVar:
return ExpArity::kNullary;
- case kAbsF:
- case kAbsC:
- case kAbsI:
- case kCeilF:
- case kFloorF:
- case kSqrtF:
- case kSqrtC:
- case kExpm1F:
- case kExpm1C:
- case kLog1pF:
- case kLog1pC:
- case kSinF:
- case kSinC:
- case kTanhF:
- case kTanhC:
- case kTruncF:
- case kExtF:
- case kCastFS:
- case kCastFU:
- case kCastSF:
- case kCastUF:
- case kCastS:
- case kCastU:
- case kCastIdx:
- case kTruncI:
- case kCIm:
- case kCRe:
- case kBitCast:
- case kBinaryBranch:
- case kUnary:
- case kSelect:
- case kNegF:
- case kNegC:
- case kNegI:
+ case TensorExp::Kind::kAbsF:
+ case TensorExp::Kind::kAbsC:
+ case TensorExp::Kind::kAbsI:
+ case TensorExp::Kind::kCeilF:
+ case TensorExp::Kind::kFloorF:
+ case TensorExp::Kind::kSqrtF:
+ case TensorExp::Kind::kSqrtC:
+ case TensorExp::Kind::kExpm1F:
+ case TensorExp::Kind::kExpm1C:
+ case TensorExp::Kind::kLog1pF:
+ case TensorExp::Kind::kLog1pC:
+ case TensorExp::Kind::kSinF:
+ case TensorExp::Kind::kSinC:
+ case TensorExp::Kind::kTanhF:
+ case TensorExp::Kind::kTanhC:
+ case TensorExp::Kind::kTruncF:
+ case TensorExp::Kind::kExtF:
+ case TensorExp::Kind::kCastFS:
+ case TensorExp::Kind::kCastFU:
+ case TensorExp::Kind::kCastSF:
+ case TensorExp::Kind::kCastUF:
+ case TensorExp::Kind::kCastS:
+ case TensorExp::Kind::kCastU:
+ case TensorExp::Kind::kCastIdx:
+ case TensorExp::Kind::kTruncI:
+ case TensorExp::Kind::kCIm:
+ case TensorExp::Kind::kCRe:
+ case TensorExp::Kind::kBitCast:
+ case TensorExp::Kind::kBinaryBranch:
+ case TensorExp::Kind::kUnary:
+ case TensorExp::Kind::kSelect:
+ case TensorExp::Kind::kNegF:
+ case TensorExp::Kind::kNegC:
+ case TensorExp::Kind::kNegI:
return ExpArity::kUnary;
// Binary operations.
- case kDivF:
- case kDivC:
- case kDivS:
- case kDivU:
- case kShrS:
- case kShrU:
- case kShlI:
- case kMulF:
- case kMulC:
- case kMulI:
- case kAndI:
- case kAddF:
- case kAddC:
- case kAddI:
- case kOrI:
- case kXorI:
- case kBinary:
- case kReduce:
- case kSubF:
- case kSubC:
- case kSubI:
+ case TensorExp::Kind::kDivF:
+ case TensorExp::Kind::kDivC:
+ case TensorExp::Kind::kDivS:
+ case TensorExp::Kind::kDivU:
+ case TensorExp::Kind::kShrS:
+ case TensorExp::Kind::kShrU:
+ case TensorExp::Kind::kShlI:
+ case TensorExp::Kind::kMulF:
+ case TensorExp::Kind::kMulC:
+ case TensorExp::Kind::kMulI:
+ case TensorExp::Kind::kAndI:
+ case TensorExp::Kind::kAddF:
+ case TensorExp::Kind::kAddC:
+ case TensorExp::Kind::kAddI:
+ case TensorExp::Kind::kOrI:
+ case TensorExp::Kind::kXorI:
+ case TensorExp::Kind::kBinary:
+ case TensorExp::Kind::kReduce:
+ case TensorExp::Kind::kSubF:
+ case TensorExp::Kind::kSubC:
+ case TensorExp::Kind::kSubI:
return ExpArity::kBinary;
}
llvm_unreachable("unexpected kind");
@@ -102,64 +102,64 @@ TensorExp::TensorExp(Kind k, unsigned x, ExprId y, Value v, Operation *o)
: kind(k), val(v), op(o) {
switch (kind) {
// Leaf.
- case kTensor:
+ case TensorExp::Kind::kTensor:
assert(x != kInvalidId && y == kInvalidId && !v && !o);
tensor = x;
break;
- case kInvariant:
+ case TensorExp::Kind::kInvariant:
assert(x == kInvalidId && y == kInvalidId && v && !o);
break;
- case kLoopVar:
+ case TensorExp::Kind::kLoopVar:
assert(x != kInvalidId && y == kInvalidId && !v && !o);
loop = x;
break;
// Unary operations.
- case kAbsF:
- case kAbsC:
- case kAbsI:
- case kCeilF:
- case kFloorF:
- case kSqrtF:
- case kSqrtC:
- case kExpm1F:
- case kExpm1C:
- case kLog1pF:
- case kLog1pC:
- case kSinF:
- case kSinC:
- case kTanhF:
- case kTanhC:
- case kNegF:
- case kNegC:
- case kNegI:
- case kCIm:
- case kCRe:
+ case TensorExp::Kind::kAbsF:
+ case TensorExp::Kind::kAbsC:
+ case TensorExp::Kind::kAbsI:
+ case TensorExp::Kind::kCeilF:
+ case TensorExp::Kind::kFloorF:
+ case TensorExp::Kind::kSqrtF:
+ case TensorExp::Kind::kSqrtC:
+ case TensorExp::Kind::kExpm1F:
+ case TensorExp::Kind::kExpm1C:
+ case TensorExp::Kind::kLog1pF:
+ case TensorExp::Kind::kLog1pC:
+ case TensorExp::Kind::kSinF:
+ case TensorExp::Kind::kSinC:
+ case TensorExp::Kind::kTanhF:
+ case TensorExp::Kind::kTanhC:
+ case TensorExp::Kind::kNegF:
+ case TensorExp::Kind::kNegC:
+ case TensorExp::Kind::kNegI:
+ case TensorExp::Kind::kCIm:
+ case TensorExp::Kind::kCRe:
assert(x != kInvalidId && y == kInvalidId && !v && !o);
children.e0 = x;
children.e1 = y;
break;
- case kTruncF:
- case kExtF:
- case kCastFS:
- case kCastFU:
- case kCastSF:
- case kCastUF:
- case kCastS:
- case kCastU:
- case kCastIdx:
- case kTruncI:
- case kBitCast:
+ case TensorExp::Kind::kTruncF:
+ case TensorExp::Kind::kExtF:
+ case TensorExp::Kind::kCastFS:
+ case TensorExp::Kind::kCastFU:
+ case TensorExp::Kind::kCastSF:
+ case TensorExp::Kind::kCastUF:
+ case TensorExp::Kind::kCastS:
+ case TensorExp::Kind::kCastU:
+ case TensorExp::Kind::kCastIdx:
+ case TensorExp::Kind::kTruncI:
+ case TensorExp::Kind::kBitCast:
assert(x != kInvalidId && y == kInvalidId && v && !o);
children.e0 = x;
children.e1 = y;
break;
- case kBinaryBranch:
- case kSelect:
+ case TensorExp::Kind::kBinaryBranch:
+ case TensorExp::Kind::kSelect:
assert(x != kInvalidId && y == kInvalidId && !v && o);
children.e0 = x;
children.e1 = y;
break;
- case kUnary:
+ case TensorExp::Kind::kUnary:
// No assertion on y can be made, as the branching paths involve both
// a unary (`mapSet`) and binary (`disjSet`) pathway.
assert(x != kInvalidId && !v && o);
@@ -167,31 +167,31 @@ TensorExp::TensorExp(Kind k, unsigned x, ExprId y, Value v, Operation *o)
children.e1 = y;
break;
// Binary operations.
- case kMulF:
- case kMulC:
- case kMulI:
- case kDivF:
- case kDivC:
- case kDivS:
- case kDivU:
- case kAddF:
- case kAddC:
- case kAddI:
- case kSubF:
- case kSubC:
- case kSubI:
- case kAndI:
- case kOrI:
- case kXorI:
- case kShrS:
- case kShrU:
- case kShlI:
+ case TensorExp::Kind::kMulF:
+ case TensorExp::Kind::kMulC:
+ case TensorExp::Kind::kMulI:
+ case TensorExp::Kind::kDivF:
+ case TensorExp::Kind::kDivC:
+ case TensorExp::Kind::kDivS:
+ case TensorExp::Kind::kDivU:
+ case TensorExp::Kind::kAddF:
+ case TensorExp::Kind::kAddC:
+ case TensorExp::Kind::kAddI:
+ case TensorExp::Kind::kSubF:
+ case TensorExp::Kind::kSubC:
+ case TensorExp::Kind::kSubI:
+ case TensorExp::Kind::kAndI:
+ case TensorExp::Kind::kOrI:
+ case TensorExp::Kind::kXorI:
+ case TensorExp::Kind::kShrS:
+ case TensorExp::Kind::kShrU:
+ case TensorExp::Kind::kShlI:
assert(x != kInvalidId && y != kInvalidId && !v && !o);
children.e0 = x;
children.e1 = y;
break;
- case kBinary:
- case kReduce:
+ case TensorExp::Kind::kBinary:
+ case TensorExp::Kind::kReduce:
assert(x != kInvalidId && y != kInvalidId && !v && o);
children.e0 = x;
children.e1 = y;
@@ -231,9 +231,11 @@ Merger::Merger(unsigned numInputOutputTensors, unsigned numNativeLoops,
// Lattice methods.
//===----------------------------------------------------------------------===//
-ExprId Merger::addExp(Kind k, unsigned x, ExprId y, Value v, Operation *op) {
+ExprId Merger::addExp(TensorExp::Kind k, unsigned x, ExprId y, Value v,
+ Operation *op) {
const ExprId e = tensorExps.size();
- assert((k != kTensor || x < numTensors) && (k != kLoopVar || x < numLoops));
+ assert((k != TensorExp::Kind::kTensor || x < numTensors) &&
+ (k != TensorExp::Kind::kLoopVar || x < numLoops));
tensorExps.emplace_back(k, x, y, v, op);
return e;
}
@@ -251,7 +253,7 @@ LatSetId Merger::addSet() {
return s;
}
-LatPointId Merger::conjLat(Kind kind, LatPointId p0, LatPointId p1,
+LatPointId Merger::conjLat(TensorExp::Kind kind, LatPointId p0, LatPointId p1,
Operation *op) {
const LatPointId p = latPoints.size();
BitVector bits(latPoints[p0].bits);
@@ -262,7 +264,8 @@ LatPointId Merger::conjLat(Kind kind, LatPointId p0, LatPointId p1,
return p;
}
-LatSetId Merger::conjSet(Kind kind, LatSetId s0, LatSetId s1, Operation *op) {
+LatSetId Merger::conjSet(TensorExp::Kind kind, LatSetId s0, LatSetId s1,
+ Operation *op) {
const LatSetId s = addSet();
for (const LatPointId p0 : latSets[s0])
for (const LatPointId p1 : latSets[s1])
@@ -270,28 +273,31 @@ LatSetId Merger::conjSet(Kind kind, LatSetId s0, LatSetId s1, Operation *op) {
return s;
}
-LatSetId Merger::disjSet(Kind kind, LatSetId s0, LatSetId s1, Operation *op) {
+LatSetId Merger::disjSet(TensorExp::Kind kind, LatSetId s0, LatSetId s1,
+ Operation *op) {
const LatSetId s = conjSet(kind, s0, s1, op);
// Followed by all in s0.
for (const LatPointId p : latSets[s0])
latSets[s].push_back(p);
// Map binary 0-y to unary -y.
// TODO: move this if-else logic into buildLattices
- if (kind == kSubF)
- s1 = mapSet(kNegF, s1);
- else if (kind == kSubC)
- s1 = mapSet(kNegC, s1);
- else if (kind == kSubI)
- s1 = mapSet(kNegI, s1);
+ if (kind == TensorExp::Kind::kSubF)
+ s1 = mapSet(TensorExp::Kind::kNegF, s1);
+ else if (kind == TensorExp::Kind::kSubC)
+ s1 = mapSet(TensorExp::Kind::kNegC, s1);
+ else if (kind == TensorExp::Kind::kSubI)
+ s1 = mapSet(TensorExp::Kind::kNegI, s1);
// Followed by all in s1.
for (const LatPointId p : latSets[s1])
latSets[s].push_back(p);
return s;
}
-LatSetId Merger::combiSet(Kind kind, LatSetId s0, LatSetId s1, Operation *orig,
- bool includeLeft, Kind ltrans, Operation *opleft,
- bool includeRight, Kind rtrans, Operation *opright) {
+LatSetId Merger::combiSet(TensorExp::Kind kind, LatSetId s0, LatSetId s1,
+ Operation *orig, bool includeLeft,
+ TensorExp::Kind ltrans, Operation *opleft,
+ bool includeRight, TensorExp::Kind rtrans,
+ Operation *opright) {
const LatSetId s = conjSet(kind, s0, s1, orig);
// Left Region.
if (includeLeft) {
@@ -310,8 +316,9 @@ LatSetId Merger::combiSet(Kind kind, LatSetId s0, LatSetId s1, Operation *orig,
return s;
}
-LatSetId Merger::mapSet(Kind kind, LatSetId s0, Value v, Operation *op) {
- assert(kAbsF <= kind && kind <= kSelect);
+LatSetId Merger::mapSet(TensorExp::Kind kind, LatSetId s0, Value v,
+ Operation *op) {
+ assert(TensorExp::Kind::kAbsF <= kind && kind <= TensorExp::Kind::kSelect);
const LatSetId s = addSet();
for (const LatPointId p : latSets[s0]) {
const ExprId e = addExp(kind, latPoints[p].exp, v, op);
@@ -414,7 +421,7 @@ bool Merger::onlyDenseDiff(LatPointId i, LatPointId j) const {
}
bool Merger::expContainsTensor(ExprId e, TensorId t) const {
- if (tensorExps[e].kind == kTensor)
+ if (tensorExps[e].kind == TensorExp::Kind::kTensor)
return tensorExps[e].tensor == t;
switch (getExpArity(tensorExps[e].kind)) {
@@ -439,13 +446,13 @@ bool Merger::expContainsTensor(ExprId e, TensorId t) const {
bool Merger::hasNegateOnOut(ExprId e) const {
switch (tensorExps[e].kind) {
- case kNegF:
- case kNegC:
- case kNegI:
+ case TensorExp::Kind::kNegF:
+ case TensorExp::Kind::kNegC:
+ case TensorExp::Kind::kNegI:
return expContainsTensor(tensorExps[e].children.e0, outTensor);
- case kSubF:
- case kSubC:
- case kSubI:
+ case TensorExp::Kind::kSubF:
+ case TensorExp::Kind::kSubC:
+ case TensorExp::Kind::kSubI:
return expContainsTensor(tensorExps[e].children.e1, outTensor) ||
hasNegateOnOut(tensorExps[e].children.e0);
default: {
@@ -467,82 +474,82 @@ bool Merger::isSingleCondition(TensorId t, ExprId e) const {
assert(t < numTensors && e < tensorExps.size());
switch (tensorExps[e].kind) {
// Leaf.
- case kTensor:
+ case TensorExp::Kind::kTensor:
return tensorExps[e].tensor == t;
- case kInvariant:
- case kLoopVar:
+ case TensorExp::Kind::kInvariant:
+ case TensorExp::Kind::kLoopVar:
return false;
// Unary operations.
- case kAbsF:
- case kAbsC:
- case kAbsI:
- case kCeilF:
- case kFloorF:
- case kSqrtF:
- case kSqrtC:
- case kExpm1F:
- case kExpm1C:
- case kLog1pF:
- case kLog1pC:
- case kSinF:
- case kSinC:
- case kTanhF:
- case kTanhC:
- case kNegF:
- case kNegC:
- case kNegI:
- case kTruncF:
- case kExtF:
- case kCastFS:
- case kCastFU:
- case kCastSF:
- case kCastUF:
- case kCastS:
- case kCastU:
- case kCastIdx:
- case kTruncI:
- case kCIm:
- case kCRe:
- case kBitCast:
+ case TensorExp::Kind::kAbsF:
+ case TensorExp::Kind::kAbsC:
+ case TensorExp::Kind::kAbsI:
+ case TensorExp::Kind::kCeilF:
+ case TensorExp::Kind::kFloorF:
+ case TensorExp::Kind::kSqrtF:
+ case TensorExp::Kind::kSqrtC:
+ case TensorExp::Kind::kExpm1F:
+ case TensorExp::Kind::kExpm1C:
+ case TensorExp::Kind::kLog1pF:
+ case TensorExp::Kind::kLog1pC:
+ case TensorExp::Kind::kSinF:
+ case TensorExp::Kind::kSinC:
+ case TensorExp::Kind::kTanhF:
+ case TensorExp::Kind::kTanhC:
+ case TensorExp::Kind::kNegF:
+ case TensorExp::Kind::kNegC:
+ case TensorExp::Kind::kNegI:
+ case TensorExp::Kind::kTruncF:
+ case TensorExp::Kind::kExtF:
+ case TensorExp::Kind::kCastFS:
+ case TensorExp::Kind::kCastFU:
+ case TensorExp::Kind::kCastSF:
+ case TensorExp::Kind::kCastUF:
+ case TensorExp::Kind::kCastS:
+ case TensorExp::Kind::kCastU:
+ case TensorExp::Kind::kCastIdx:
+ case TensorExp::Kind::kTruncI:
+ case TensorExp::Kind::kCIm:
+ case TensorExp::Kind::kCRe:
+ case TensorExp::Kind::kBitCast:
return isSingleCondition(t, tensorExps[e].children.e0);
- case kBinaryBranch:
- case kUnary:
- case kSelect:
+ case TensorExp::Kind::kBinaryBranch:
+ case TensorExp::Kind::kUnary:
+ case TensorExp::Kind::kSelect:
return false;
// Binary operations.
- case kDivF: // note: x / c only
- case kDivC:
- case kDivS:
- case kDivU:
+ case TensorExp::Kind::kDivF: // note: x / c only
+ case TensorExp::Kind::kDivC:
+ case TensorExp::Kind::kDivS:
+ case TensorExp::Kind::kDivU:
assert(!maybeZero(tensorExps[e].children.e1));
return isSingleCondition(t, tensorExps[e].children.e0);
- case kShrS: // note: x >> inv only
- case kShrU:
- case kShlI:
+ case TensorExp::Kind::kShrS: // note: x >> inv only
+ case TensorExp::Kind::kShrU:
+ case TensorExp::Kind::kShlI:
assert(isInvariant(tensorExps[e].children.e1));
return isSingleCondition(t, tensorExps[e].children.e0);
- case kMulF:
- case kMulC:
- case kMulI:
- case kAndI:
+ case TensorExp::Kind::kMulF:
+ case TensorExp::Kind::kMulC:
+ case TensorExp::Kind::kMulI:
+ case TensorExp::Kind::kAndI:
if (isSingleCondition(t, tensorExps[e].children.e0))
return isSingleCondition(t, tensorExps[e].children.e1) ||
isInvariant(tensorExps[e].children.e1);
if (isSingleCondition(t, tensorExps[e].children.e1))
return isInvariant(tensorExps[e].children.e0);
return false;
- case kAddF:
- case kAddC:
- case kAddI:
+ case TensorExp::Kind::kAddF:
+ case TensorExp::Kind::kAddC:
+ case TensorExp::Kind::kAddI:
return isSingleCondition(t, tensorExps[e].children.e0) &&
isSingleCondition(t, tensorExps[e].children.e1);
- case kSubF:
- case kSubC:
- case kSubI:
- case kOrI:
- case kXorI:
- case kBinary:
- case kReduce:
+ case TensorExp::Kind::kSubF:
+ case TensorExp::Kind::kSubC:
+ case TensorExp::Kind::kSubI:
+ case TensorExp::Kind::kOrI:
+ case TensorExp::Kind::kXorI:
+ case TensorExp::Kind::kBinary:
+ case TensorExp::Kind::kReduce:
return false;
}
llvm_unreachable("unexpected kind");
@@ -572,98 +579,98 @@ bool Merger::hasSparseIdxReduction(const BitVector &bits) const {
// Print methods (for debugging).
//===----------------------------------------------------------------------===//
-static const char *kindToOpSymbol(Kind kind) {
+static const char *kindToOpSymbol(TensorExp::Kind kind) {
switch (kind) {
// Leaf.
- case kTensor:
+ case TensorExp::Kind::kTensor:
return "tensor";
- case kInvariant:
+ case TensorExp::Kind::kInvariant:
return "invariant";
- case kLoopVar:
+ case TensorExp::Kind::kLoopVar:
return "index";
// Unary operations.
- case kAbsF:
- case kAbsC:
- case kAbsI:
+ case TensorExp::Kind::kAbsF:
+ case TensorExp::Kind::kAbsC:
+ case TensorExp::Kind::kAbsI:
return "abs";
- case kCeilF:
+ case TensorExp::Kind::kCeilF:
return "ceil";
- case kFloorF:
+ case TensorExp::Kind::kFloorF:
return "floor";
- case kSqrtF:
- case kSqrtC:
+ case TensorExp::Kind::kSqrtF:
+ case TensorExp::Kind::kSqrtC:
return "sqrt";
- case kExpm1F:
- case kExpm1C:
+ case TensorExp::Kind::kExpm1F:
+ case TensorExp::Kind::kExpm1C:
return "expm1";
- case kLog1pF:
- case kLog1pC:
+ case TensorExp::Kind::kLog1pF:
+ case TensorExp::Kind::kLog1pC:
return "log1p";
- case kSinF:
- case kSinC:
+ case TensorExp::Kind::kSinF:
+ case TensorExp::Kind::kSinC:
return "sin";
- case kTanhF:
- case kTanhC:
+ case TensorExp::Kind::kTanhF:
+ case TensorExp::Kind::kTanhC:
return "tanh";
- case kNegF:
- case kNegC:
- case kNegI:
+ case TensorExp::Kind::kNegF:
+ case TensorExp::Kind::kNegC:
+ case TensorExp::Kind::kNegI:
return "-";
- case kTruncF:
- case kExtF:
- case kCastFS:
- case kCastFU:
- case kCastSF:
- case kCastUF:
- case kCastS:
- case kCastU:
- case kCastIdx:
- case kTruncI:
- case kCIm:
+ case TensorExp::Kind::kTruncF:
+ case TensorExp::Kind::kExtF:
+ case TensorExp::Kind::kCastFS:
+ case TensorExp::Kind::kCastFU:
+ case TensorExp::Kind::kCastSF:
+ case TensorExp::Kind::kCastUF:
+ case TensorExp::Kind::kCastS:
+ case TensorExp::Kind::kCastU:
+ case TensorExp::Kind::kCastIdx:
+ case TensorExp::Kind::kTruncI:
+ case TensorExp::Kind::kCIm:
return "complex.im";
- case kCRe:
+ case TensorExp::Kind::kCRe:
return "complex.re";
- case kBitCast:
+ case TensorExp::Kind::kBitCast:
return "cast";
- case kBinaryBranch:
+ case TensorExp::Kind::kBinaryBranch:
return "binary_branch";
- case kUnary:
+ case TensorExp::Kind::kUnary:
return "unary";
- case kSelect:
+ case TensorExp::Kind::kSelect:
return "select";
// Binary operations.
- case kMulF:
- case kMulC:
- case kMulI:
+ case TensorExp::Kind::kMulF:
+ case TensorExp::Kind::kMulC:
+ case TensorExp::Kind::kMulI:
return "*";
- case kDivF:
- case kDivC:
- case kDivS:
- case kDivU:
+ case TensorExp::Kind::kDivF:
+ case TensorExp::Kind::kDivC:
+ case TensorExp::Kind::kDivS:
+ case TensorExp::Kind::kDivU:
return "/";
- case kAddF:
- case kAddC:
- case kAddI:
+ case TensorExp::Kind::kAddF:
+ case TensorExp::Kind::kAddC:
+ case TensorExp::Kind::kAddI:
return "+";
- case kSubF:
- case kSubC:
- case kSubI:
+ case TensorExp::Kind::kSubF:
+ case TensorExp::Kind::kSubC:
+ case TensorExp::Kind::kSubI:
return "-";
- case kAndI:
+ case TensorExp::Kind::kAndI:
return "&";
- case kOrI:
+ case TensorExp::Kind::kOrI:
return "|";
- case kXorI:
+ case TensorExp::Kind::kXorI:
return "^";
- case kShrS:
+ case TensorExp::Kind::kShrS:
return "a>>";
- case kShrU:
+ case TensorExp::Kind::kShrU:
return ">>";
- case kShlI:
+ case TensorExp::Kind::kShlI:
return "<<";
- case kBinary:
+ case TensorExp::Kind::kBinary:
return "binary";
- case kReduce:
+ case TensorExp::Kind::kReduce:
return "reduce";
}
llvm_unreachable("unexpected kind for symbol");
@@ -672,79 +679,79 @@ static const char *kindToOpSymbol(Kind kind) {
void Merger::dumpExp(ExprId e) const {
switch (tensorExps[e].kind) {
// Leaf.
- case kTensor:
+ case TensorExp::Kind::kTensor:
if (tensorExps[e].tensor == syntheticTensor)
llvm::dbgs() << "synthetic_";
else if (tensorExps[e].tensor == outTensor)
llvm::dbgs() << "output_";
llvm::dbgs() << "tensor_" << tensorExps[e].tensor;
break;
- case kInvariant:
+ case TensorExp::Kind::kInvariant:
llvm::dbgs() << "invariant";
break;
- case kLoopVar:
+ case TensorExp::Kind::kLoopVar:
llvm::dbgs() << "loopvar_" << tensorExps[e].loop;
break;
// Unary operations.
- case kAbsF:
- case kAbsC:
- case kAbsI:
- case kCeilF:
- case kFloorF:
- case kSqrtF:
- case kSqrtC:
- case kExpm1F:
- case kExpm1C:
- case kLog1pF:
- case kLog1pC:
- case kSinF:
- case kSinC:
- case kTanhF:
- case kTanhC:
- case kNegF:
- case kNegC:
- case kNegI:
- case kTruncF:
- case kExtF:
- case kCastFS:
- case kCastFU:
- case kCastSF:
- case kCastUF:
- case kCastS:
- case kCastU:
- case kCastIdx:
- case kTruncI:
- case kCIm:
- case kCRe:
- case kBitCast:
- case kBinaryBranch:
- case kUnary:
- case kSelect:
+ case TensorExp::Kind::kAbsF:
+ case TensorExp::Kind::kAbsC:
+ case TensorExp::Kind::kAbsI:
+ case TensorExp::Kind::kCeilF:
+ case TensorExp::Kind::kFloorF:
+ case TensorExp::Kind::kSqrtF:
+ case TensorExp::Kind::kSqrtC:
+ case TensorExp::Kind::kExpm1F:
+ case TensorExp::Kind::kExpm1C:
+ case TensorExp::Kind::kLog1pF:
+ case TensorExp::Kind::kLog1pC:
+ case TensorExp::Kind::kSinF:
+ case TensorExp::Kind::kSinC:
+ case TensorExp::Kind::kTanhF:
+ case TensorExp::Kind::kTanhC:
+ case TensorExp::Kind::kNegF:
+ case TensorExp::Kind::kNegC:
+ case TensorExp::Kind::kNegI:
+ case TensorExp::Kind::kTruncF:
+ case TensorExp::Kind::kExtF:
+ case TensorExp::Kind::kCastFS:
+ case TensorExp::Kind::kCastFU:
+ case TensorExp::Kind::kCastSF:
+ case TensorExp::Kind::kCastUF:
+ case TensorExp::Kind::kCastS:
+ case TensorExp::Kind::kCastU:
+ case TensorExp::Kind::kCastIdx:
+ case TensorExp::Kind::kTruncI:
+ case TensorExp::Kind::kCIm:
+ case TensorExp::Kind::kCRe:
+ case TensorExp::Kind::kBitCast:
+ case TensorExp::Kind::kBinaryBranch:
+ case TensorExp::Kind::kUnary:
+ case TensorExp::Kind::kSelect:
llvm::dbgs() << kindToOpSymbol(tensorExps[e].kind) << " ";
dumpExp(tensorExps[e].children.e0);
break;
// Binary operations.
- case kMulF:
- case kMulC:
- case kMulI:
- case kDivF:
- case kDivC:
- case kDivS:
- case kDivU:
- case kAddF:
- case kAddC:
- case kAddI:
- case kSubF:
- case kSubC:
- case kSubI:
- case kAndI:
- case kOrI:
- case kXorI:
- case kShrS:
- case kShrU:
- case kShlI:
- case kBinary:
- case kReduce:
+ case TensorExp::Kind::kMulF:
+ case TensorExp::Kind::kMulC:
+ case TensorExp::Kind::kMulI:
+ case TensorExp::Kind::kDivF:
+ case TensorExp::Kind::kDivC:
+ case TensorExp::Kind::kDivS:
+ case TensorExp::Kind::kDivU:
+ case TensorExp::Kind::kAddF:
+ case TensorExp::Kind::kAddC:
+ case TensorExp::Kind::kAddI:
+ case TensorExp::Kind::kSubF:
+ case TensorExp::Kind::kSubC:
+ case TensorExp::Kind::kSubI:
+ case TensorExp::Kind::kAndI:
+ case TensorExp::Kind::kOrI:
+ case TensorExp::Kind::kXorI:
+ case TensorExp::Kind::kShrS:
+ case TensorExp::Kind::kShrU:
+ case TensorExp::Kind::kShlI:
+ case TensorExp::Kind::kBinary:
+ case TensorExp::Kind::kReduce:
llvm::dbgs() << "(";
dumpExp(tensorExps[e].children.e0);
llvm::dbgs() << " " << kindToOpSymbol(tensorExps[e].kind) << " ";
@@ -793,12 +800,12 @@ void Merger::dumpBits(const BitVector &bits) const {
//===----------------------------------------------------------------------===//
LatSetId Merger::buildLattices(ExprId e, LoopId i) {
- const Kind kind = tensorExps[e].kind;
+ const TensorExp::Kind kind = tensorExps[e].kind;
switch (kind) {
// Leaf.
- case kTensor:
- case kInvariant:
- case kLoopVar: {
+ case TensorExp::Kind::kTensor:
+ case TensorExp::Kind::kInvariant:
+ case TensorExp::Kind::kLoopVar: {
// Either the loop-var is really used in the tensor expression, or it is
// set to the undefined loop-var in that level. An invariant expression,
// a proper index value, and a truly dynamic sparse output tensor are set
@@ -806,7 +813,7 @@ LatSetId Merger::buildLattices(ExprId e, LoopId i) {
// iteration space is not skipped as a result of their contents.
const LatSetId s = addSet();
TensorId t = syntheticTensor;
- if (kind == kTensor) {
+ if (kind == TensorExp::Kind::kTensor) {
t = tensorExps[e].tensor;
if (hasSparseOut && t == outTensor)
t = syntheticTensor;
@@ -815,37 +822,37 @@ LatSetId Merger::buildLattices(ExprId e, LoopId i) {
return s;
}
// Unary operations.
- case kAbsF:
- case kAbsC:
- case kAbsI:
- case kCeilF:
- case kFloorF:
- case kSqrtF:
- case kSqrtC:
- case kExpm1F:
- case kExpm1C:
- case kLog1pF:
- case kLog1pC:
- case kSinF:
- case kSinC:
- case kTanhF:
- case kTanhC:
- case kNegF:
- case kNegC:
- case kNegI:
- case kTruncF:
- case kExtF:
- case kCastFS:
- case kCastFU:
- case kCastSF:
- case kCastUF:
- case kCastS:
- case kCastU:
- case kCastIdx:
- case kTruncI:
- case kCIm:
- case kCRe:
- case kBitCast:
+ case TensorExp::Kind::kAbsF:
+ case TensorExp::Kind::kAbsC:
+ case TensorExp::Kind::kAbsI:
+ case TensorExp::Kind::kCeilF:
+ case TensorExp::Kind::kFloorF:
+ case TensorExp::Kind::kSqrtF:
+ case TensorExp::Kind::kSqrtC:
+ case TensorExp::Kind::kExpm1F:
+ case TensorExp::Kind::kExpm1C:
+ case TensorExp::Kind::kLog1pF:
+ case TensorExp::Kind::kLog1pC:
+ case TensorExp::Kind::kSinF:
+ case TensorExp::Kind::kSinC:
+ case TensorExp::Kind::kTanhF:
+ case TensorExp::Kind::kTanhC:
+ case TensorExp::Kind::kNegF:
+ case TensorExp::Kind::kNegC:
+ case TensorExp::Kind::kNegI:
+ case TensorExp::Kind::kTruncF:
+ case TensorExp::Kind::kExtF:
+ case TensorExp::Kind::kCastFS:
+ case TensorExp::Kind::kCastFU:
+ case TensorExp::Kind::kCastSF:
+ case TensorExp::Kind::kCastUF:
+ case TensorExp::Kind::kCastS:
+ case TensorExp::Kind::kCastU:
+ case TensorExp::Kind::kCastIdx:
+ case TensorExp::Kind::kTruncI:
+ case TensorExp::Kind::kCIm:
+ case TensorExp::Kind::kCRe:
+ case TensorExp::Kind::kBitCast:
// A zero preserving operation (viz. f(0) = 0, [Bik96,Ch5]) maps the
// lattice set of the operand through the operator into a new set.
//
@@ -854,13 +861,13 @@ LatSetId Merger::buildLattices(ExprId e, LoopId i) {
// | 0 |-y |
return mapSet(kind, buildLattices(tensorExps[e].children.e0, i),
tensorExps[e].val);
- case kBinaryBranch:
- case kSelect:
+ case TensorExp::Kind::kBinaryBranch:
+ case TensorExp::Kind::kSelect:
// The left or right half of a binary operation which has already
// been split into separate operations for each region.
return mapSet(kind, buildLattices(tensorExps[e].children.e0, i), Value(),
tensorExps[e].op);
- case kUnary:
+ case TensorExp::Kind::kUnary:
// A custom unary operation.
//
// op y| !y | y |
@@ -879,14 +886,14 @@ LatSetId Merger::buildLattices(ExprId e, LoopId i) {
Block &absentBlock = absentRegion.front();
YieldOp absentYield = cast<YieldOp>(absentBlock.getTerminator());
Value absentVal = absentYield.getResult();
- const ExprId rhs = addExp(kInvariant, absentVal);
+ const ExprId rhs = addExp(TensorExp::Kind::kInvariant, absentVal);
return disjSet(kind, child0, buildLattices(rhs, i), unop);
}
// Binary operations.
- case kMulF:
- case kMulC:
- case kMulI:
- case kAndI:
+ case TensorExp::Kind::kMulF:
+ case TensorExp::Kind::kMulC:
+ case TensorExp::Kind::kMulI:
+ case TensorExp::Kind::kAndI:
// A multiplicative operation only needs to be performed
// for the conjunction of sparse iteration spaces.
//
@@ -898,10 +905,10 @@ LatSetId Merger::buildLattices(ExprId e, LoopId i) {
// Note even here, 0*NaN=NaN and 0*Inf=NaN, but that is ignored.
return conjSet(kind, buildLattices(tensorExps[e].children.e0, i),
buildLattices(tensorExps[e].children.e1, i));
- case kDivF:
- case kDivC:
- case kDivS:
- case kDivU:
+ case TensorExp::Kind::kDivF:
+ case TensorExp::Kind::kDivC:
+ case TensorExp::Kind::kDivS:
+ case TensorExp::Kind::kDivU:
// A division is tricky, since 0/0, 0/c, c/0 all have
// specific outcomes for floating-point and integers.
// Thus, we need to traverse the full iteration space.
@@ -918,14 +925,14 @@ LatSetId Merger::buildLattices(ExprId e, LoopId i) {
assert(!maybeZero(tensorExps[e].children.e1));
return conjSet(kind, buildLattices(tensorExps[e].children.e0, i),
buildLattices(tensorExps[e].children.e1, i));
- case kAddF:
- case kAddC:
- case kAddI:
- case kSubF:
- case kSubC:
- case kSubI:
- case kOrI:
- case kXorI:
+ case TensorExp::Kind::kAddF:
+ case TensorExp::Kind::kAddC:
+ case TensorExp::Kind::kAddI:
+ case TensorExp::Kind::kSubF:
+ case TensorExp::Kind::kSubC:
+ case TensorExp::Kind::kSubI:
+ case TensorExp::Kind::kOrI:
+ case TensorExp::Kind::kXorI:
// An additive operation needs to be performed
// for the disjunction of sparse iteration spaces.
//
@@ -935,16 +942,16 @@ LatSetId Merger::buildLattices(ExprId e, LoopId i) {
// x | x |x+y| x | x |x-y|
return disjSet(kind, buildLattices(tensorExps[e].children.e0, i),
buildLattices(tensorExps[e].children.e1, i));
- case kShrS:
- case kShrU:
- case kShlI:
+ case TensorExp::Kind::kShrS:
+ case TensorExp::Kind::kShrU:
+ case TensorExp::Kind::kShlI:
// A shift operation by an invariant amount (viz. tensor expressions
// can only occur at the left-hand-side of the operator) can be handled
// with the conjuction rule.
assert(isInvariant(tensorExps[e].children.e1));
return conjSet(kind, buildLattices(tensorExps[e].children.e0, i),
buildLattices(tensorExps[e].children.e1, i));
- case kBinary:
+ case TensorExp::Kind::kBinary:
// A custom binary operation.
//
// x op y| !y | y |
@@ -971,11 +978,11 @@ LatSetId Merger::buildLattices(ExprId e, LoopId i) {
}
bool includeLeft = binop.getLeftIdentity() || !leftRegion.empty();
bool includeRight = binop.getRightIdentity() || !rightRegion.empty();
- return combiSet(kBinary, child0, child1, binop, includeLeft,
- kBinaryBranch, leftYield, includeRight, kBinaryBranch,
- rightYield);
+ return combiSet(TensorExp::Kind::kBinary, child0, child1, binop,
+ includeLeft, TensorExp::Kind::kBinaryBranch, leftYield,
+ includeRight, TensorExp::Kind::kBinaryBranch, rightYield);
}
- case kReduce:
+ case TensorExp::Kind::kReduce:
// A custom reduce operation.
return conjSet(kind, buildLattices(tensorExps[e].children.e0, i),
buildLattices(tensorExps[e].children.e1, i),
@@ -993,7 +1000,7 @@ std::optional<ExprId> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) {
/// Only returns false if we are certain this is a nonzero.
bool Merger::maybeZero(ExprId e) const {
- if (tensorExps[e].kind == kInvariant) {
+ if (tensorExps[e].kind == TensorExp::Kind::kInvariant) {
if (auto c = tensorExps[e].val.getDefiningOp<complex::ConstantOp>()) {
ArrayAttr arrayAttr = c.getValue();
return arrayAttr[0].cast<FloatAttr>().getValue().isZero() &&
@@ -1008,7 +1015,7 @@ bool Merger::maybeZero(ExprId e) const {
}
bool Merger::isInvariant(ExprId e) const {
- return tensorExps[e].kind == kInvariant;
+ return tensorExps[e].kind == TensorExp::Kind::kInvariant;
}
Type Merger::inferType(ExprId e, Value src) const {
@@ -1060,21 +1067,21 @@ std::optional<ExprId> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
if (arg.getOwner()->getParentOp() == op) {
OpOperand &t = op->getOpOperand(argN);
if (!op.isScalar(&t))
- return addExp(kTensor, argN);
+ return addExp(TensorExp::Kind::kTensor, argN);
v = t.get(); // get scalar value
}
// Any other argument (marked as scalar argument for the generic op
// or belonging to an enveloping op) is considered invariant.
- return addExp(kInvariant, v);
+ return addExp(TensorExp::Kind::kInvariant, v);
}
// Something defined outside is invariant.
Operation *def = v.getDefiningOp();
if (def->getBlock() != &op.getRegion().front())
- return addExp(kInvariant, v);
+ return addExp(TensorExp::Kind::kInvariant, v);
// Construct index operations.
if (def->getNumOperands() == 0) {
if (auto indexOp = dyn_cast<linalg::IndexOp>(def))
- return addExp(kLoopVar, indexOp.getDim());
+ return addExp(TensorExp::Kind::kLoopVar, indexOp.getDim());
}
// Construct unary operations if subexpression can be built.
if (def->getNumOperands() == 1) {
@@ -1082,73 +1089,73 @@ std::optional<ExprId> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
if (x.has_value()) {
const ExprId e = *x;
if (isa<math::AbsFOp>(def))
- return addExp(kAbsF, e);
+ return addExp(TensorExp::Kind::kAbsF, e);
if (isa<complex::AbsOp>(def))
- return addExp(kAbsC, e);
+ return addExp(TensorExp::Kind::kAbsC, e);
if (isa<math::AbsIOp>(def))
- return addExp(kAbsI, e);
+ return addExp(TensorExp::Kind::kAbsI, e);
if (isa<math::CeilOp>(def))
- return addExp(kCeilF, e);
+ return addExp(TensorExp::Kind::kCeilF, e);
if (isa<math::FloorOp>(def))
- return addExp(kFloorF, e);
+ return addExp(TensorExp::Kind::kFloorF, e);
if (isa<math::SqrtOp>(def))
- return addExp(kSqrtF, e);
+ return addExp(TensorExp::Kind::kSqrtF, e);
if (isa<complex::SqrtOp>(def))
- return addExp(kSqrtC, e);
+ return addExp(TensorExp::Kind::kSqrtC, e);
if (isa<math::ExpM1Op>(def))
- return addExp(kExpm1F, e);
+ return addExp(TensorExp::Kind::kExpm1F, e);
if (isa<complex::Expm1Op>(def))
- return addExp(kExpm1C, e);
+ return addExp(TensorExp::Kind::kExpm1C, e);
if (isa<math::Log1pOp>(def))
- return addExp(kLog1pF, e);
+ return addExp(TensorExp::Kind::kLog1pF, e);
if (isa<complex::Log1pOp>(def))
- return addExp(kLog1pC, e);
+ return addExp(TensorExp::Kind::kLog1pC, e);
if (isa<math::SinOp>(def))
- return addExp(kSinF, e);
+ return addExp(TensorExp::Kind::kSinF, e);
if (isa<complex::SinOp>(def))
- return addExp(kSinC, e);
+ return addExp(TensorExp::Kind::kSinC, e);
if (isa<math::TanhOp>(def))
- return addExp(kTanhF, e);
+ return addExp(TensorExp::Kind::kTanhF, e);
if (isa<complex::TanhOp>(def))
- return addExp(kTanhC, e);
+ return addExp(TensorExp::Kind::kTanhC, e);
if (isa<arith::NegFOp>(def))
- return addExp(kNegF, e); // no negi in std
+ return addExp(TensorExp::Kind::kNegF, e); // no negi in std
if (isa<complex::NegOp>(def))
- return addExp(kNegC, e);
+ return addExp(TensorExp::Kind::kNegC, e);
if (isa<arith::TruncFOp>(def))
- return addExp(kTruncF, e, v);
+ return addExp(TensorExp::Kind::kTruncF, e, v);
if (isa<arith::ExtFOp>(def))
- return addExp(kExtF, e, v);
+ return addExp(TensorExp::Kind::kExtF, e, v);
if (isa<arith::FPToSIOp>(def))
- return addExp(kCastFS, e, v);
+ return addExp(TensorExp::Kind::kCastFS, e, v);
if (isa<arith::FPToUIOp>(def))
- return addExp(kCastFU, e, v);
+ return addExp(TensorExp::Kind::kCastFU, e, v);
if (isa<arith::SIToFPOp>(def))
- return addExp(kCastSF, e, v);
+ return addExp(TensorExp::Kind::kCastSF, e, v);
if (isa<arith::UIToFPOp>(def))
- return addExp(kCastUF, e, v);
+ return addExp(TensorExp::Kind::kCastUF, e, v);
if (isa<arith::ExtSIOp>(def))
- return addExp(kCastS, e, v);
+ return addExp(TensorExp::Kind::kCastS, e, v);
if (isa<arith::ExtUIOp>(def))
- return addExp(kCastU, e, v);
+ return addExp(TensorExp::Kind::kCastU, e, v);
if (isa<arith::IndexCastOp>(def))
- return addExp(kCastIdx, e, v);
+ return addExp(TensorExp::Kind::kCastIdx, e, v);
if (isa<arith::TruncIOp>(def))
- return addExp(kTruncI, e, v);
+ return addExp(TensorExp::Kind::kTruncI, e, v);
if (isa<complex::ImOp>(def))
- return addExp(kCIm, e);
+ return addExp(TensorExp::Kind::kCIm, e);
if (isa<complex::ReOp>(def))
- return addExp(kCRe, e);
+ return addExp(TensorExp::Kind::kCRe, e);
if (isa<arith::BitcastOp>(def))
- return addExp(kBitCast, e, v);
+ return addExp(TensorExp::Kind::kBitCast, e, v);
if (auto unop = dyn_cast<sparse_tensor::UnaryOp>(def)) {
if (isAdmissibleBranch(unop, unop.getPresentRegion()) &&
isAdmissibleBranch(unop, unop.getAbsentRegion()))
- return addExp(kUnary, e, Value(), def);
+ return addExp(TensorExp::Kind::kUnary, e, Value(), def);
}
if (auto selop = dyn_cast<sparse_tensor::SelectOp>(def)) {
if (isAdmissibleBranch(selop, selop.getRegion()))
- return addExp(kSelect, e, Value(), def);
+ return addExp(TensorExp::Kind::kSelect, e, Value(), def);
}
}
}
@@ -1162,50 +1169,50 @@ std::optional<ExprId> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
const ExprId e0 = *x;
const ExprId e1 = *y;
if (isa<arith::MulFOp>(def))
- return addExp(kMulF, e0, e1);
+ return addExp(TensorExp::Kind::kMulF, e0, e1);
if (isa<complex::MulOp>(def))
- return addExp(kMulC, e0, e1);
+ return addExp(TensorExp::Kind::kMulC, e0, e1);
if (isa<arith::MulIOp>(def))
- return addExp(kMulI, e0, e1);
+ return addExp(TensorExp::Kind::kMulI, e0, e1);
if (isa<arith::DivFOp>(def) && !maybeZero(e1))
- return addExp(kDivF, e0, e1);
+ return addExp(TensorExp::Kind::kDivF, e0, e1);
if (isa<complex::DivOp>(def) && !maybeZero(e1))
- return addExp(kDivC, e0, e1);
+ return addExp(TensorExp::Kind::kDivC, e0, e1);
if (isa<arith::DivSIOp>(def) && !maybeZero(e1))
- return addExp(kDivS, e0, e1);
+ return addExp(TensorExp::Kind::kDivS, e0, e1);
if (isa<arith::DivUIOp>(def) && !maybeZero(e1))
- return addExp(kDivU, e0, e1);
+ return addExp(TensorExp::Kind::kDivU, e0, e1);
if (isa<arith::AddFOp>(def))
- return addExp(kAddF, e0, e1);
+ return addExp(TensorExp::Kind::kAddF, e0, e1);
if (isa<complex::AddOp>(def))
- return addExp(kAddC, e0, e1);
+ return addExp(TensorExp::Kind::kAddC, e0, e1);
if (isa<arith::AddIOp>(def))
- return addExp(kAddI, e0, e1);
+ return addExp(TensorExp::Kind::kAddI, e0, e1);
if (isa<arith::SubFOp>(def))
- return addExp(kSubF, e0, e1);
+ return addExp(TensorExp::Kind::kSubF, e0, e1);
if (isa<complex::SubOp>(def))
- return addExp(kSubC, e0, e1);
+ return addExp(TensorExp::Kind::kSubC, e0, e1);
if (isa<arith::SubIOp>(def))
- return addExp(kSubI, e0, e1);
+ return addExp(TensorExp::Kind::kSubI, e0, e1);
if (isa<arith::AndIOp>(def))
- return addExp(kAndI, e0, e1);
+ return addExp(TensorExp::Kind::kAndI, e0, e1);
if (isa<arith::OrIOp>(def))
- return addExp(kOrI, e0, e1);
+ return addExp(TensorExp::Kind::kOrI, e0, e1);
if (isa<arith::XOrIOp>(def))
- return addExp(kXorI, e0, e1);
+ return addExp(TensorExp::Kind::kXorI, e0, e1);
if (isa<arith::ShRSIOp>(def) && isInvariant(e1))
- return addExp(kShrS, e0, e1);
+ return addExp(TensorExp::Kind::kShrS, e0, e1);
if (isa<arith::ShRUIOp>(def) && isInvariant(e1))
- return addExp(kShrU, e0, e1);
+ return addExp(TensorExp::Kind::kShrU, e0, e1);
if (isa<arith::ShLIOp>(def) && isInvariant(e1))
- return addExp(kShlI, e0, e1);
+ return addExp(TensorExp::Kind::kShlI, e0, e1);
if (auto binop = dyn_cast<sparse_tensor::BinaryOp>(def)) {
if (isAdmissibleBranch(binop, binop.getOverlapRegion()) &&
(binop.getLeftIdentity() ||
isAdmissibleBranch(binop, binop.getLeftRegion())) &&
(binop.getRightIdentity() ||
isAdmissibleBranch(binop, binop.getRightRegion())))
- return addExp(kBinary, e0, e1, Value(), def);
+ return addExp(TensorExp::Kind::kBinary, e0, e1, Value(), def);
}
}
}
@@ -1219,7 +1226,7 @@ std::optional<ExprId> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
const ExprId e1 = *y;
if (auto redop = dyn_cast<sparse_tensor::ReduceOp>(def)) {
if (isAdmissibleBranch(redop, redop.getRegion()))
- return addExp(kReduce, e0, e1, Value(), def);
+ return addExp(TensorExp::Kind::kReduce, e0, e1, Value(), def);
}
}
}
@@ -1276,136 +1283,136 @@ Value Merger::buildExp(RewriterBase &rewriter, Location loc, ExprId e, Value v0,
Value v1) const {
switch (tensorExps[e].kind) {
// Leaf.
- case kTensor:
- case kInvariant:
- case kLoopVar:
+ case TensorExp::Kind::kTensor:
+ case TensorExp::Kind::kInvariant:
+ case TensorExp::Kind::kLoopVar:
llvm_unreachable("unexpected non-op");
// Unary operations.
- case kAbsF:
+ case TensorExp::Kind::kAbsF:
return rewriter.create<math::AbsFOp>(loc, v0);
- case kAbsC: {
+ case TensorExp::Kind::kAbsC: {
auto type = v0.getType().cast<ComplexType>();
auto eltType = type.getElementType().cast<FloatType>();
return rewriter.create<complex::AbsOp>(loc, eltType, v0);
}
- case kAbsI:
+ case TensorExp::Kind::kAbsI:
return rewriter.create<math::AbsIOp>(loc, v0);
- case kCeilF:
+ case TensorExp::Kind::kCeilF:
return rewriter.create<math::CeilOp>(loc, v0);
- case kFloorF:
+ case TensorExp::Kind::kFloorF:
return rewriter.create<math::FloorOp>(loc, v0);
- case kSqrtF:
+ case TensorExp::Kind::kSqrtF:
return rewriter.create<math::SqrtOp>(loc, v0);
- case kSqrtC:
+ case TensorExp::Kind::kSqrtC:
return rewriter.create<complex::SqrtOp>(loc, v0);
- case kExpm1F:
+ case TensorExp::Kind::kExpm1F:
return rewriter.create<math::ExpM1Op>(loc, v0);
- case kExpm1C:
+ case TensorExp::Kind::kExpm1C:
return rewriter.create<complex::Expm1Op>(loc, v0);
- case kLog1pF:
+ case TensorExp::Kind::kLog1pF:
return rewriter.create<math::Log1pOp>(loc, v0);
- case kLog1pC:
+ case TensorExp::Kind::kLog1pC:
return rewriter.create<complex::Log1pOp>(loc, v0);
- case kSinF:
+ case TensorExp::Kind::kSinF:
return rewriter.create<math::SinOp>(loc, v0);
- case kSinC:
+ case TensorExp::Kind::kSinC:
return rewriter.create<complex::SinOp>(loc, v0);
- case kTanhF:
+ case TensorExp::Kind::kTanhF:
return rewriter.create<math::TanhOp>(loc, v0);
- case kTanhC:
+ case TensorExp::Kind::kTanhC:
return rewriter.create<complex::TanhOp>(loc, v0);
- case kNegF:
+ case TensorExp::Kind::kNegF:
return rewriter.create<arith::NegFOp>(loc, v0);
- case kNegC:
+ case TensorExp::Kind::kNegC:
return rewriter.create<complex::NegOp>(loc, v0);
- case kNegI: // no negi in std
+ case TensorExp::Kind::kNegI: // no negi in std
return rewriter.create<arith::SubIOp>(
loc,
rewriter.create<arith::ConstantOp>(loc, v0.getType(),
rewriter.getZeroAttr(v0.getType())),
v0);
- case kTruncF:
+ case TensorExp::Kind::kTruncF:
return rewriter.create<arith::TruncFOp>(loc, inferType(e, v0), v0);
- case kExtF:
+ case TensorExp::Kind::kExtF:
return rewriter.create<arith::ExtFOp>(loc, inferType(e, v0), v0);
- case kCastFS:
+ case TensorExp::Kind::kCastFS:
return rewriter.create<arith::FPToSIOp>(loc, inferType(e, v0), v0);
- case kCastFU:
+ case TensorExp::Kind::kCastFU:
return rewriter.create<arith::FPToUIOp>(loc, inferType(e, v0), v0);
- case kCastSF:
+ case TensorExp::Kind::kCastSF:
return rewriter.create<arith::SIToFPOp>(loc, inferType(e, v0), v0);
- case kCastUF:
+ case TensorExp::Kind::kCastUF:
return rewriter.create<arith::UIToFPOp>(loc, inferType(e, v0), v0);
- case kCastS:
+ case TensorExp::Kind::kCastS:
return rewriter.create<arith::ExtSIOp>(loc, inferType(e, v0), v0);
- case kCastU:
+ case TensorExp::Kind::kCastU:
return rewriter.create<arith::ExtUIOp>(loc, inferType(e, v0), v0);
- case kCastIdx:
+ case TensorExp::Kind::kCastIdx:
return rewriter.create<arith::IndexCastOp>(loc, inferType(e, v0), v0);
- case kTruncI:
+ case TensorExp::Kind::kTruncI:
return rewriter.create<arith::TruncIOp>(loc, inferType(e, v0), v0);
- case kCIm: {
+ case TensorExp::Kind::kCIm: {
auto type = v0.getType().cast<ComplexType>();
auto eltType = type.getElementType().cast<FloatType>();
return rewriter.create<complex::ImOp>(loc, eltType, v0);
}
- case kCRe: {
+ case TensorExp::Kind::kCRe: {
auto type = v0.getType().cast<ComplexType>();
auto eltType = type.getElementType().cast<FloatType>();
return rewriter.create<complex::ReOp>(loc, eltType, v0);
}
- case kBitCast:
+ case TensorExp::Kind::kBitCast:
return rewriter.create<arith::BitcastOp>(loc, inferType(e, v0), v0);
// Binary operations.
- case kMulF:
+ case TensorExp::Kind::kMulF:
return rewriter.create<arith::MulFOp>(loc, v0, v1);
- case kMulC:
+ case TensorExp::Kind::kMulC:
return rewriter.create<complex::MulOp>(loc, v0, v1);
- case kMulI:
+ case TensorExp::Kind::kMulI:
return rewriter.create<arith::MulIOp>(loc, v0, v1);
- case kDivF:
+ case TensorExp::Kind::kDivF:
return rewriter.create<arith::DivFOp>(loc, v0, v1);
- case kDivC:
+ case TensorExp::Kind::kDivC:
return rewriter.create<complex::DivOp>(loc, v0, v1);
- case kDivS:
+ case TensorExp::Kind::kDivS:
return rewriter.create<arith::DivSIOp>(loc, v0, v1);
- case kDivU:
+ case TensorExp::Kind::kDivU:
return rewriter.create<arith::DivUIOp>(loc, v0, v1);
- case kAddF:
+ case TensorExp::Kind::kAddF:
return rewriter.create<arith::AddFOp>(loc, v0, v1);
- case kAddC:
+ case TensorExp::Kind::kAddC:
return rewriter.create<complex::AddOp>(loc, v0, v1);
- case kAddI:
+ case TensorExp::Kind::kAddI:
return rewriter.create<arith::AddIOp>(loc, v0, v1);
- case kSubF:
+ case TensorExp::Kind::kSubF:
return rewriter.create<arith::SubFOp>(loc, v0, v1);
- case kSubC:
+ case TensorExp::Kind::kSubC:
return rewriter.create<complex::SubOp>(loc, v0, v1);
- case kSubI:
+ case TensorExp::Kind::kSubI:
return rewriter.create<arith::SubIOp>(loc, v0, v1);
- case kAndI:
+ case TensorExp::Kind::kAndI:
return rewriter.create<arith::AndIOp>(loc, v0, v1);
- case kOrI:
+ case TensorExp::Kind::kOrI:
return rewriter.create<arith::OrIOp>(loc, v0, v1);
- case kXorI:
+ case TensorExp::Kind::kXorI:
return rewriter.create<arith::XOrIOp>(loc, v0, v1);
- case kShrS:
+ case TensorExp::Kind::kShrS:
return rewriter.create<arith::ShRSIOp>(loc, v0, v1);
- case kShrU:
+ case TensorExp::Kind::kShrU:
return rewriter.create<arith::ShRUIOp>(loc, v0, v1);
- case kShlI:
+ case TensorExp::Kind::kShlI:
return rewriter.create<arith::ShLIOp>(loc, v0, v1);
- case kBinaryBranch: // semi-ring ops with custom logic.
+ case TensorExp::Kind::kBinaryBranch: // semi-ring ops with custom logic.
return insertYieldOp(rewriter, loc,
*tensorExps[e].op->getBlock()->getParent(), {v0});
- case kUnary:
+ case TensorExp::Kind::kUnary:
return buildUnaryPresent(rewriter, loc, tensorExps[e].op, v0);
- case kSelect:
+ case TensorExp::Kind::kSelect:
return insertYieldOp(rewriter, loc,
cast<SelectOp>(tensorExps[e].op).getRegion(), {v0});
- case kBinary:
+ case TensorExp::Kind::kBinary:
return buildBinaryOverlap(rewriter, loc, tensorExps[e].op, v0, v1);
- case kReduce: {
+ case TensorExp::Kind::kReduce: {
ReduceOp redOp = cast<ReduceOp>(tensorExps[e].op);
return insertYieldOp(rewriter, loc, redOp.getRegion(), {v0, v1});
}
diff --git a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
index 10d350f7c6b97..270b5836907e3 100644
--- a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
+++ b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
@@ -23,18 +23,18 @@ namespace {
///
#define FOREVERY_BINOP(DO) \
- DO(mulf, Kind::kMulF) \
- DO(mulc, Kind::kMulC) \
- DO(muli, Kind::kMulI) \
- DO(addf, Kind::kAddF) \
- DO(addc, Kind::kAddC) \
- DO(addi, Kind::kAddI) \
- DO(subf, Kind::kSubF) \
- DO(subc, Kind::kSubC) \
- DO(subi, Kind::kSubI) \
- DO(andi, Kind::kAndI) \
- DO(xori, Kind::kXorI) \
- DO(ori, Kind::kOrI)
+ DO(mulf, TensorExp::Kind::kMulF) \
+ DO(mulc, TensorExp::Kind::kMulC) \
+ DO(muli, TensorExp::Kind::kMulI) \
+ DO(addf, TensorExp::Kind::kAddF) \
+ DO(addc, TensorExp::Kind::kAddC) \
+ DO(addi, TensorExp::Kind::kAddI) \
+ DO(subf, TensorExp::Kind::kSubF) \
+ DO(subc, TensorExp::Kind::kSubC) \
+ DO(subi, TensorExp::Kind::kSubI) \
+ DO(andi, TensorExp::Kind::kAndI) \
+ DO(xori, TensorExp::Kind::kXorI) \
+ DO(ori, TensorExp::Kind::kOrI)
// TODO: Disjunctive binary operations that need special handling are not
// included, e.g., Division are not tested (for now) as it need a constant
@@ -82,7 +82,7 @@ namespace {
/// Simple recursive data structure used to match expressions in Mergers.
struct Pattern {
- Kind kind;
+ TensorExp::Kind kind;
/// Expressions representing tensors simply have a tensor number.
unsigned tensorNum;
@@ -94,11 +94,12 @@ struct Pattern {
/// Constructors.
/// Rather than using these, please use the readable helper constructor
/// functions below to make tests more readable.
- Pattern(unsigned tensorNum) : kind(Kind::kTensor), tensorNum(tensorNum) {}
- Pattern(Kind kind, const std::shared_ptr<Pattern> &e0,
+ Pattern(unsigned tensorNum)
+ : kind(TensorExp::Kind::kTensor), tensorNum(tensorNum) {}
+ Pattern(TensorExp::Kind kind, const std::shared_ptr<Pattern> &e0,
const std::shared_ptr<Pattern> &e1)
: kind(kind), e0(e0), e1(e1) {
- assert(kind >= Kind::kMulF);
+ assert(kind >= TensorExp::Kind::kMulF);
assert(e0 && e1);
}
};
@@ -134,7 +135,7 @@ class MergerTestBase : public ::testing::Test {
///
unsigned tensor(unsigned tensor) {
- return merger.addExp(Kind::kTensor, tensor);
+ return merger.addExp(TensorExp::Kind::kTensor, tensor);
}
#define IMPL_BINOP_EXPR(OP, KIND) \
@@ -222,69 +223,69 @@ class MergerTestBase : public ::testing::Test {
return false;
switch (tensorExp.kind) {
// Leaf.
- case kTensor:
+ case TensorExp::Kind::kTensor:
return tensorExp.tensor == pattern->tensorNum;
- case kInvariant:
- case kLoopVar:
+ case TensorExp::Kind::kInvariant:
+ case TensorExp::Kind::kLoopVar:
llvm_unreachable("invariant not handled yet");
// Unary operations.
- case kAbsF:
- case kAbsC:
- case kAbsI:
- case kCeilF:
- case kFloorF:
- case kSqrtF:
- case kSqrtC:
- case kExpm1F:
- case kExpm1C:
- case kLog1pF:
- case kLog1pC:
- case kSinF:
- case kSinC:
- case kTanhF:
- case kTanhC:
- case kNegF:
- case kNegC:
- case kNegI:
- case kTruncF:
- case kExtF:
- case kCastFS:
- case kCastFU:
- case kCastSF:
- case kCastUF:
- case kCastS:
- case kCastU:
- case kCastIdx:
- case kTruncI:
- case kCIm:
- case kCRe:
- case kBitCast:
- case kSelect:
- case kBinaryBranch:
- case kUnary:
+ case TensorExp::Kind::kAbsF:
+ case TensorExp::Kind::kAbsC:
+ case TensorExp::Kind::kAbsI:
+ case TensorExp::Kind::kCeilF:
+ case TensorExp::Kind::kFloorF:
+ case TensorExp::Kind::kSqrtF:
+ case TensorExp::Kind::kSqrtC:
+ case TensorExp::Kind::kExpm1F:
+ case TensorExp::Kind::kExpm1C:
+ case TensorExp::Kind::kLog1pF:
+ case TensorExp::Kind::kLog1pC:
+ case TensorExp::Kind::kSinF:
+ case TensorExp::Kind::kSinC:
+ case TensorExp::Kind::kTanhF:
+ case TensorExp::Kind::kTanhC:
+ case TensorExp::Kind::kNegF:
+ case TensorExp::Kind::kNegC:
+ case TensorExp::Kind::kNegI:
+ case TensorExp::Kind::kTruncF:
+ case TensorExp::Kind::kExtF:
+ case TensorExp::Kind::kCastFS:
+ case TensorExp::Kind::kCastFU:
+ case TensorExp::Kind::kCastSF:
+ case TensorExp::Kind::kCastUF:
+ case TensorExp::Kind::kCastS:
+ case TensorExp::Kind::kCastU:
+ case TensorExp::Kind::kCastIdx:
+ case TensorExp::Kind::kTruncI:
+ case TensorExp::Kind::kCIm:
+ case TensorExp::Kind::kCRe:
+ case TensorExp::Kind::kBitCast:
+ case TensorExp::Kind::kSelect:
+ case TensorExp::Kind::kBinaryBranch:
+ case TensorExp::Kind::kUnary:
return compareExpression(tensorExp.children.e0, pattern->e0);
// Binary operations.
- case kMulF:
- case kMulC:
- case kMulI:
- case kDivF:
- case kDivC:
- case kDivS:
- case kDivU:
- case kAddF:
- case kAddC:
- case kAddI:
- case kSubF:
- case kSubC:
- case kSubI:
- case kAndI:
- case kOrI:
- case kXorI:
- case kShrS:
- case kShrU:
- case kShlI:
- case kBinary:
- case kReduce:
+ case TensorExp::Kind::kMulF:
+ case TensorExp::Kind::kMulC:
+ case TensorExp::Kind::kMulI:
+ case TensorExp::Kind::kDivF:
+ case TensorExp::Kind::kDivC:
+ case TensorExp::Kind::kDivS:
+ case TensorExp::Kind::kDivU:
+ case TensorExp::Kind::kAddF:
+ case TensorExp::Kind::kAddC:
+ case TensorExp::Kind::kAddI:
+ case TensorExp::Kind::kSubF:
+ case TensorExp::Kind::kSubC:
+ case TensorExp::Kind::kSubI:
+ case TensorExp::Kind::kAndI:
+ case TensorExp::Kind::kOrI:
+ case TensorExp::Kind::kXorI:
+ case TensorExp::Kind::kShrS:
+ case TensorExp::Kind::kShrU:
+ case TensorExp::Kind::kShlI:
+ case TensorExp::Kind::kBinary:
+ case TensorExp::Kind::kReduce:
return compareExpression(tensorExp.children.e0, pattern->e0) &&
compareExpression(tensorExp.children.e1, pattern->e1);
}
@@ -312,15 +313,15 @@ class MergerTest3T1L : public MergerTestBase {
EXPECT_TRUE(merger.getOutTensorID() == t2);
// Tensor 0: sparse input vector.
- merger.addExp(Kind::kTensor, t0, -1u);
+ merger.addExp(TensorExp::Kind::kTensor, t0, -1u);
merger.setLevelAndType(t0, l0, 0, DimLevelType::Compressed);
// Tensor 1: sparse input vector.
- merger.addExp(Kind::kTensor, t1, -1u);
+ merger.addExp(TensorExp::Kind::kTensor, t1, -1u);
merger.setLevelAndType(t1, l0, 0, DimLevelType::Compressed);
// Tensor 2: dense output vector.
- merger.addExp(Kind::kTensor, t2, -1u);
+ merger.addExp(TensorExp::Kind::kTensor, t2, -1u);
merger.setLevelAndType(t2, l0, 0, DimLevelType::Dense);
}
};
@@ -337,19 +338,19 @@ class MergerTest4T1L : public MergerTestBase {
EXPECT_TRUE(merger.getOutTensorID() == t3);
// Tensor 0: sparse input vector.
- merger.addExp(Kind::kTensor, t0, -1u);
+ merger.addExp(TensorExp::Kind::kTensor, t0, -1u);
merger.setLevelAndType(t0, l0, 0, DimLevelType::Compressed);
// Tensor 1: sparse input vector.
- merger.addExp(Kind::kTensor, t1, -1u);
+ merger.addExp(TensorExp::Kind::kTensor, t1, -1u);
merger.setLevelAndType(t1, l0, 0, DimLevelType::Compressed);
// Tensor 2: sparse input vector
- merger.addExp(Kind::kTensor, t2, -1u);
+ merger.addExp(TensorExp::Kind::kTensor, t2, -1u);
merger.setLevelAndType(t2, l0, 0, DimLevelType::Compressed);
// Tensor 3: dense output vector
- merger.addExp(Kind::kTensor, t3, -1u);
+ merger.addExp(TensorExp::Kind::kTensor, t3, -1u);
merger.setLevelAndType(t3, l0, 0, DimLevelType::Dense);
}
};
@@ -370,15 +371,15 @@ class MergerTest3T1LD : public MergerTestBase {
EXPECT_TRUE(merger.getOutTensorID() == t2);
// Tensor 0: sparse input vector.
- merger.addExp(Kind::kTensor, t0, -1u);
+ merger.addExp(TensorExp::Kind::kTensor, t0, -1u);
merger.setLevelAndType(t0, l0, 0, DimLevelType::Compressed);
// Tensor 1: dense input vector.
- merger.addExp(Kind::kTensor, t1, -1u);
+ merger.addExp(TensorExp::Kind::kTensor, t1, -1u);
merger.setLevelAndType(t1, l0, 0, DimLevelType::Dense);
// Tensor 2: dense output vector.
- merger.addExp(Kind::kTensor, t2, -1u);
+ merger.addExp(TensorExp::Kind::kTensor, t2, -1u);
merger.setLevelAndType(t2, l0, 0, DimLevelType::Dense);
}
};
@@ -399,19 +400,19 @@ class MergerTest4T1LU : public MergerTestBase {
EXPECT_TRUE(merger.getOutTensorID() == t3);
// Tensor 0: undef input vector.
- merger.addExp(Kind::kTensor, t0, -1u);
+ merger.addExp(TensorExp::Kind::kTensor, t0, -1u);
merger.setLevelAndType(t0, l0, 0, DimLevelType::Undef);
// Tensor 1: dense input vector.
- merger.addExp(Kind::kTensor, t1, -1u);
+ merger.addExp(TensorExp::Kind::kTensor, t1, -1u);
merger.setLevelAndType(t1, l0, 0, DimLevelType::Dense);
// Tensor 2: undef input vector.
- merger.addExp(Kind::kTensor, t2, -1u);
+ merger.addExp(TensorExp::Kind::kTensor, t2, -1u);
merger.setLevelAndType(t2, l0, 0, DimLevelType::Undef);
// Tensor 3: dense output vector.
- merger.addExp(Kind::kTensor, t3, -1u);
+ merger.addExp(TensorExp::Kind::kTensor, t3, -1u);
merger.setLevelAndType(t3, l0, 0, DimLevelType::Dense);
}
};
@@ -435,15 +436,15 @@ class MergerTest3T1LSo : public MergerTestBase {
merger.setHasSparseOut(true);
// Tensor 0: undef input vector.
- merger.addExp(Kind::kTensor, t0, -1u);
+ merger.addExp(TensorExp::Kind::kTensor, t0, -1u);
merger.setLevelAndType(t0, l0, 0, DimLevelType::Undef);
// Tensor 1: undef input vector.
- merger.addExp(Kind::kTensor, t1, -1u);
+ merger.addExp(TensorExp::Kind::kTensor, t1, -1u);
merger.setLevelAndType(t1, l0, 0, DimLevelType::Undef);
// Tensor 2: sparse output vector.
- merger.addExp(Kind::kTensor, t2, -1u);
+ merger.addExp(TensorExp::Kind::kTensor, t2, -1u);
merger.setLevelAndType(t2, l0, 0, DimLevelType::Compressed);
}
};
More information about the Mlir-commits
mailing list