[Mlir-commits] [mlir] 1f58ae8 - [mlir][sparse] Making `TensorExp::Kind` a nested enum-class

wren romano llvmlistbot at llvm.org
Mon Mar 20 16:12:39 PDT 2023


Author: wren romano
Date: 2023-03-20T16:12:31-07:00
New Revision: 1f58ae80661b7c9738ca5cff08ff8246ddecf987

URL: https://github.com/llvm/llvm-project/commit/1f58ae80661b7c9738ca5cff08ff8246ddecf987
DIFF: https://github.com/llvm/llvm-project/commit/1f58ae80661b7c9738ca5cff08ff8246ddecf987.diff

LOG: [mlir][sparse] Making `TensorExp::Kind` a nested enum-class

This improves namespacing, and follows the pattern used for "Kind" enums elsewhere in MLIR.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D146086

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
    mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
    mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
    mlir/unittests/Dialect/SparseTensor/MergerTest.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
index 4a83237fb1634..6e39404bb28aa 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
@@ -23,87 +23,6 @@
 namespace mlir {
 namespace sparse_tensor {
 
-/// Tensor expression kind.
-///
-/// The `kLoopVar` leaf kind is for representing `linalg::IndexOp`.
-/// That is, its argument is a `LoopId` identifying the loop-variable
-/// in question, and its value will be the current iteration's value
-/// of that loop-variable.  See the `LoopId` documentation for more details.
-//
-// TODO: make this an `enum class` nested in the `TensorExp` class;
-// to improve namespacing, and match the pattern used by other "Kind"
-// enums in MLIR.
-//
-// TODO: Modify this definition so that the numeric values already encode
-// the `ExpArity` (while extending the notion of "arity" to include not
-// just the number of `ExprId` children the node has, but also whether the
-// node has a `Value` and/or `Operation*`).  Doing this will avoid needing
-// to enumerate all the kinds in `getExpArity` and in the `TensorExp` ctor,
-// and should help clean up a few other places as well.
-enum Kind {
-  // Leaf.
-  kTensor = 0,
-  kInvariant,
-  kLoopVar,
-  // Unary operations.
-  kAbsF,
-  kAbsC,
-  kAbsI,
-  kCeilF,
-  kFloorF,
-  kSqrtF,
-  kSqrtC,
-  kExpm1F,
-  kExpm1C,
-  kLog1pF,
-  kLog1pC,
-  kSinF,
-  kSinC,
-  kTanhF,
-  kTanhC,
-  kNegF,
-  kNegC,
-  kNegI,
-  kTruncF,
-  kExtF,
-  kCastFS, // signed
-  kCastFU, // unsigned
-  kCastSF, // signed
-  kCastUF, // unsigned
-  kCastS,  // signed
-  kCastU,  // unsigned
-  kCastIdx,
-  kTruncI,
-  kCIm, // complex.im
-  kCRe, // complex.re
-  kBitCast,
-  kBinaryBranch, // semiring unary branch created from a binary op
-  kUnary,        // semiring unary op
-  kSelect,       // custom selection criteria
-  // Binary operations.
-  kMulF,
-  kMulC,
-  kMulI,
-  kDivF,
-  kDivC, // complex
-  kDivS, // signed
-  kDivU, // unsigned
-  kAddF,
-  kAddC,
-  kAddI,
-  kSubF,
-  kSubC,
-  kSubI,
-  kAndI,
-  kOrI,
-  kXorI,
-  kShrS, // signed
-  kShrU, // unsigned
-  kShlI,
-  kBinary, // semiring binary op
-  kReduce, // semiring reduction op
-};
-
 // TODO: These type aliases currently only serve to make the code more
 // self-documenting, however because they are not type-checked they can
 // do nothing to prevent mixups.  We should really change them from mere
@@ -169,6 +88,8 @@ struct Children {
 
 /// Tensor expression. Represents a MLIR expression in tensor index notation.
 struct TensorExp {
+  enum class Kind;
+
   // The `x` parameter has 
diff erent types depending on the value of the
   // `k` parameter.  The correspondences are:
   // * `kTensor`    -> `TensorId`
@@ -207,6 +128,83 @@ struct TensorExp {
   Operation *op;
 };
 
+/// Tensor expression kind.
+///
+/// The `kLoopVar` leaf kind is for representing `linalg::IndexOp`.
+/// That is, its argument is a `LoopId` identifying the loop-variable
+/// in question, and its value will be the current iteration's value
+/// of that loop-variable.  See the `LoopId` documentation for more details.
+//
+// TODO: Modify this definition so that the numeric values already encode
+// the `ExpArity` (while extending the notion of "arity" to include not
+// just the number of `ExprId` children the node has, but also whether the
+// node has a `Value` and/or `Operation*`).  Doing this will avoid needing
+// to enumerate all the kinds in `getExpArity` and in the `TensorExp` ctor,
+// and should help clean up a few other places as well.
+enum class TensorExp::Kind {
+  // Leaf.
+  kTensor = 0,
+  kInvariant,
+  kLoopVar,
+  // Unary operations.
+  kAbsF,
+  kAbsC,
+  kAbsI,
+  kCeilF,
+  kFloorF,
+  kSqrtF,
+  kSqrtC,
+  kExpm1F,
+  kExpm1C,
+  kLog1pF,
+  kLog1pC,
+  kSinF,
+  kSinC,
+  kTanhF,
+  kTanhC,
+  kNegF,
+  kNegC,
+  kNegI,
+  kTruncF,
+  kExtF,
+  kCastFS, // signed
+  kCastFU, // unsigned
+  kCastSF, // signed
+  kCastUF, // unsigned
+  kCastS,  // signed
+  kCastU,  // unsigned
+  kCastIdx,
+  kTruncI,
+  kCIm, // complex.im
+  kCRe, // complex.re
+  kBitCast,
+  kBinaryBranch, // semiring unary branch created from a binary op
+  kUnary,        // semiring unary op
+  kSelect,       // custom selection criteria
+  // Binary operations.
+  kMulF,
+  kMulC,
+  kMulI,
+  kDivF,
+  kDivC, // complex
+  kDivS, // signed
+  kDivU, // unsigned
+  kAddF,
+  kAddC,
+  kAddI,
+  kSubF,
+  kSubC,
+  kSubI,
+  kAndI,
+  kOrI,
+  kXorI,
+  kShrS, // signed
+  kShrU, // unsigned
+  kShlI,
+  kBinary, // semiring binary op
+  kReduce, // semiring reduction op
+};
+
 /// Lattice point.  Each lattice point consists of a formal conjunction
 /// of `TensorLoopId`s, together with the identifier of the corresponding
 /// tensor expression.  The formal conjunction is represented as a set of
@@ -271,12 +269,12 @@ class Merger {
   /// Constructs a new tensor expression, and returns its identifier.
   /// The type of the `e0` argument varies according to the value of the
   /// `k` argument, as described by the `TensorExp` ctor.
-  ExprId addExp(Kind k, unsigned e0, ExprId e1 = kInvalidId, Value v = Value(),
-                Operation *op = nullptr);
-  ExprId addExp(Kind k, ExprId e, Value v, Operation *op = nullptr) {
+  ExprId addExp(TensorExp::Kind k, unsigned e0, ExprId e1 = kInvalidId,
+                Value v = Value(), Operation *op = nullptr);
+  ExprId addExp(TensorExp::Kind k, ExprId e, Value v, Operation *op = nullptr) {
     return addExp(k, e, kInvalidId, v, op);
   }
-  ExprId addExp(Kind k, Value v, Operation *op = nullptr) {
+  ExprId addExp(TensorExp::Kind k, Value v, Operation *op = nullptr) {
     return addExp(k, kInvalidId, kInvalidId, v, op);
   }
 
@@ -290,30 +288,31 @@ class Merger {
   /// of `LoopId` (effectively constructing a larger "intersection" of those
   /// loops) with a newly constructed tensor (sub)expression of given kind.
   /// Returns the identifier of the new lattice point.
-  LatPointId conjLat(Kind kind, LatPointId p0, LatPointId p1,
+  LatPointId conjLat(TensorExp::Kind kind, LatPointId p0, LatPointId p1,
                      Operation *op = nullptr);
 
   /// Conjunctive merge of two lattice sets: `(s0 /\_op s1)`.
   /// Returns the identifier of the new set.
-  LatSetId conjSet(Kind kind, LatSetId s0, LatSetId s1,
+  LatSetId conjSet(TensorExp::Kind kind, LatSetId s0, LatSetId s1,
                    Operation *op = nullptr);
 
   /// Disjunctive merge of two lattice sets: `(s0 /\_op s1, s0, s1)`.
   /// Returns the identifier of the new set.
-  LatSetId disjSet(Kind kind, LatSetId s0, LatSetId s1,
+  LatSetId disjSet(TensorExp::Kind kind, LatSetId s0, LatSetId s1,
                    Operation *op = nullptr);
 
   /// Disjunctive merge of two lattice sets with custom handling of the
   /// overlap, left, and right regions.  Any region may be left missing
   /// in the output.  Returns the identifier of the new set.
-  LatSetId combiSet(Kind kind, LatSetId s0, LatSetId s1, Operation *orig,
-                    bool includeLeft, Kind ltrans, Operation *opleft,
-                    bool includeRight, Kind rtrans, Operation *opright);
+  LatSetId combiSet(TensorExp::Kind kind, LatSetId s0, LatSetId s1,
+                    Operation *orig, bool includeLeft, TensorExp::Kind ltrans,
+                    Operation *opleft, bool includeRight,
+                    TensorExp::Kind rtrans, Operation *opright);
 
   /// Maps the unary operator over the lattice set of the operand, i.e. each
   /// lattice point on an expression E is simply copied over, but with OP E
   /// as new expression. Returns the identifier of the new set.
-  LatSetId mapSet(Kind kind, LatSetId s, Value v = Value(),
+  LatSetId mapSet(TensorExp::Kind kind, LatSetId s, Value v = Value(),
                   Operation *op = nullptr);
 
   /// Optimizes the iteration lattice points in the given set. This
@@ -377,7 +376,8 @@ class Merger {
 
   /// Returns true if the expression is `(kTensor t)`.
   bool expIsTensor(ExprId e, TensorId t) const {
-    return tensorExps[e].kind == kTensor && tensorExps[e].tensor == t;
+    return tensorExps[e].kind == TensorExp::Kind::kTensor &&
+           tensorExps[e].tensor == t;
   }
 
   /// Returns true if the expression contains the tensor as an operand.

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index f189b14c60c7e..d8aeb44811534 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -1045,8 +1045,9 @@ static void genTensorStore(CodegenEnv &env, OpBuilder &builder, ExprId exp,
     if (!rhs) {
       // Only unary and binary are allowed to return uninitialized rhs
       // to indicate missing output.
-      assert(env.exp(exp).kind == kUnary || env.exp(exp).kind == kBinary);
-    } else if (env.exp(exp).kind == kSelect) {
+      assert(env.exp(exp).kind == TensorExp::Kind::kUnary ||
+             env.exp(exp).kind == TensorExp::Kind::kBinary);
+    } else if (env.exp(exp).kind == TensorExp::Kind::kSelect) {
       // Select operation insertion.
       Value chain = env.getInsertionChain();
       scf::IfOp ifOp =
@@ -1114,28 +1115,29 @@ static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e,
     return Value();
   const TensorExp &exp = env.exp(e);
   const auto kind = exp.kind;
-  if (kind == Kind::kTensor)
+  if (kind == TensorExp::Kind::kTensor)
     return genTensorLoad(env, rewriter, e);
-  if (kind == Kind::kInvariant)
+  if (kind == TensorExp::Kind::kInvariant)
     return genInvariantValue(env, e);
-  if (kind == Kind::kLoopVar)
+  if (kind == TensorExp::Kind::kLoopVar)
     return env.getLoopVar(exp.loop);
 
-  if (kind == Kind::kReduce)
+  if (kind == TensorExp::Kind::kReduce)
     env.startCustomReduc(e); // enter custom
 
   Value v0 = genExp(env, rewriter, exp.children.e0, ldx);
   Value v1 = genExp(env, rewriter, exp.children.e1, ldx);
   Value ee = env.merger().buildExp(rewriter, loc, e, v0, v1);
-  if (ee && (kind == Kind::kUnary || kind == Kind::kBinary ||
-             kind == Kind::kBinaryBranch || kind == Kind::kReduce ||
-             kind == Kind::kSelect))
+  if (ee &&
+      (kind == TensorExp::Kind::kUnary || kind == TensorExp::Kind::kBinary ||
+       kind == TensorExp::Kind::kBinaryBranch ||
+       kind == TensorExp::Kind::kReduce || kind == TensorExp::Kind::kSelect))
     ee = relinkBranch(env, rewriter, ee.getParentBlock(), ee, ldx);
 
-  if (kind == Kind::kReduce)
+  if (kind == TensorExp::Kind::kReduce)
     env.endCustomReduc(); // exit custom
 
-  if (kind == kSelect) {
+  if (kind == TensorExp::Kind::kSelect) {
     assert(!exp.val);
     env.exp(e).val = v0; // Preserve value for later use.
   }
@@ -1148,7 +1150,7 @@ static void genInvariants(CodegenEnv &env, OpBuilder &builder, ExprId exp,
                           LoopId ldx, bool atStart) {
   if (exp == kInvalidId)
     return;
-  if (env.exp(exp).kind == Kind::kTensor) {
+  if (env.exp(exp).kind == TensorExp::Kind::kTensor) {
     // Inspect tensor indices.
     bool isAtLoop = ldx == kInvalidId;
     linalg::GenericOp op = env.op();
@@ -1192,18 +1194,18 @@ static void genInvariants(CodegenEnv &env, OpBuilder &builder, ExprId exp,
       // Start or end loop invariant hoisting of a tensor load.
       env.exp(exp).val = atStart ? genTensorLoad(env, builder, exp) : Value();
     }
-  } else if (env.exp(exp).kind != Kind::kInvariant &&
-             env.exp(exp).kind != Kind::kLoopVar) {
+  } else if (env.exp(exp).kind != TensorExp::Kind::kInvariant &&
+             env.exp(exp).kind != TensorExp::Kind::kLoopVar) {
     // Traverse into the binary operations. Note that we only hoist
     // tensor loads, since subsequent MLIR/LLVM passes know how to
     // deal with all other kinds of derived loop invariants.
-    if (env.exp(exp).kind == Kind::kReduce)
+    if (env.exp(exp).kind == TensorExp::Kind::kReduce)
       env.startCustomReduc(exp); // enter custom
     const ExprId e0 = env.exp(exp).children.e0;
     const ExprId e1 = env.exp(exp).children.e1;
     genInvariants(env, builder, e0, ldx, atStart);
     genInvariants(env, builder, e1, ldx, atStart);
-    if (env.exp(exp).kind == Kind::kReduce)
+    if (env.exp(exp).kind == TensorExp::Kind::kReduce)
       env.endCustomReduc(); // exit custom
   }
 }

diff  --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index 40db5411132b4..4a8c3cbfbe584 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -25,70 +25,70 @@ enum class ExpArity {
   kBinary,
 };
 
-static ExpArity getExpArity(Kind k) {
+static ExpArity getExpArity(TensorExp::Kind k) {
   switch (k) {
   // Leaf.
-  case kTensor:
-  case kInvariant:
-  case kLoopVar:
+  case TensorExp::Kind::kTensor:
+  case TensorExp::Kind::kInvariant:
+  case TensorExp::Kind::kLoopVar:
     return ExpArity::kNullary;
-  case kAbsF:
-  case kAbsC:
-  case kAbsI:
-  case kCeilF:
-  case kFloorF:
-  case kSqrtF:
-  case kSqrtC:
-  case kExpm1F:
-  case kExpm1C:
-  case kLog1pF:
-  case kLog1pC:
-  case kSinF:
-  case kSinC:
-  case kTanhF:
-  case kTanhC:
-  case kTruncF:
-  case kExtF:
-  case kCastFS:
-  case kCastFU:
-  case kCastSF:
-  case kCastUF:
-  case kCastS:
-  case kCastU:
-  case kCastIdx:
-  case kTruncI:
-  case kCIm:
-  case kCRe:
-  case kBitCast:
-  case kBinaryBranch:
-  case kUnary:
-  case kSelect:
-  case kNegF:
-  case kNegC:
-  case kNegI:
+  case TensorExp::Kind::kAbsF:
+  case TensorExp::Kind::kAbsC:
+  case TensorExp::Kind::kAbsI:
+  case TensorExp::Kind::kCeilF:
+  case TensorExp::Kind::kFloorF:
+  case TensorExp::Kind::kSqrtF:
+  case TensorExp::Kind::kSqrtC:
+  case TensorExp::Kind::kExpm1F:
+  case TensorExp::Kind::kExpm1C:
+  case TensorExp::Kind::kLog1pF:
+  case TensorExp::Kind::kLog1pC:
+  case TensorExp::Kind::kSinF:
+  case TensorExp::Kind::kSinC:
+  case TensorExp::Kind::kTanhF:
+  case TensorExp::Kind::kTanhC:
+  case TensorExp::Kind::kTruncF:
+  case TensorExp::Kind::kExtF:
+  case TensorExp::Kind::kCastFS:
+  case TensorExp::Kind::kCastFU:
+  case TensorExp::Kind::kCastSF:
+  case TensorExp::Kind::kCastUF:
+  case TensorExp::Kind::kCastS:
+  case TensorExp::Kind::kCastU:
+  case TensorExp::Kind::kCastIdx:
+  case TensorExp::Kind::kTruncI:
+  case TensorExp::Kind::kCIm:
+  case TensorExp::Kind::kCRe:
+  case TensorExp::Kind::kBitCast:
+  case TensorExp::Kind::kBinaryBranch:
+  case TensorExp::Kind::kUnary:
+  case TensorExp::Kind::kSelect:
+  case TensorExp::Kind::kNegF:
+  case TensorExp::Kind::kNegC:
+  case TensorExp::Kind::kNegI:
     return ExpArity::kUnary;
   // Binary operations.
-  case kDivF:
-  case kDivC:
-  case kDivS:
-  case kDivU:
-  case kShrS:
-  case kShrU:
-  case kShlI:
-  case kMulF:
-  case kMulC:
-  case kMulI:
-  case kAndI:
-  case kAddF:
-  case kAddC:
-  case kAddI:
-  case kOrI:
-  case kXorI:
-  case kBinary:
-  case kReduce:
-  case kSubF:
-  case kSubC:
-  case kSubI:
+  case TensorExp::Kind::kDivF:
+  case TensorExp::Kind::kDivC:
+  case TensorExp::Kind::kDivS:
+  case TensorExp::Kind::kDivU:
+  case TensorExp::Kind::kShrS:
+  case TensorExp::Kind::kShrU:
+  case TensorExp::Kind::kShlI:
+  case TensorExp::Kind::kMulF:
+  case TensorExp::Kind::kMulC:
+  case TensorExp::Kind::kMulI:
+  case TensorExp::Kind::kAndI:
+  case TensorExp::Kind::kAddF:
+  case TensorExp::Kind::kAddC:
+  case TensorExp::Kind::kAddI:
+  case TensorExp::Kind::kOrI:
+  case TensorExp::Kind::kXorI:
+  case TensorExp::Kind::kBinary:
+  case TensorExp::Kind::kReduce:
+  case TensorExp::Kind::kSubF:
+  case TensorExp::Kind::kSubC:
+  case TensorExp::Kind::kSubI:
     return ExpArity::kBinary;
   }
   llvm_unreachable("unexpected kind");
@@ -102,64 +102,64 @@ TensorExp::TensorExp(Kind k, unsigned x, ExprId y, Value v, Operation *o)
     : kind(k), val(v), op(o) {
   switch (kind) {
   // Leaf.
-  case kTensor:
+  case TensorExp::Kind::kTensor:
     assert(x != kInvalidId && y == kInvalidId && !v && !o);
     tensor = x;
     break;
-  case kInvariant:
+  case TensorExp::Kind::kInvariant:
     assert(x == kInvalidId && y == kInvalidId && v && !o);
     break;
-  case kLoopVar:
+  case TensorExp::Kind::kLoopVar:
     assert(x != kInvalidId && y == kInvalidId && !v && !o);
     loop = x;
     break;
   // Unary operations.
-  case kAbsF:
-  case kAbsC:
-  case kAbsI:
-  case kCeilF:
-  case kFloorF:
-  case kSqrtF:
-  case kSqrtC:
-  case kExpm1F:
-  case kExpm1C:
-  case kLog1pF:
-  case kLog1pC:
-  case kSinF:
-  case kSinC:
-  case kTanhF:
-  case kTanhC:
-  case kNegF:
-  case kNegC:
-  case kNegI:
-  case kCIm:
-  case kCRe:
+  case TensorExp::Kind::kAbsF:
+  case TensorExp::Kind::kAbsC:
+  case TensorExp::Kind::kAbsI:
+  case TensorExp::Kind::kCeilF:
+  case TensorExp::Kind::kFloorF:
+  case TensorExp::Kind::kSqrtF:
+  case TensorExp::Kind::kSqrtC:
+  case TensorExp::Kind::kExpm1F:
+  case TensorExp::Kind::kExpm1C:
+  case TensorExp::Kind::kLog1pF:
+  case TensorExp::Kind::kLog1pC:
+  case TensorExp::Kind::kSinF:
+  case TensorExp::Kind::kSinC:
+  case TensorExp::Kind::kTanhF:
+  case TensorExp::Kind::kTanhC:
+  case TensorExp::Kind::kNegF:
+  case TensorExp::Kind::kNegC:
+  case TensorExp::Kind::kNegI:
+  case TensorExp::Kind::kCIm:
+  case TensorExp::Kind::kCRe:
     assert(x != kInvalidId && y == kInvalidId && !v && !o);
     children.e0 = x;
     children.e1 = y;
     break;
-  case kTruncF:
-  case kExtF:
-  case kCastFS:
-  case kCastFU:
-  case kCastSF:
-  case kCastUF:
-  case kCastS:
-  case kCastU:
-  case kCastIdx:
-  case kTruncI:
-  case kBitCast:
+  case TensorExp::Kind::kTruncF:
+  case TensorExp::Kind::kExtF:
+  case TensorExp::Kind::kCastFS:
+  case TensorExp::Kind::kCastFU:
+  case TensorExp::Kind::kCastSF:
+  case TensorExp::Kind::kCastUF:
+  case TensorExp::Kind::kCastS:
+  case TensorExp::Kind::kCastU:
+  case TensorExp::Kind::kCastIdx:
+  case TensorExp::Kind::kTruncI:
+  case TensorExp::Kind::kBitCast:
     assert(x != kInvalidId && y == kInvalidId && v && !o);
     children.e0 = x;
     children.e1 = y;
     break;
-  case kBinaryBranch:
-  case kSelect:
+  case TensorExp::Kind::kBinaryBranch:
+  case TensorExp::Kind::kSelect:
     assert(x != kInvalidId && y == kInvalidId && !v && o);
     children.e0 = x;
     children.e1 = y;
     break;
-  case kUnary:
+  case TensorExp::Kind::kUnary:
     // No assertion on y can be made, as the branching paths involve both
     // a unary (`mapSet`) and binary (`disjSet`) pathway.
     assert(x != kInvalidId && !v && o);
@@ -167,31 +167,31 @@ TensorExp::TensorExp(Kind k, unsigned x, ExprId y, Value v, Operation *o)
     children.e1 = y;
     break;
   // Binary operations.
-  case kMulF:
-  case kMulC:
-  case kMulI:
-  case kDivF:
-  case kDivC:
-  case kDivS:
-  case kDivU:
-  case kAddF:
-  case kAddC:
-  case kAddI:
-  case kSubF:
-  case kSubC:
-  case kSubI:
-  case kAndI:
-  case kOrI:
-  case kXorI:
-  case kShrS:
-  case kShrU:
-  case kShlI:
+  case TensorExp::Kind::kMulF:
+  case TensorExp::Kind::kMulC:
+  case TensorExp::Kind::kMulI:
+  case TensorExp::Kind::kDivF:
+  case TensorExp::Kind::kDivC:
+  case TensorExp::Kind::kDivS:
+  case TensorExp::Kind::kDivU:
+  case TensorExp::Kind::kAddF:
+  case TensorExp::Kind::kAddC:
+  case TensorExp::Kind::kAddI:
+  case TensorExp::Kind::kSubF:
+  case TensorExp::Kind::kSubC:
+  case TensorExp::Kind::kSubI:
+  case TensorExp::Kind::kAndI:
+  case TensorExp::Kind::kOrI:
+  case TensorExp::Kind::kXorI:
+  case TensorExp::Kind::kShrS:
+  case TensorExp::Kind::kShrU:
+  case TensorExp::Kind::kShlI:
     assert(x != kInvalidId && y != kInvalidId && !v && !o);
     children.e0 = x;
     children.e1 = y;
     break;
-  case kBinary:
-  case kReduce:
+  case TensorExp::Kind::kBinary:
+  case TensorExp::Kind::kReduce:
     assert(x != kInvalidId && y != kInvalidId && !v && o);
     children.e0 = x;
     children.e1 = y;
@@ -231,9 +231,11 @@ Merger::Merger(unsigned numInputOutputTensors, unsigned numNativeLoops,
 // Lattice methods.
 //===----------------------------------------------------------------------===//
 
-ExprId Merger::addExp(Kind k, unsigned x, ExprId y, Value v, Operation *op) {
+ExprId Merger::addExp(TensorExp::Kind k, unsigned x, ExprId y, Value v,
+                      Operation *op) {
   const ExprId e = tensorExps.size();
-  assert((k != kTensor || x < numTensors) && (k != kLoopVar || x < numLoops));
+  assert((k != TensorExp::Kind::kTensor || x < numTensors) &&
+         (k != TensorExp::Kind::kLoopVar || x < numLoops));
   tensorExps.emplace_back(k, x, y, v, op);
   return e;
 }
@@ -251,7 +253,7 @@ LatSetId Merger::addSet() {
   return s;
 }
 
-LatPointId Merger::conjLat(Kind kind, LatPointId p0, LatPointId p1,
+LatPointId Merger::conjLat(TensorExp::Kind kind, LatPointId p0, LatPointId p1,
                            Operation *op) {
   const LatPointId p = latPoints.size();
   BitVector bits(latPoints[p0].bits);
@@ -262,7 +264,8 @@ LatPointId Merger::conjLat(Kind kind, LatPointId p0, LatPointId p1,
   return p;
 }
 
-LatSetId Merger::conjSet(Kind kind, LatSetId s0, LatSetId s1, Operation *op) {
+LatSetId Merger::conjSet(TensorExp::Kind kind, LatSetId s0, LatSetId s1,
+                         Operation *op) {
   const LatSetId s = addSet();
   for (const LatPointId p0 : latSets[s0])
     for (const LatPointId p1 : latSets[s1])
@@ -270,28 +273,31 @@ LatSetId Merger::conjSet(Kind kind, LatSetId s0, LatSetId s1, Operation *op) {
   return s;
 }
 
-LatSetId Merger::disjSet(Kind kind, LatSetId s0, LatSetId s1, Operation *op) {
+LatSetId Merger::disjSet(TensorExp::Kind kind, LatSetId s0, LatSetId s1,
+                         Operation *op) {
   const LatSetId s = conjSet(kind, s0, s1, op);
   // Followed by all in s0.
   for (const LatPointId p : latSets[s0])
     latSets[s].push_back(p);
   // Map binary 0-y to unary -y.
   // TODO: move this if-else logic into buildLattices
-  if (kind == kSubF)
-    s1 = mapSet(kNegF, s1);
-  else if (kind == kSubC)
-    s1 = mapSet(kNegC, s1);
-  else if (kind == kSubI)
-    s1 = mapSet(kNegI, s1);
+  if (kind == TensorExp::Kind::kSubF)
+    s1 = mapSet(TensorExp::Kind::kNegF, s1);
+  else if (kind == TensorExp::Kind::kSubC)
+    s1 = mapSet(TensorExp::Kind::kNegC, s1);
+  else if (kind == TensorExp::Kind::kSubI)
+    s1 = mapSet(TensorExp::Kind::kNegI, s1);
   // Followed by all in s1.
   for (const LatPointId p : latSets[s1])
     latSets[s].push_back(p);
   return s;
 }
 
-LatSetId Merger::combiSet(Kind kind, LatSetId s0, LatSetId s1, Operation *orig,
-                          bool includeLeft, Kind ltrans, Operation *opleft,
-                          bool includeRight, Kind rtrans, Operation *opright) {
+LatSetId Merger::combiSet(TensorExp::Kind kind, LatSetId s0, LatSetId s1,
+                          Operation *orig, bool includeLeft,
+                          TensorExp::Kind ltrans, Operation *opleft,
+                          bool includeRight, TensorExp::Kind rtrans,
+                          Operation *opright) {
   const LatSetId s = conjSet(kind, s0, s1, orig);
   // Left Region.
   if (includeLeft) {
@@ -310,8 +316,9 @@ LatSetId Merger::combiSet(Kind kind, LatSetId s0, LatSetId s1, Operation *orig,
   return s;
 }
 
-LatSetId Merger::mapSet(Kind kind, LatSetId s0, Value v, Operation *op) {
-  assert(kAbsF <= kind && kind <= kSelect);
+LatSetId Merger::mapSet(TensorExp::Kind kind, LatSetId s0, Value v,
+                        Operation *op) {
+  assert(TensorExp::Kind::kAbsF <= kind && kind <= TensorExp::Kind::kSelect);
   const LatSetId s = addSet();
   for (const LatPointId p : latSets[s0]) {
     const ExprId e = addExp(kind, latPoints[p].exp, v, op);
@@ -414,7 +421,7 @@ bool Merger::onlyDenseDiff(LatPointId i, LatPointId j) const {
 }
 
 bool Merger::expContainsTensor(ExprId e, TensorId t) const {
-  if (tensorExps[e].kind == kTensor)
+  if (tensorExps[e].kind == TensorExp::Kind::kTensor)
     return tensorExps[e].tensor == t;
 
   switch (getExpArity(tensorExps[e].kind)) {
@@ -439,13 +446,13 @@ bool Merger::expContainsTensor(ExprId e, TensorId t) const {
 
 bool Merger::hasNegateOnOut(ExprId e) const {
   switch (tensorExps[e].kind) {
-  case kNegF:
-  case kNegC:
-  case kNegI:
+  case TensorExp::Kind::kNegF:
+  case TensorExp::Kind::kNegC:
+  case TensorExp::Kind::kNegI:
     return expContainsTensor(tensorExps[e].children.e0, outTensor);
-  case kSubF:
-  case kSubC:
-  case kSubI:
+  case TensorExp::Kind::kSubF:
+  case TensorExp::Kind::kSubC:
+  case TensorExp::Kind::kSubI:
     return expContainsTensor(tensorExps[e].children.e1, outTensor) ||
            hasNegateOnOut(tensorExps[e].children.e0);
   default: {
@@ -467,82 +474,82 @@ bool Merger::isSingleCondition(TensorId t, ExprId e) const {
   assert(t < numTensors && e < tensorExps.size());
   switch (tensorExps[e].kind) {
   // Leaf.
-  case kTensor:
+  case TensorExp::Kind::kTensor:
     return tensorExps[e].tensor == t;
-  case kInvariant:
-  case kLoopVar:
+  case TensorExp::Kind::kInvariant:
+  case TensorExp::Kind::kLoopVar:
     return false;
   // Unary operations.
-  case kAbsF:
-  case kAbsC:
-  case kAbsI:
-  case kCeilF:
-  case kFloorF:
-  case kSqrtF:
-  case kSqrtC:
-  case kExpm1F:
-  case kExpm1C:
-  case kLog1pF:
-  case kLog1pC:
-  case kSinF:
-  case kSinC:
-  case kTanhF:
-  case kTanhC:
-  case kNegF:
-  case kNegC:
-  case kNegI:
-  case kTruncF:
-  case kExtF:
-  case kCastFS:
-  case kCastFU:
-  case kCastSF:
-  case kCastUF:
-  case kCastS:
-  case kCastU:
-  case kCastIdx:
-  case kTruncI:
-  case kCIm:
-  case kCRe:
-  case kBitCast:
+  case TensorExp::Kind::kAbsF:
+  case TensorExp::Kind::kAbsC:
+  case TensorExp::Kind::kAbsI:
+  case TensorExp::Kind::kCeilF:
+  case TensorExp::Kind::kFloorF:
+  case TensorExp::Kind::kSqrtF:
+  case TensorExp::Kind::kSqrtC:
+  case TensorExp::Kind::kExpm1F:
+  case TensorExp::Kind::kExpm1C:
+  case TensorExp::Kind::kLog1pF:
+  case TensorExp::Kind::kLog1pC:
+  case TensorExp::Kind::kSinF:
+  case TensorExp::Kind::kSinC:
+  case TensorExp::Kind::kTanhF:
+  case TensorExp::Kind::kTanhC:
+  case TensorExp::Kind::kNegF:
+  case TensorExp::Kind::kNegC:
+  case TensorExp::Kind::kNegI:
+  case TensorExp::Kind::kTruncF:
+  case TensorExp::Kind::kExtF:
+  case TensorExp::Kind::kCastFS:
+  case TensorExp::Kind::kCastFU:
+  case TensorExp::Kind::kCastSF:
+  case TensorExp::Kind::kCastUF:
+  case TensorExp::Kind::kCastS:
+  case TensorExp::Kind::kCastU:
+  case TensorExp::Kind::kCastIdx:
+  case TensorExp::Kind::kTruncI:
+  case TensorExp::Kind::kCIm:
+  case TensorExp::Kind::kCRe:
+  case TensorExp::Kind::kBitCast:
     return isSingleCondition(t, tensorExps[e].children.e0);
-  case kBinaryBranch:
-  case kUnary:
-  case kSelect:
+  case TensorExp::Kind::kBinaryBranch:
+  case TensorExp::Kind::kUnary:
+  case TensorExp::Kind::kSelect:
     return false;
   // Binary operations.
-  case kDivF: // note: x / c only
-  case kDivC:
-  case kDivS:
-  case kDivU:
+  case TensorExp::Kind::kDivF: // note: x / c only
+  case TensorExp::Kind::kDivC:
+  case TensorExp::Kind::kDivS:
+  case TensorExp::Kind::kDivU:
     assert(!maybeZero(tensorExps[e].children.e1));
     return isSingleCondition(t, tensorExps[e].children.e0);
-  case kShrS: // note: x >> inv only
-  case kShrU:
-  case kShlI:
+  case TensorExp::Kind::kShrS: // note: x >> inv only
+  case TensorExp::Kind::kShrU:
+  case TensorExp::Kind::kShlI:
     assert(isInvariant(tensorExps[e].children.e1));
     return isSingleCondition(t, tensorExps[e].children.e0);
-  case kMulF:
-  case kMulC:
-  case kMulI:
-  case kAndI:
+  case TensorExp::Kind::kMulF:
+  case TensorExp::Kind::kMulC:
+  case TensorExp::Kind::kMulI:
+  case TensorExp::Kind::kAndI:
     if (isSingleCondition(t, tensorExps[e].children.e0))
       return isSingleCondition(t, tensorExps[e].children.e1) ||
              isInvariant(tensorExps[e].children.e1);
     if (isSingleCondition(t, tensorExps[e].children.e1))
       return isInvariant(tensorExps[e].children.e0);
     return false;
-  case kAddF:
-  case kAddC:
-  case kAddI:
+  case TensorExp::Kind::kAddF:
+  case TensorExp::Kind::kAddC:
+  case TensorExp::Kind::kAddI:
     return isSingleCondition(t, tensorExps[e].children.e0) &&
            isSingleCondition(t, tensorExps[e].children.e1);
-  case kSubF:
-  case kSubC:
-  case kSubI:
-  case kOrI:
-  case kXorI:
-  case kBinary:
-  case kReduce:
+  case TensorExp::Kind::kSubF:
+  case TensorExp::Kind::kSubC:
+  case TensorExp::Kind::kSubI:
+  case TensorExp::Kind::kOrI:
+  case TensorExp::Kind::kXorI:
+  case TensorExp::Kind::kBinary:
+  case TensorExp::Kind::kReduce:
     return false;
   }
   llvm_unreachable("unexpected kind");
@@ -572,98 +579,98 @@ bool Merger::hasSparseIdxReduction(const BitVector &bits) const {
 // Print methods (for debugging).
 //===----------------------------------------------------------------------===//
 
-static const char *kindToOpSymbol(Kind kind) {
+static const char *kindToOpSymbol(TensorExp::Kind kind) {
   switch (kind) {
   // Leaf.
-  case kTensor:
+  case TensorExp::Kind::kTensor:
     return "tensor";
-  case kInvariant:
+  case TensorExp::Kind::kInvariant:
     return "invariant";
-  case kLoopVar:
+  case TensorExp::Kind::kLoopVar:
     return "index";
   // Unary operations.
-  case kAbsF:
-  case kAbsC:
-  case kAbsI:
+  case TensorExp::Kind::kAbsF:
+  case TensorExp::Kind::kAbsC:
+  case TensorExp::Kind::kAbsI:
     return "abs";
-  case kCeilF:
+  case TensorExp::Kind::kCeilF:
     return "ceil";
-  case kFloorF:
+  case TensorExp::Kind::kFloorF:
     return "floor";
-  case kSqrtF:
-  case kSqrtC:
+  case TensorExp::Kind::kSqrtF:
+  case TensorExp::Kind::kSqrtC:
     return "sqrt";
-  case kExpm1F:
-  case kExpm1C:
+  case TensorExp::Kind::kExpm1F:
+  case TensorExp::Kind::kExpm1C:
     return "expm1";
-  case kLog1pF:
-  case kLog1pC:
+  case TensorExp::Kind::kLog1pF:
+  case TensorExp::Kind::kLog1pC:
     return "log1p";
-  case kSinF:
-  case kSinC:
+  case TensorExp::Kind::kSinF:
+  case TensorExp::Kind::kSinC:
     return "sin";
-  case kTanhF:
-  case kTanhC:
+  case TensorExp::Kind::kTanhF:
+  case TensorExp::Kind::kTanhC:
     return "tanh";
-  case kNegF:
-  case kNegC:
-  case kNegI:
+  case TensorExp::Kind::kNegF:
+  case TensorExp::Kind::kNegC:
+  case TensorExp::Kind::kNegI:
     return "-";
-  case kTruncF:
-  case kExtF:
-  case kCastFS:
-  case kCastFU:
-  case kCastSF:
-  case kCastUF:
-  case kCastS:
-  case kCastU:
-  case kCastIdx:
-  case kTruncI:
-  case kCIm:
+  case TensorExp::Kind::kTruncF:
+  case TensorExp::Kind::kExtF:
+  case TensorExp::Kind::kCastFS:
+  case TensorExp::Kind::kCastFU:
+  case TensorExp::Kind::kCastSF:
+  case TensorExp::Kind::kCastUF:
+  case TensorExp::Kind::kCastS:
+  case TensorExp::Kind::kCastU:
+  case TensorExp::Kind::kCastIdx:
+  case TensorExp::Kind::kTruncI:
+  case TensorExp::Kind::kCIm:
     return "complex.im";
-  case kCRe:
+  case TensorExp::Kind::kCRe:
     return "complex.re";
-  case kBitCast:
+  case TensorExp::Kind::kBitCast:
     return "cast";
-  case kBinaryBranch:
+  case TensorExp::Kind::kBinaryBranch:
     return "binary_branch";
-  case kUnary:
+  case TensorExp::Kind::kUnary:
     return "unary";
-  case kSelect:
+  case TensorExp::Kind::kSelect:
     return "select";
   // Binary operations.
-  case kMulF:
-  case kMulC:
-  case kMulI:
+  case TensorExp::Kind::kMulF:
+  case TensorExp::Kind::kMulC:
+  case TensorExp::Kind::kMulI:
     return "*";
-  case kDivF:
-  case kDivC:
-  case kDivS:
-  case kDivU:
+  case TensorExp::Kind::kDivF:
+  case TensorExp::Kind::kDivC:
+  case TensorExp::Kind::kDivS:
+  case TensorExp::Kind::kDivU:
     return "/";
-  case kAddF:
-  case kAddC:
-  case kAddI:
+  case TensorExp::Kind::kAddF:
+  case TensorExp::Kind::kAddC:
+  case TensorExp::Kind::kAddI:
     return "+";
-  case kSubF:
-  case kSubC:
-  case kSubI:
+  case TensorExp::Kind::kSubF:
+  case TensorExp::Kind::kSubC:
+  case TensorExp::Kind::kSubI:
     return "-";
-  case kAndI:
+  case TensorExp::Kind::kAndI:
     return "&";
-  case kOrI:
+  case TensorExp::Kind::kOrI:
     return "|";
-  case kXorI:
+  case TensorExp::Kind::kXorI:
     return "^";
-  case kShrS:
+  case TensorExp::Kind::kShrS:
     return "a>>";
-  case kShrU:
+  case TensorExp::Kind::kShrU:
     return ">>";
-  case kShlI:
+  case TensorExp::Kind::kShlI:
     return "<<";
-  case kBinary:
+  case TensorExp::Kind::kBinary:
     return "binary";
-  case kReduce:
+  case TensorExp::Kind::kReduce:
     return "reduce";
   }
   llvm_unreachable("unexpected kind for symbol");
@@ -672,79 +679,79 @@ static const char *kindToOpSymbol(Kind kind) {
 void Merger::dumpExp(ExprId e) const {
   switch (tensorExps[e].kind) {
   // Leaf.
-  case kTensor:
+  case TensorExp::Kind::kTensor:
     if (tensorExps[e].tensor == syntheticTensor)
       llvm::dbgs() << "synthetic_";
     else if (tensorExps[e].tensor == outTensor)
       llvm::dbgs() << "output_";
     llvm::dbgs() << "tensor_" << tensorExps[e].tensor;
     break;
-  case kInvariant:
+  case TensorExp::Kind::kInvariant:
     llvm::dbgs() << "invariant";
     break;
-  case kLoopVar:
+  case TensorExp::Kind::kLoopVar:
     llvm::dbgs() << "loopvar_" << tensorExps[e].loop;
     break;
   // Unary operations.
-  case kAbsF:
-  case kAbsC:
-  case kAbsI:
-  case kCeilF:
-  case kFloorF:
-  case kSqrtF:
-  case kSqrtC:
-  case kExpm1F:
-  case kExpm1C:
-  case kLog1pF:
-  case kLog1pC:
-  case kSinF:
-  case kSinC:
-  case kTanhF:
-  case kTanhC:
-  case kNegF:
-  case kNegC:
-  case kNegI:
-  case kTruncF:
-  case kExtF:
-  case kCastFS:
-  case kCastFU:
-  case kCastSF:
-  case kCastUF:
-  case kCastS:
-  case kCastU:
-  case kCastIdx:
-  case kTruncI:
-  case kCIm:
-  case kCRe:
-  case kBitCast:
-  case kBinaryBranch:
-  case kUnary:
-  case kSelect:
+  case TensorExp::Kind::kAbsF:
+  case TensorExp::Kind::kAbsC:
+  case TensorExp::Kind::kAbsI:
+  case TensorExp::Kind::kCeilF:
+  case TensorExp::Kind::kFloorF:
+  case TensorExp::Kind::kSqrtF:
+  case TensorExp::Kind::kSqrtC:
+  case TensorExp::Kind::kExpm1F:
+  case TensorExp::Kind::kExpm1C:
+  case TensorExp::Kind::kLog1pF:
+  case TensorExp::Kind::kLog1pC:
+  case TensorExp::Kind::kSinF:
+  case TensorExp::Kind::kSinC:
+  case TensorExp::Kind::kTanhF:
+  case TensorExp::Kind::kTanhC:
+  case TensorExp::Kind::kNegF:
+  case TensorExp::Kind::kNegC:
+  case TensorExp::Kind::kNegI:
+  case TensorExp::Kind::kTruncF:
+  case TensorExp::Kind::kExtF:
+  case TensorExp::Kind::kCastFS:
+  case TensorExp::Kind::kCastFU:
+  case TensorExp::Kind::kCastSF:
+  case TensorExp::Kind::kCastUF:
+  case TensorExp::Kind::kCastS:
+  case TensorExp::Kind::kCastU:
+  case TensorExp::Kind::kCastIdx:
+  case TensorExp::Kind::kTruncI:
+  case TensorExp::Kind::kCIm:
+  case TensorExp::Kind::kCRe:
+  case TensorExp::Kind::kBitCast:
+  case TensorExp::Kind::kBinaryBranch:
+  case TensorExp::Kind::kUnary:
+  case TensorExp::Kind::kSelect:
     llvm::dbgs() << kindToOpSymbol(tensorExps[e].kind) << " ";
     dumpExp(tensorExps[e].children.e0);
     break;
   // Binary operations.
-  case kMulF:
-  case kMulC:
-  case kMulI:
-  case kDivF:
-  case kDivC:
-  case kDivS:
-  case kDivU:
-  case kAddF:
-  case kAddC:
-  case kAddI:
-  case kSubF:
-  case kSubC:
-  case kSubI:
-  case kAndI:
-  case kOrI:
-  case kXorI:
-  case kShrS:
-  case kShrU:
-  case kShlI:
-  case kBinary:
-  case kReduce:
+  case TensorExp::Kind::kMulF:
+  case TensorExp::Kind::kMulC:
+  case TensorExp::Kind::kMulI:
+  case TensorExp::Kind::kDivF:
+  case TensorExp::Kind::kDivC:
+  case TensorExp::Kind::kDivS:
+  case TensorExp::Kind::kDivU:
+  case TensorExp::Kind::kAddF:
+  case TensorExp::Kind::kAddC:
+  case TensorExp::Kind::kAddI:
+  case TensorExp::Kind::kSubF:
+  case TensorExp::Kind::kSubC:
+  case TensorExp::Kind::kSubI:
+  case TensorExp::Kind::kAndI:
+  case TensorExp::Kind::kOrI:
+  case TensorExp::Kind::kXorI:
+  case TensorExp::Kind::kShrS:
+  case TensorExp::Kind::kShrU:
+  case TensorExp::Kind::kShlI:
+  case TensorExp::Kind::kBinary:
+  case TensorExp::Kind::kReduce:
     llvm::dbgs() << "(";
     dumpExp(tensorExps[e].children.e0);
     llvm::dbgs() << " " << kindToOpSymbol(tensorExps[e].kind) << " ";
@@ -793,12 +800,12 @@ void Merger::dumpBits(const BitVector &bits) const {
 //===----------------------------------------------------------------------===//
 
 LatSetId Merger::buildLattices(ExprId e, LoopId i) {
-  const Kind kind = tensorExps[e].kind;
+  const TensorExp::Kind kind = tensorExps[e].kind;
   switch (kind) {
   // Leaf.
-  case kTensor:
-  case kInvariant:
-  case kLoopVar: {
+  case TensorExp::Kind::kTensor:
+  case TensorExp::Kind::kInvariant:
+  case TensorExp::Kind::kLoopVar: {
     // Either the loop-var is really used in the tensor expression, or it is
     // set to the undefined loop-var in that level. An invariant expression,
     // a proper index value, and a truly dynamic sparse output tensor are set
@@ -806,7 +813,7 @@ LatSetId Merger::buildLattices(ExprId e, LoopId i) {
     // iteration space is not skipped as a result of their contents.
     const LatSetId s = addSet();
     TensorId t = syntheticTensor;
-    if (kind == kTensor) {
+    if (kind == TensorExp::Kind::kTensor) {
       t = tensorExps[e].tensor;
       if (hasSparseOut && t == outTensor)
         t = syntheticTensor;
@@ -815,37 +822,37 @@ LatSetId Merger::buildLattices(ExprId e, LoopId i) {
     return s;
   }
   // Unary operations.
-  case kAbsF:
-  case kAbsC:
-  case kAbsI:
-  case kCeilF:
-  case kFloorF:
-  case kSqrtF:
-  case kSqrtC:
-  case kExpm1F:
-  case kExpm1C:
-  case kLog1pF:
-  case kLog1pC:
-  case kSinF:
-  case kSinC:
-  case kTanhF:
-  case kTanhC:
-  case kNegF:
-  case kNegC:
-  case kNegI:
-  case kTruncF:
-  case kExtF:
-  case kCastFS:
-  case kCastFU:
-  case kCastSF:
-  case kCastUF:
-  case kCastS:
-  case kCastU:
-  case kCastIdx:
-  case kTruncI:
-  case kCIm:
-  case kCRe:
-  case kBitCast:
+  case TensorExp::Kind::kAbsF:
+  case TensorExp::Kind::kAbsC:
+  case TensorExp::Kind::kAbsI:
+  case TensorExp::Kind::kCeilF:
+  case TensorExp::Kind::kFloorF:
+  case TensorExp::Kind::kSqrtF:
+  case TensorExp::Kind::kSqrtC:
+  case TensorExp::Kind::kExpm1F:
+  case TensorExp::Kind::kExpm1C:
+  case TensorExp::Kind::kLog1pF:
+  case TensorExp::Kind::kLog1pC:
+  case TensorExp::Kind::kSinF:
+  case TensorExp::Kind::kSinC:
+  case TensorExp::Kind::kTanhF:
+  case TensorExp::Kind::kTanhC:
+  case TensorExp::Kind::kNegF:
+  case TensorExp::Kind::kNegC:
+  case TensorExp::Kind::kNegI:
+  case TensorExp::Kind::kTruncF:
+  case TensorExp::Kind::kExtF:
+  case TensorExp::Kind::kCastFS:
+  case TensorExp::Kind::kCastFU:
+  case TensorExp::Kind::kCastSF:
+  case TensorExp::Kind::kCastUF:
+  case TensorExp::Kind::kCastS:
+  case TensorExp::Kind::kCastU:
+  case TensorExp::Kind::kCastIdx:
+  case TensorExp::Kind::kTruncI:
+  case TensorExp::Kind::kCIm:
+  case TensorExp::Kind::kCRe:
+  case TensorExp::Kind::kBitCast:
     // A zero preserving operation (viz. f(0) = 0, [Bik96,Ch5]) maps the
     // lattice set of the operand through the operator into a new set.
     //
@@ -854,13 +861,13 @@ LatSetId Merger::buildLattices(ExprId e, LoopId i) {
     //    | 0 |-y |
     return mapSet(kind, buildLattices(tensorExps[e].children.e0, i),
                   tensorExps[e].val);
-  case kBinaryBranch:
-  case kSelect:
+  case TensorExp::Kind::kBinaryBranch:
+  case TensorExp::Kind::kSelect:
     // The left or right half of a binary operation which has already
     // been split into separate operations for each region.
     return mapSet(kind, buildLattices(tensorExps[e].children.e0, i), Value(),
                   tensorExps[e].op);
-  case kUnary:
+  case TensorExp::Kind::kUnary:
     // A custom unary operation.
     //
     //  op y|    !y    |     y      |
@@ -879,14 +886,14 @@ LatSetId Merger::buildLattices(ExprId e, LoopId i) {
       Block &absentBlock = absentRegion.front();
       YieldOp absentYield = cast<YieldOp>(absentBlock.getTerminator());
       Value absentVal = absentYield.getResult();
-      const ExprId rhs = addExp(kInvariant, absentVal);
+      const ExprId rhs = addExp(TensorExp::Kind::kInvariant, absentVal);
       return disjSet(kind, child0, buildLattices(rhs, i), unop);
     }
   // Binary operations.
-  case kMulF:
-  case kMulC:
-  case kMulI:
-  case kAndI:
+  case TensorExp::Kind::kMulF:
+  case TensorExp::Kind::kMulC:
+  case TensorExp::Kind::kMulI:
+  case TensorExp::Kind::kAndI:
     // A multiplicative operation only needs to be performed
     // for the conjunction of sparse iteration spaces.
     //
@@ -898,10 +905,10 @@ LatSetId Merger::buildLattices(ExprId e, LoopId i) {
     // Note even here, 0*NaN=NaN and 0*Inf=NaN, but that is ignored.
     return conjSet(kind, buildLattices(tensorExps[e].children.e0, i),
                    buildLattices(tensorExps[e].children.e1, i));
-  case kDivF:
-  case kDivC:
-  case kDivS:
-  case kDivU:
+  case TensorExp::Kind::kDivF:
+  case TensorExp::Kind::kDivC:
+  case TensorExp::Kind::kDivS:
+  case TensorExp::Kind::kDivU:
     // A division is tricky, since 0/0, 0/c, c/0 all have
     // specific outcomes for floating-point and integers.
     // Thus, we need to traverse the full iteration space.
@@ -918,14 +925,14 @@ LatSetId Merger::buildLattices(ExprId e, LoopId i) {
     assert(!maybeZero(tensorExps[e].children.e1));
     return conjSet(kind, buildLattices(tensorExps[e].children.e0, i),
                    buildLattices(tensorExps[e].children.e1, i));
-  case kAddF:
-  case kAddC:
-  case kAddI:
-  case kSubF:
-  case kSubC:
-  case kSubI:
-  case kOrI:
-  case kXorI:
+  case TensorExp::Kind::kAddF:
+  case TensorExp::Kind::kAddC:
+  case TensorExp::Kind::kAddI:
+  case TensorExp::Kind::kSubF:
+  case TensorExp::Kind::kSubC:
+  case TensorExp::Kind::kSubI:
+  case TensorExp::Kind::kOrI:
+  case TensorExp::Kind::kXorI:
     // An additive operation needs to be performed
     // for the disjunction of sparse iteration spaces.
     //
@@ -935,16 +942,16 @@ LatSetId Merger::buildLattices(ExprId e, LoopId i) {
     //   x | x |x+y|     x | x |x-y|
     return disjSet(kind, buildLattices(tensorExps[e].children.e0, i),
                    buildLattices(tensorExps[e].children.e1, i));
-  case kShrS:
-  case kShrU:
-  case kShlI:
+  case TensorExp::Kind::kShrS:
+  case TensorExp::Kind::kShrU:
+  case TensorExp::Kind::kShlI:
     // A shift operation by an invariant amount (viz. tensor expressions
     // can only occur at the left-hand-side of the operator) can be handled
     // with the conjuction rule.
     assert(isInvariant(tensorExps[e].children.e1));
     return conjSet(kind, buildLattices(tensorExps[e].children.e0, i),
                    buildLattices(tensorExps[e].children.e1, i));
-  case kBinary:
+  case TensorExp::Kind::kBinary:
     // A custom binary operation.
     //
     //  x op y|   !y    |       y      |
@@ -971,11 +978,11 @@ LatSetId Merger::buildLattices(ExprId e, LoopId i) {
       }
       bool includeLeft = binop.getLeftIdentity() || !leftRegion.empty();
       bool includeRight = binop.getRightIdentity() || !rightRegion.empty();
-      return combiSet(kBinary, child0, child1, binop, includeLeft,
-                      kBinaryBranch, leftYield, includeRight, kBinaryBranch,
-                      rightYield);
+      return combiSet(TensorExp::Kind::kBinary, child0, child1, binop,
+                      includeLeft, TensorExp::Kind::kBinaryBranch, leftYield,
+                      includeRight, TensorExp::Kind::kBinaryBranch, rightYield);
     }
-  case kReduce:
+  case TensorExp::Kind::kReduce:
     // A custom reduce operation.
     return conjSet(kind, buildLattices(tensorExps[e].children.e0, i),
                    buildLattices(tensorExps[e].children.e1, i),
@@ -993,7 +1000,7 @@ std::optional<ExprId> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) {
 
 /// Only returns false if we are certain this is a nonzero.
 bool Merger::maybeZero(ExprId e) const {
-  if (tensorExps[e].kind == kInvariant) {
+  if (tensorExps[e].kind == TensorExp::Kind::kInvariant) {
     if (auto c = tensorExps[e].val.getDefiningOp<complex::ConstantOp>()) {
       ArrayAttr arrayAttr = c.getValue();
       return arrayAttr[0].cast<FloatAttr>().getValue().isZero() &&
@@ -1008,7 +1015,7 @@ bool Merger::maybeZero(ExprId e) const {
 }
 
 bool Merger::isInvariant(ExprId e) const {
-  return tensorExps[e].kind == kInvariant;
+  return tensorExps[e].kind == TensorExp::Kind::kInvariant;
 }
 
 Type Merger::inferType(ExprId e, Value src) const {
@@ -1060,21 +1067,21 @@ std::optional<ExprId> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
     if (arg.getOwner()->getParentOp() == op) {
       OpOperand &t = op->getOpOperand(argN);
       if (!op.isScalar(&t))
-        return addExp(kTensor, argN);
+        return addExp(TensorExp::Kind::kTensor, argN);
       v = t.get(); // get scalar value
     }
     // Any other argument (marked as scalar argument for the generic op
     // or belonging to an enveloping op) is considered invariant.
-    return addExp(kInvariant, v);
+    return addExp(TensorExp::Kind::kInvariant, v);
   }
   // Something defined outside is invariant.
   Operation *def = v.getDefiningOp();
   if (def->getBlock() != &op.getRegion().front())
-    return addExp(kInvariant, v);
+    return addExp(TensorExp::Kind::kInvariant, v);
   // Construct index operations.
   if (def->getNumOperands() == 0) {
     if (auto indexOp = dyn_cast<linalg::IndexOp>(def))
-      return addExp(kLoopVar, indexOp.getDim());
+      return addExp(TensorExp::Kind::kLoopVar, indexOp.getDim());
   }
   // Construct unary operations if subexpression can be built.
   if (def->getNumOperands() == 1) {
@@ -1082,73 +1089,73 @@ std::optional<ExprId> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
     if (x.has_value()) {
       const ExprId e = *x;
       if (isa<math::AbsFOp>(def))
-        return addExp(kAbsF, e);
+        return addExp(TensorExp::Kind::kAbsF, e);
       if (isa<complex::AbsOp>(def))
-        return addExp(kAbsC, e);
+        return addExp(TensorExp::Kind::kAbsC, e);
       if (isa<math::AbsIOp>(def))
-        return addExp(kAbsI, e);
+        return addExp(TensorExp::Kind::kAbsI, e);
       if (isa<math::CeilOp>(def))
-        return addExp(kCeilF, e);
+        return addExp(TensorExp::Kind::kCeilF, e);
       if (isa<math::FloorOp>(def))
-        return addExp(kFloorF, e);
+        return addExp(TensorExp::Kind::kFloorF, e);
       if (isa<math::SqrtOp>(def))
-        return addExp(kSqrtF, e);
+        return addExp(TensorExp::Kind::kSqrtF, e);
       if (isa<complex::SqrtOp>(def))
-        return addExp(kSqrtC, e);
+        return addExp(TensorExp::Kind::kSqrtC, e);
       if (isa<math::ExpM1Op>(def))
-        return addExp(kExpm1F, e);
+        return addExp(TensorExp::Kind::kExpm1F, e);
       if (isa<complex::Expm1Op>(def))
-        return addExp(kExpm1C, e);
+        return addExp(TensorExp::Kind::kExpm1C, e);
       if (isa<math::Log1pOp>(def))
-        return addExp(kLog1pF, e);
+        return addExp(TensorExp::Kind::kLog1pF, e);
       if (isa<complex::Log1pOp>(def))
-        return addExp(kLog1pC, e);
+        return addExp(TensorExp::Kind::kLog1pC, e);
       if (isa<math::SinOp>(def))
-        return addExp(kSinF, e);
+        return addExp(TensorExp::Kind::kSinF, e);
       if (isa<complex::SinOp>(def))
-        return addExp(kSinC, e);
+        return addExp(TensorExp::Kind::kSinC, e);
       if (isa<math::TanhOp>(def))
-        return addExp(kTanhF, e);
+        return addExp(TensorExp::Kind::kTanhF, e);
       if (isa<complex::TanhOp>(def))
-        return addExp(kTanhC, e);
+        return addExp(TensorExp::Kind::kTanhC, e);
       if (isa<arith::NegFOp>(def))
-        return addExp(kNegF, e); // no negi in std
+        return addExp(TensorExp::Kind::kNegF, e); // no negi in std
       if (isa<complex::NegOp>(def))
-        return addExp(kNegC, e);
+        return addExp(TensorExp::Kind::kNegC, e);
       if (isa<arith::TruncFOp>(def))
-        return addExp(kTruncF, e, v);
+        return addExp(TensorExp::Kind::kTruncF, e, v);
       if (isa<arith::ExtFOp>(def))
-        return addExp(kExtF, e, v);
+        return addExp(TensorExp::Kind::kExtF, e, v);
       if (isa<arith::FPToSIOp>(def))
-        return addExp(kCastFS, e, v);
+        return addExp(TensorExp::Kind::kCastFS, e, v);
       if (isa<arith::FPToUIOp>(def))
-        return addExp(kCastFU, e, v);
+        return addExp(TensorExp::Kind::kCastFU, e, v);
       if (isa<arith::SIToFPOp>(def))
-        return addExp(kCastSF, e, v);
+        return addExp(TensorExp::Kind::kCastSF, e, v);
       if (isa<arith::UIToFPOp>(def))
-        return addExp(kCastUF, e, v);
+        return addExp(TensorExp::Kind::kCastUF, e, v);
       if (isa<arith::ExtSIOp>(def))
-        return addExp(kCastS, e, v);
+        return addExp(TensorExp::Kind::kCastS, e, v);
       if (isa<arith::ExtUIOp>(def))
-        return addExp(kCastU, e, v);
+        return addExp(TensorExp::Kind::kCastU, e, v);
       if (isa<arith::IndexCastOp>(def))
-        return addExp(kCastIdx, e, v);
+        return addExp(TensorExp::Kind::kCastIdx, e, v);
       if (isa<arith::TruncIOp>(def))
-        return addExp(kTruncI, e, v);
+        return addExp(TensorExp::Kind::kTruncI, e, v);
       if (isa<complex::ImOp>(def))
-        return addExp(kCIm, e);
+        return addExp(TensorExp::Kind::kCIm, e);
       if (isa<complex::ReOp>(def))
-        return addExp(kCRe, e);
+        return addExp(TensorExp::Kind::kCRe, e);
       if (isa<arith::BitcastOp>(def))
-        return addExp(kBitCast, e, v);
+        return addExp(TensorExp::Kind::kBitCast, e, v);
       if (auto unop = dyn_cast<sparse_tensor::UnaryOp>(def)) {
         if (isAdmissibleBranch(unop, unop.getPresentRegion()) &&
             isAdmissibleBranch(unop, unop.getAbsentRegion()))
-          return addExp(kUnary, e, Value(), def);
+          return addExp(TensorExp::Kind::kUnary, e, Value(), def);
       }
       if (auto selop = dyn_cast<sparse_tensor::SelectOp>(def)) {
         if (isAdmissibleBranch(selop, selop.getRegion()))
-          return addExp(kSelect, e, Value(), def);
+          return addExp(TensorExp::Kind::kSelect, e, Value(), def);
       }
     }
   }
@@ -1162,50 +1169,50 @@ std::optional<ExprId> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
       const ExprId e0 = *x;
       const ExprId e1 = *y;
       if (isa<arith::MulFOp>(def))
-        return addExp(kMulF, e0, e1);
+        return addExp(TensorExp::Kind::kMulF, e0, e1);
       if (isa<complex::MulOp>(def))
-        return addExp(kMulC, e0, e1);
+        return addExp(TensorExp::Kind::kMulC, e0, e1);
       if (isa<arith::MulIOp>(def))
-        return addExp(kMulI, e0, e1);
+        return addExp(TensorExp::Kind::kMulI, e0, e1);
       if (isa<arith::DivFOp>(def) && !maybeZero(e1))
-        return addExp(kDivF, e0, e1);
+        return addExp(TensorExp::Kind::kDivF, e0, e1);
       if (isa<complex::DivOp>(def) && !maybeZero(e1))
-        return addExp(kDivC, e0, e1);
+        return addExp(TensorExp::Kind::kDivC, e0, e1);
       if (isa<arith::DivSIOp>(def) && !maybeZero(e1))
-        return addExp(kDivS, e0, e1);
+        return addExp(TensorExp::Kind::kDivS, e0, e1);
       if (isa<arith::DivUIOp>(def) && !maybeZero(e1))
-        return addExp(kDivU, e0, e1);
+        return addExp(TensorExp::Kind::kDivU, e0, e1);
       if (isa<arith::AddFOp>(def))
-        return addExp(kAddF, e0, e1);
+        return addExp(TensorExp::Kind::kAddF, e0, e1);
       if (isa<complex::AddOp>(def))
-        return addExp(kAddC, e0, e1);
+        return addExp(TensorExp::Kind::kAddC, e0, e1);
       if (isa<arith::AddIOp>(def))
-        return addExp(kAddI, e0, e1);
+        return addExp(TensorExp::Kind::kAddI, e0, e1);
       if (isa<arith::SubFOp>(def))
-        return addExp(kSubF, e0, e1);
+        return addExp(TensorExp::Kind::kSubF, e0, e1);
       if (isa<complex::SubOp>(def))
-        return addExp(kSubC, e0, e1);
+        return addExp(TensorExp::Kind::kSubC, e0, e1);
       if (isa<arith::SubIOp>(def))
-        return addExp(kSubI, e0, e1);
+        return addExp(TensorExp::Kind::kSubI, e0, e1);
       if (isa<arith::AndIOp>(def))
-        return addExp(kAndI, e0, e1);
+        return addExp(TensorExp::Kind::kAndI, e0, e1);
       if (isa<arith::OrIOp>(def))
-        return addExp(kOrI, e0, e1);
+        return addExp(TensorExp::Kind::kOrI, e0, e1);
       if (isa<arith::XOrIOp>(def))
-        return addExp(kXorI, e0, e1);
+        return addExp(TensorExp::Kind::kXorI, e0, e1);
       if (isa<arith::ShRSIOp>(def) && isInvariant(e1))
-        return addExp(kShrS, e0, e1);
+        return addExp(TensorExp::Kind::kShrS, e0, e1);
       if (isa<arith::ShRUIOp>(def) && isInvariant(e1))
-        return addExp(kShrU, e0, e1);
+        return addExp(TensorExp::Kind::kShrU, e0, e1);
       if (isa<arith::ShLIOp>(def) && isInvariant(e1))
-        return addExp(kShlI, e0, e1);
+        return addExp(TensorExp::Kind::kShlI, e0, e1);
       if (auto binop = dyn_cast<sparse_tensor::BinaryOp>(def)) {
         if (isAdmissibleBranch(binop, binop.getOverlapRegion()) &&
             (binop.getLeftIdentity() ||
              isAdmissibleBranch(binop, binop.getLeftRegion())) &&
             (binop.getRightIdentity() ||
              isAdmissibleBranch(binop, binop.getRightRegion())))
-          return addExp(kBinary, e0, e1, Value(), def);
+          return addExp(TensorExp::Kind::kBinary, e0, e1, Value(), def);
       }
     }
   }
@@ -1219,7 +1226,7 @@ std::optional<ExprId> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
       const ExprId e1 = *y;
       if (auto redop = dyn_cast<sparse_tensor::ReduceOp>(def)) {
         if (isAdmissibleBranch(redop, redop.getRegion()))
-          return addExp(kReduce, e0, e1, Value(), def);
+          return addExp(TensorExp::Kind::kReduce, e0, e1, Value(), def);
       }
     }
   }
@@ -1276,136 +1283,136 @@ Value Merger::buildExp(RewriterBase &rewriter, Location loc, ExprId e, Value v0,
                        Value v1) const {
   switch (tensorExps[e].kind) {
   // Leaf.
-  case kTensor:
-  case kInvariant:
-  case kLoopVar:
+  case TensorExp::Kind::kTensor:
+  case TensorExp::Kind::kInvariant:
+  case TensorExp::Kind::kLoopVar:
     llvm_unreachable("unexpected non-op");
   // Unary operations.
-  case kAbsF:
+  case TensorExp::Kind::kAbsF:
     return rewriter.create<math::AbsFOp>(loc, v0);
-  case kAbsC: {
+  case TensorExp::Kind::kAbsC: {
     auto type = v0.getType().cast<ComplexType>();
     auto eltType = type.getElementType().cast<FloatType>();
     return rewriter.create<complex::AbsOp>(loc, eltType, v0);
   }
-  case kAbsI:
+  case TensorExp::Kind::kAbsI:
     return rewriter.create<math::AbsIOp>(loc, v0);
-  case kCeilF:
+  case TensorExp::Kind::kCeilF:
     return rewriter.create<math::CeilOp>(loc, v0);
-  case kFloorF:
+  case TensorExp::Kind::kFloorF:
     return rewriter.create<math::FloorOp>(loc, v0);
-  case kSqrtF:
+  case TensorExp::Kind::kSqrtF:
     return rewriter.create<math::SqrtOp>(loc, v0);
-  case kSqrtC:
+  case TensorExp::Kind::kSqrtC:
     return rewriter.create<complex::SqrtOp>(loc, v0);
-  case kExpm1F:
+  case TensorExp::Kind::kExpm1F:
     return rewriter.create<math::ExpM1Op>(loc, v0);
-  case kExpm1C:
+  case TensorExp::Kind::kExpm1C:
     return rewriter.create<complex::Expm1Op>(loc, v0);
-  case kLog1pF:
+  case TensorExp::Kind::kLog1pF:
     return rewriter.create<math::Log1pOp>(loc, v0);
-  case kLog1pC:
+  case TensorExp::Kind::kLog1pC:
     return rewriter.create<complex::Log1pOp>(loc, v0);
-  case kSinF:
+  case TensorExp::Kind::kSinF:
     return rewriter.create<math::SinOp>(loc, v0);
-  case kSinC:
+  case TensorExp::Kind::kSinC:
     return rewriter.create<complex::SinOp>(loc, v0);
-  case kTanhF:
+  case TensorExp::Kind::kTanhF:
     return rewriter.create<math::TanhOp>(loc, v0);
-  case kTanhC:
+  case TensorExp::Kind::kTanhC:
     return rewriter.create<complex::TanhOp>(loc, v0);
-  case kNegF:
+  case TensorExp::Kind::kNegF:
     return rewriter.create<arith::NegFOp>(loc, v0);
-  case kNegC:
+  case TensorExp::Kind::kNegC:
     return rewriter.create<complex::NegOp>(loc, v0);
-  case kNegI: // no negi in std
+  case TensorExp::Kind::kNegI: // no negi in std
     return rewriter.create<arith::SubIOp>(
         loc,
         rewriter.create<arith::ConstantOp>(loc, v0.getType(),
                                            rewriter.getZeroAttr(v0.getType())),
         v0);
-  case kTruncF:
+  case TensorExp::Kind::kTruncF:
     return rewriter.create<arith::TruncFOp>(loc, inferType(e, v0), v0);
-  case kExtF:
+  case TensorExp::Kind::kExtF:
     return rewriter.create<arith::ExtFOp>(loc, inferType(e, v0), v0);
-  case kCastFS:
+  case TensorExp::Kind::kCastFS:
     return rewriter.create<arith::FPToSIOp>(loc, inferType(e, v0), v0);
-  case kCastFU:
+  case TensorExp::Kind::kCastFU:
     return rewriter.create<arith::FPToUIOp>(loc, inferType(e, v0), v0);
-  case kCastSF:
+  case TensorExp::Kind::kCastSF:
     return rewriter.create<arith::SIToFPOp>(loc, inferType(e, v0), v0);
-  case kCastUF:
+  case TensorExp::Kind::kCastUF:
     return rewriter.create<arith::UIToFPOp>(loc, inferType(e, v0), v0);
-  case kCastS:
+  case TensorExp::Kind::kCastS:
     return rewriter.create<arith::ExtSIOp>(loc, inferType(e, v0), v0);
-  case kCastU:
+  case TensorExp::Kind::kCastU:
     return rewriter.create<arith::ExtUIOp>(loc, inferType(e, v0), v0);
-  case kCastIdx:
+  case TensorExp::Kind::kCastIdx:
     return rewriter.create<arith::IndexCastOp>(loc, inferType(e, v0), v0);
-  case kTruncI:
+  case TensorExp::Kind::kTruncI:
     return rewriter.create<arith::TruncIOp>(loc, inferType(e, v0), v0);
-  case kCIm: {
+  case TensorExp::Kind::kCIm: {
     auto type = v0.getType().cast<ComplexType>();
     auto eltType = type.getElementType().cast<FloatType>();
     return rewriter.create<complex::ImOp>(loc, eltType, v0);
   }
-  case kCRe: {
+  case TensorExp::Kind::kCRe: {
     auto type = v0.getType().cast<ComplexType>();
     auto eltType = type.getElementType().cast<FloatType>();
     return rewriter.create<complex::ReOp>(loc, eltType, v0);
   }
-  case kBitCast:
+  case TensorExp::Kind::kBitCast:
     return rewriter.create<arith::BitcastOp>(loc, inferType(e, v0), v0);
   // Binary operations.
-  case kMulF:
+  case TensorExp::Kind::kMulF:
     return rewriter.create<arith::MulFOp>(loc, v0, v1);
-  case kMulC:
+  case TensorExp::Kind::kMulC:
     return rewriter.create<complex::MulOp>(loc, v0, v1);
-  case kMulI:
+  case TensorExp::Kind::kMulI:
     return rewriter.create<arith::MulIOp>(loc, v0, v1);
-  case kDivF:
+  case TensorExp::Kind::kDivF:
     return rewriter.create<arith::DivFOp>(loc, v0, v1);
-  case kDivC:
+  case TensorExp::Kind::kDivC:
     return rewriter.create<complex::DivOp>(loc, v0, v1);
-  case kDivS:
+  case TensorExp::Kind::kDivS:
     return rewriter.create<arith::DivSIOp>(loc, v0, v1);
-  case kDivU:
+  case TensorExp::Kind::kDivU:
     return rewriter.create<arith::DivUIOp>(loc, v0, v1);
-  case kAddF:
+  case TensorExp::Kind::kAddF:
     return rewriter.create<arith::AddFOp>(loc, v0, v1);
-  case kAddC:
+  case TensorExp::Kind::kAddC:
     return rewriter.create<complex::AddOp>(loc, v0, v1);
-  case kAddI:
+  case TensorExp::Kind::kAddI:
     return rewriter.create<arith::AddIOp>(loc, v0, v1);
-  case kSubF:
+  case TensorExp::Kind::kSubF:
     return rewriter.create<arith::SubFOp>(loc, v0, v1);
-  case kSubC:
+  case TensorExp::Kind::kSubC:
     return rewriter.create<complex::SubOp>(loc, v0, v1);
-  case kSubI:
+  case TensorExp::Kind::kSubI:
     return rewriter.create<arith::SubIOp>(loc, v0, v1);
-  case kAndI:
+  case TensorExp::Kind::kAndI:
     return rewriter.create<arith::AndIOp>(loc, v0, v1);
-  case kOrI:
+  case TensorExp::Kind::kOrI:
     return rewriter.create<arith::OrIOp>(loc, v0, v1);
-  case kXorI:
+  case TensorExp::Kind::kXorI:
     return rewriter.create<arith::XOrIOp>(loc, v0, v1);
-  case kShrS:
+  case TensorExp::Kind::kShrS:
     return rewriter.create<arith::ShRSIOp>(loc, v0, v1);
-  case kShrU:
+  case TensorExp::Kind::kShrU:
     return rewriter.create<arith::ShRUIOp>(loc, v0, v1);
-  case kShlI:
+  case TensorExp::Kind::kShlI:
     return rewriter.create<arith::ShLIOp>(loc, v0, v1);
-  case kBinaryBranch: // semi-ring ops with custom logic.
+  case TensorExp::Kind::kBinaryBranch: // semi-ring ops with custom logic.
     return insertYieldOp(rewriter, loc,
                          *tensorExps[e].op->getBlock()->getParent(), {v0});
-  case kUnary:
+  case TensorExp::Kind::kUnary:
     return buildUnaryPresent(rewriter, loc, tensorExps[e].op, v0);
-  case kSelect:
+  case TensorExp::Kind::kSelect:
     return insertYieldOp(rewriter, loc,
                          cast<SelectOp>(tensorExps[e].op).getRegion(), {v0});
-  case kBinary:
+  case TensorExp::Kind::kBinary:
     return buildBinaryOverlap(rewriter, loc, tensorExps[e].op, v0, v1);
-  case kReduce: {
+  case TensorExp::Kind::kReduce: {
     ReduceOp redOp = cast<ReduceOp>(tensorExps[e].op);
     return insertYieldOp(rewriter, loc, redOp.getRegion(), {v0, v1});
   }

diff  --git a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
index 10d350f7c6b97..270b5836907e3 100644
--- a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
+++ b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
@@ -23,18 +23,18 @@ namespace {
 ///
 
 #define FOREVERY_BINOP(DO)                                                     \
-  DO(mulf, Kind::kMulF)                                                        \
-  DO(mulc, Kind::kMulC)                                                        \
-  DO(muli, Kind::kMulI)                                                        \
-  DO(addf, Kind::kAddF)                                                        \
-  DO(addc, Kind::kAddC)                                                        \
-  DO(addi, Kind::kAddI)                                                        \
-  DO(subf, Kind::kSubF)                                                        \
-  DO(subc, Kind::kSubC)                                                        \
-  DO(subi, Kind::kSubI)                                                        \
-  DO(andi, Kind::kAndI)                                                        \
-  DO(xori, Kind::kXorI)                                                        \
-  DO(ori, Kind::kOrI)
+  DO(mulf, TensorExp::Kind::kMulF)                                             \
+  DO(mulc, TensorExp::Kind::kMulC)                                             \
+  DO(muli, TensorExp::Kind::kMulI)                                             \
+  DO(addf, TensorExp::Kind::kAddF)                                             \
+  DO(addc, TensorExp::Kind::kAddC)                                             \
+  DO(addi, TensorExp::Kind::kAddI)                                             \
+  DO(subf, TensorExp::Kind::kSubF)                                             \
+  DO(subc, TensorExp::Kind::kSubC)                                             \
+  DO(subi, TensorExp::Kind::kSubI)                                             \
+  DO(andi, TensorExp::Kind::kAndI)                                             \
+  DO(xori, TensorExp::Kind::kXorI)                                             \
+  DO(ori, TensorExp::Kind::kOrI)
 
 // TODO: Disjunctive binary operations that need special handling are not
 // included, e.g., Division are not tested (for now) as it need a constant
@@ -82,7 +82,7 @@ namespace {
 
 /// Simple recursive data structure used to match expressions in Mergers.
 struct Pattern {
-  Kind kind;
+  TensorExp::Kind kind;
 
   /// Expressions representing tensors simply have a tensor number.
   unsigned tensorNum;
@@ -94,11 +94,12 @@ struct Pattern {
   /// Constructors.
   /// Rather than using these, please use the readable helper constructor
   /// functions below to make tests more readable.
-  Pattern(unsigned tensorNum) : kind(Kind::kTensor), tensorNum(tensorNum) {}
-  Pattern(Kind kind, const std::shared_ptr<Pattern> &e0,
+  Pattern(unsigned tensorNum)
+      : kind(TensorExp::Kind::kTensor), tensorNum(tensorNum) {}
+  Pattern(TensorExp::Kind kind, const std::shared_ptr<Pattern> &e0,
           const std::shared_ptr<Pattern> &e1)
       : kind(kind), e0(e0), e1(e1) {
-    assert(kind >= Kind::kMulF);
+    assert(kind >= TensorExp::Kind::kMulF);
     assert(e0 && e1);
   }
 };
@@ -134,7 +135,7 @@ class MergerTestBase : public ::testing::Test {
   ///
 
   unsigned tensor(unsigned tensor) {
-    return merger.addExp(Kind::kTensor, tensor);
+    return merger.addExp(TensorExp::Kind::kTensor, tensor);
   }
 
 #define IMPL_BINOP_EXPR(OP, KIND)                                              \
@@ -222,69 +223,69 @@ class MergerTestBase : public ::testing::Test {
       return false;
     switch (tensorExp.kind) {
     // Leaf.
-    case kTensor:
+    case TensorExp::Kind::kTensor:
       return tensorExp.tensor == pattern->tensorNum;
-    case kInvariant:
-    case kLoopVar:
+    case TensorExp::Kind::kInvariant:
+    case TensorExp::Kind::kLoopVar:
       llvm_unreachable("invariant not handled yet");
     // Unary operations.
-    case kAbsF:
-    case kAbsC:
-    case kAbsI:
-    case kCeilF:
-    case kFloorF:
-    case kSqrtF:
-    case kSqrtC:
-    case kExpm1F:
-    case kExpm1C:
-    case kLog1pF:
-    case kLog1pC:
-    case kSinF:
-    case kSinC:
-    case kTanhF:
-    case kTanhC:
-    case kNegF:
-    case kNegC:
-    case kNegI:
-    case kTruncF:
-    case kExtF:
-    case kCastFS:
-    case kCastFU:
-    case kCastSF:
-    case kCastUF:
-    case kCastS:
-    case kCastU:
-    case kCastIdx:
-    case kTruncI:
-    case kCIm:
-    case kCRe:
-    case kBitCast:
-    case kSelect:
-    case kBinaryBranch:
-    case kUnary:
+    case TensorExp::Kind::kAbsF:
+    case TensorExp::Kind::kAbsC:
+    case TensorExp::Kind::kAbsI:
+    case TensorExp::Kind::kCeilF:
+    case TensorExp::Kind::kFloorF:
+    case TensorExp::Kind::kSqrtF:
+    case TensorExp::Kind::kSqrtC:
+    case TensorExp::Kind::kExpm1F:
+    case TensorExp::Kind::kExpm1C:
+    case TensorExp::Kind::kLog1pF:
+    case TensorExp::Kind::kLog1pC:
+    case TensorExp::Kind::kSinF:
+    case TensorExp::Kind::kSinC:
+    case TensorExp::Kind::kTanhF:
+    case TensorExp::Kind::kTanhC:
+    case TensorExp::Kind::kNegF:
+    case TensorExp::Kind::kNegC:
+    case TensorExp::Kind::kNegI:
+    case TensorExp::Kind::kTruncF:
+    case TensorExp::Kind::kExtF:
+    case TensorExp::Kind::kCastFS:
+    case TensorExp::Kind::kCastFU:
+    case TensorExp::Kind::kCastSF:
+    case TensorExp::Kind::kCastUF:
+    case TensorExp::Kind::kCastS:
+    case TensorExp::Kind::kCastU:
+    case TensorExp::Kind::kCastIdx:
+    case TensorExp::Kind::kTruncI:
+    case TensorExp::Kind::kCIm:
+    case TensorExp::Kind::kCRe:
+    case TensorExp::Kind::kBitCast:
+    case TensorExp::Kind::kSelect:
+    case TensorExp::Kind::kBinaryBranch:
+    case TensorExp::Kind::kUnary:
       return compareExpression(tensorExp.children.e0, pattern->e0);
     // Binary operations.
-    case kMulF:
-    case kMulC:
-    case kMulI:
-    case kDivF:
-    case kDivC:
-    case kDivS:
-    case kDivU:
-    case kAddF:
-    case kAddC:
-    case kAddI:
-    case kSubF:
-    case kSubC:
-    case kSubI:
-    case kAndI:
-    case kOrI:
-    case kXorI:
-    case kShrS:
-    case kShrU:
-    case kShlI:
-    case kBinary:
-    case kReduce:
+    case TensorExp::Kind::kMulF:
+    case TensorExp::Kind::kMulC:
+    case TensorExp::Kind::kMulI:
+    case TensorExp::Kind::kDivF:
+    case TensorExp::Kind::kDivC:
+    case TensorExp::Kind::kDivS:
+    case TensorExp::Kind::kDivU:
+    case TensorExp::Kind::kAddF:
+    case TensorExp::Kind::kAddC:
+    case TensorExp::Kind::kAddI:
+    case TensorExp::Kind::kSubF:
+    case TensorExp::Kind::kSubC:
+    case TensorExp::Kind::kSubI:
+    case TensorExp::Kind::kAndI:
+    case TensorExp::Kind::kOrI:
+    case TensorExp::Kind::kXorI:
+    case TensorExp::Kind::kShrS:
+    case TensorExp::Kind::kShrU:
+    case TensorExp::Kind::kShlI:
+    case TensorExp::Kind::kBinary:
+    case TensorExp::Kind::kReduce:
       return compareExpression(tensorExp.children.e0, pattern->e0) &&
              compareExpression(tensorExp.children.e1, pattern->e1);
     }
@@ -312,15 +313,15 @@ class MergerTest3T1L : public MergerTestBase {
     EXPECT_TRUE(merger.getOutTensorID() == t2);
 
     // Tensor 0: sparse input vector.
-    merger.addExp(Kind::kTensor, t0, -1u);
+    merger.addExp(TensorExp::Kind::kTensor, t0, -1u);
     merger.setLevelAndType(t0, l0, 0, DimLevelType::Compressed);
 
     // Tensor 1: sparse input vector.
-    merger.addExp(Kind::kTensor, t1, -1u);
+    merger.addExp(TensorExp::Kind::kTensor, t1, -1u);
     merger.setLevelAndType(t1, l0, 0, DimLevelType::Compressed);
 
     // Tensor 2: dense output vector.
-    merger.addExp(Kind::kTensor, t2, -1u);
+    merger.addExp(TensorExp::Kind::kTensor, t2, -1u);
     merger.setLevelAndType(t2, l0, 0, DimLevelType::Dense);
   }
 };
@@ -337,19 +338,19 @@ class MergerTest4T1L : public MergerTestBase {
     EXPECT_TRUE(merger.getOutTensorID() == t3);
 
     // Tensor 0: sparse input vector.
-    merger.addExp(Kind::kTensor, t0, -1u);
+    merger.addExp(TensorExp::Kind::kTensor, t0, -1u);
     merger.setLevelAndType(t0, l0, 0, DimLevelType::Compressed);
 
     // Tensor 1: sparse input vector.
-    merger.addExp(Kind::kTensor, t1, -1u);
+    merger.addExp(TensorExp::Kind::kTensor, t1, -1u);
     merger.setLevelAndType(t1, l0, 0, DimLevelType::Compressed);
 
     // Tensor 2: sparse input vector
-    merger.addExp(Kind::kTensor, t2, -1u);
+    merger.addExp(TensorExp::Kind::kTensor, t2, -1u);
     merger.setLevelAndType(t2, l0, 0, DimLevelType::Compressed);
 
     // Tensor 3: dense output vector
-    merger.addExp(Kind::kTensor, t3, -1u);
+    merger.addExp(TensorExp::Kind::kTensor, t3, -1u);
     merger.setLevelAndType(t3, l0, 0, DimLevelType::Dense);
   }
 };
@@ -370,15 +371,15 @@ class MergerTest3T1LD : public MergerTestBase {
     EXPECT_TRUE(merger.getOutTensorID() == t2);
 
     // Tensor 0: sparse input vector.
-    merger.addExp(Kind::kTensor, t0, -1u);
+    merger.addExp(TensorExp::Kind::kTensor, t0, -1u);
     merger.setLevelAndType(t0, l0, 0, DimLevelType::Compressed);
 
     // Tensor 1: dense input vector.
-    merger.addExp(Kind::kTensor, t1, -1u);
+    merger.addExp(TensorExp::Kind::kTensor, t1, -1u);
     merger.setLevelAndType(t1, l0, 0, DimLevelType::Dense);
 
     // Tensor 2: dense output vector.
-    merger.addExp(Kind::kTensor, t2, -1u);
+    merger.addExp(TensorExp::Kind::kTensor, t2, -1u);
     merger.setLevelAndType(t2, l0, 0, DimLevelType::Dense);
   }
 };
@@ -399,19 +400,19 @@ class MergerTest4T1LU : public MergerTestBase {
     EXPECT_TRUE(merger.getOutTensorID() == t3);
 
     // Tensor 0: undef input vector.
-    merger.addExp(Kind::kTensor, t0, -1u);
+    merger.addExp(TensorExp::Kind::kTensor, t0, -1u);
     merger.setLevelAndType(t0, l0, 0, DimLevelType::Undef);
 
     // Tensor 1: dense input vector.
-    merger.addExp(Kind::kTensor, t1, -1u);
+    merger.addExp(TensorExp::Kind::kTensor, t1, -1u);
     merger.setLevelAndType(t1, l0, 0, DimLevelType::Dense);
 
     // Tensor 2: undef input vector.
-    merger.addExp(Kind::kTensor, t2, -1u);
+    merger.addExp(TensorExp::Kind::kTensor, t2, -1u);
     merger.setLevelAndType(t2, l0, 0, DimLevelType::Undef);
 
     // Tensor 3: dense output vector.
-    merger.addExp(Kind::kTensor, t3, -1u);
+    merger.addExp(TensorExp::Kind::kTensor, t3, -1u);
     merger.setLevelAndType(t3, l0, 0, DimLevelType::Dense);
   }
 };
@@ -435,15 +436,15 @@ class MergerTest3T1LSo : public MergerTestBase {
     merger.setHasSparseOut(true);
 
     // Tensor 0: undef input vector.
-    merger.addExp(Kind::kTensor, t0, -1u);
+    merger.addExp(TensorExp::Kind::kTensor, t0, -1u);
     merger.setLevelAndType(t0, l0, 0, DimLevelType::Undef);
 
     // Tensor 1: undef input vector.
-    merger.addExp(Kind::kTensor, t1, -1u);
+    merger.addExp(TensorExp::Kind::kTensor, t1, -1u);
     merger.setLevelAndType(t1, l0, 0, DimLevelType::Undef);
 
     // Tensor 2: sparse output vector.
-    merger.addExp(Kind::kTensor, t2, -1u);
+    merger.addExp(TensorExp::Kind::kTensor, t2, -1u);
     merger.setLevelAndType(t2, l0, 0, DimLevelType::Compressed);
   }
 };


        


More information about the Mlir-commits mailing list