[Mlir-commits] [mlir] 123e8df - [mlir][sparse] add support for std unary operations

Aart Bik llvmlistbot at llvm.org
Tue Jul 13 14:51:28 PDT 2021


Author: Aart Bik
Date: 2021-07-13T14:51:13-07:00
New Revision: 123e8dfcf86a74eb7ba08f33681df581d1be9dbd

URL: https://github.com/llvm/llvm-project/commit/123e8dfcf86a74eb7ba08f33681df581d1be9dbd
DIFF: https://github.com/llvm/llvm-project/commit/123e8dfcf86a74eb7ba08f33681df581d1be9dbd.diff

LOG: [mlir][sparse] add support for std unary operations

Adds zero-preserving unary operators from std. Also adds xor.
Performs minor refactoring to remove "zero" node, and pushed
the irregular logic for negi (not support in std) into one place.

Reviewed By: gussmith23

Differential Revision: https://reviews.llvm.org/D105928

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
    mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
    mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
    mlir/test/Dialect/SparseTensor/sparse_fp_ops.mlir
    mlir/test/Dialect/SparseTensor/sparse_int_ops.mlir
    mlir/unittests/Dialect/SparseTensor/MergerTest.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
index d81cb37b03f9a..aa1e52f25b343 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
@@ -21,15 +21,20 @@ namespace mlir {
 namespace sparse_tensor {
 
 /// Dimension level type for a tensor (undef means index does not appear).
-enum class Dim { kSparse, kDense, kSingle, kUndef };
+enum Dim { kSparse, kDense, kSingle, kUndef };
 
 /// Tensor expression kind.
-enum class Kind {
+enum Kind {
   // Leaf.
-  kTensor,
+  kTensor = 0,
   kInvariant,
-  kZero,
-  // Operation.
+  // Unary operations.
+  kAbsF,
+  kCeilF,
+  kFloorF,
+  kNegF,
+  kNegI,
+  // Binary operations.
   kMulF,
   kMulI,
   kDivF,
@@ -41,6 +46,7 @@ enum class Kind {
   kSubI,
   kAndI,
   kOrI,
+  kXorI,
 };
 
 /// Children subexpressions of tensor operations.
@@ -105,8 +111,7 @@ class Merger {
         dims(t + 1, std::vector<Dim>(l, Dim::kUndef)) {}
 
   /// Adds a tensor expression. Returns its index.
-  unsigned addExp(Kind k, unsigned e0 = -1u, unsigned e1 = -1u,
-                  Value v = Value());
+  unsigned addExp(Kind k, unsigned e0, unsigned e1 = -1u, Value v = Value());
   unsigned addExp(Kind k, Value v) { return addExp(k, -1u, -1u, v); }
 
   /// Adds an iteration lattice point. Returns its index.
@@ -129,11 +134,10 @@ class Merger {
   /// Returns the index of the new set.
   unsigned takeDisj(Kind kind, unsigned s0, unsigned s1);
 
-  /// Maps a zero operand over a lattice set, i.e. each lattice point on an
-  /// expression E is simply copied over, but with 0 OP E as new expression.
-  /// This is useful to deal with disjunctive, but non-commutative operators.
-  /// Returns the index of the new set.
-  unsigned mapZero(Kind kind, unsigned s0);
+  /// Maps the unary operator over the lattice set of the operand, i.e. each
+  /// lattice point on an expression E is simply copied over, but with OP E
+  /// as new expression. Returns the index of the new set.
+  unsigned mapSet(Kind kind, unsigned s0);
 
   /// Optimizes the iteration lattice points in the given set. This
   /// method should be called right before code generation to avoid

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 9d5447e181c87..a852a3df512d6 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -609,18 +609,22 @@ static void genReductionEnd(Merger &merger, CodeGen &codegen,
 static Value genExp(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
                     linalg::GenericOp op, unsigned exp) {
   Location loc = op.getLoc();
+  if (exp == -1u)
+    return Value();
   if (merger.exp(exp).kind == Kind::kTensor)
     return genTensorLoad(merger, codegen, rewriter, op, exp);
   if (merger.exp(exp).kind == Kind::kInvariant)
     return genInvariantValue(merger, codegen, rewriter, exp);
-  if (merger.exp(exp).kind == Kind::kZero) {
-    Type tp = op.getOutputTensorTypes()[0].getElementType();
-    merger.exp(exp).val =
-        rewriter.create<ConstantOp>(loc, tp, rewriter.getZeroAttr(tp));
-    return genInvariantValue(merger, codegen, rewriter, exp);
-  }
   Value v0 = genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e0);
   Value v1 = genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e1);
+  if (merger.exp(exp).kind == Kind::kNegI) {
+    // TODO: no negi in std, need to make zero explicit.
+    Type tp = op.getOutputTensorTypes()[0].getElementType();
+    v1 = v0;
+    v0 = rewriter.create<ConstantOp>(loc, tp, rewriter.getZeroAttr(tp));
+    if (codegen.curVecLength > 1)
+      v0 = genVectorInvariantValue(codegen, rewriter, v0);
+  }
   return merger.buildExp(rewriter, loc, exp, v0, v1);
 }
 
@@ -628,6 +632,8 @@ static Value genExp(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
 static void genInvariants(Merger &merger, CodeGen &codegen,
                           PatternRewriter &rewriter, linalg::GenericOp op,
                           unsigned exp, unsigned ldx, bool hoist) {
+  if (exp == -1u)
+    return;
   if (merger.exp(exp).kind == Kind::kTensor) {
     // Inspect tensor indices.
     bool atLevel = ldx == -1u;
@@ -649,8 +655,7 @@ static void genInvariants(Merger &merger, CodeGen &codegen,
       merger.exp(exp).val =
           hoist ? genTensorLoad(merger, codegen, rewriter, op, exp) : Value();
     }
-  } else if (merger.exp(exp).kind != Kind::kInvariant &&
-             merger.exp(exp).kind != Kind::kZero) {
+  } else if (merger.exp(exp).kind != Kind::kInvariant) {
     // Traverse into the binary operations. Note that we only hoist
     // tensor loads, since subsequent MLIR/LLVM passes know how to
     // deal with all other kinds of derived loop invariants.

diff  --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index 74a3a24738eaf..b254f03139dc2 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -28,8 +28,14 @@ TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v)
   case Kind::kInvariant:
     assert(x == -1u && y == -1u && v);
     break;
-  case Kind::kZero:
-    assert(x == -1u && y == -1u && !v);
+  case kAbsF:
+  case kCeilF:
+  case kFloorF:
+  case kNegF:
+  case kNegI:
+    assert(x != -1u && y == -1u && !v);
+    children.e0 = x;
+    children.e1 = y;
     break;
   default:
     assert(x != -1u && y != -1u && !v);
@@ -89,22 +95,25 @@ unsigned Merger::takeConj(Kind kind, unsigned s0, unsigned s1) {
 
 unsigned Merger::takeDisj(Kind kind, unsigned s0, unsigned s1) {
   unsigned s = takeConj(kind, s0, s1);
-  // Followed by all in s0 and s1.
+  // Followed by all in s0.
   for (unsigned p : latSets[s0])
     latSets[s].push_back(p);
-  if (Kind::kSubF <= kind && kind <= Kind::kSubI)
-    s1 = mapZero(kind, s1);
+  // Map binary 0-y to unary -y.
+  if (kind == Kind::kSubF)
+    s1 = mapSet(Kind::kNegF, s1);
+  else if (kind == Kind::kSubI)
+    s1 = mapSet(Kind::kNegI, s1);
+  // Followed by all in s1.
   for (unsigned p : latSets[s1])
     latSets[s].push_back(p);
   return s;
 }
 
-unsigned Merger::mapZero(Kind kind, unsigned s0) {
-  assert(Kind::kSubF <= kind && kind <= Kind::kSubI);
+unsigned Merger::mapSet(Kind kind, unsigned s0) {
+  assert(Kind::kAbsF <= kind && kind <= Kind::kNegI);
   unsigned s = addSet();
-  unsigned z = addExp(Kind::kZero);
   for (unsigned p : latSets[s0]) {
-    unsigned e = addExp(kind, z, latPoints[p].exp);
+    unsigned e = addExp(kind, latPoints[p].exp);
     latPoints.push_back(LatPoint(latPoints[p].bits, e));
     latSets[s].push_back(latPoints.size() - 1);
   }
@@ -194,6 +203,12 @@ bool Merger::isConjunction(unsigned t, unsigned e) const {
   switch (tensorExps[e].kind) {
   case Kind::kTensor:
     return tensorExps[e].tensor == t;
+  case kAbsF:
+  case kCeilF:
+  case kFloorF:
+  case kNegF:
+  case kNegI:
+    return isConjunction(t, tensorExps[e].children.e0);
   case Kind::kMulF:
   case Kind::kMulI:
   case Kind::kAndI:
@@ -213,30 +228,9 @@ bool Merger::isConjunction(unsigned t, unsigned e) const {
 // Print methods (for debugging).
 //
 
-static char kindToOpSymbol(Kind kind) {
-  switch (kind) {
-  case Kind::kMulF:
-  case Kind::kMulI:
-    return '*';
-  case Kind::kDivF:
-  case Kind::kDivS:
-  case Kind::kDivU:
-    return '/';
-  case Kind::kAddF:
-  case Kind::kAddI:
-    return '+';
-  case Kind::kSubF:
-  case Kind::kSubI:
-    return '-';
-  case Kind::kAndI:
-    return '&';
-  case Kind::kOrI:
-    return '|';
-  default:
-    break;
-  }
-  llvm_unreachable("unexpected kind");
-}
+static const char *kOpSymbols[] = {"",  "",  "abs", "ceil", "floor", "-",
+                                   "-", "*", "*",   "/",    "/",     "+",
+                                   "+", "-", "-",   "&",    "|",     "^"};
 
 void Merger::dumpExp(unsigned e) const {
   switch (tensorExps[e].kind) {
@@ -250,13 +244,18 @@ void Merger::dumpExp(unsigned e) const {
   case Kind::kInvariant:
     llvm::dbgs() << "invariant";
     break;
-  case Kind::kZero:
-    llvm::dbgs() << "zero";
+  case kAbsF:
+  case kCeilF:
+  case kFloorF:
+  case kNegF:
+  case kNegI:
+    llvm::dbgs() << kOpSymbols[tensorExps[e].kind] << " ";
+    dumpExp(tensorExps[e].children.e0);
     break;
   default:
     llvm::dbgs() << "(";
     dumpExp(tensorExps[e].children.e0);
-    llvm::dbgs() << " " << kindToOpSymbol(tensorExps[e].kind) << " ";
+    llvm::dbgs() << " " << kOpSymbols[tensorExps[e].kind] << " ";
     dumpExp(tensorExps[e].children.e1);
     llvm::dbgs() << ")";
   }
@@ -315,8 +314,7 @@ unsigned Merger::buildLattices(unsigned e, unsigned i) {
   Kind kind = tensorExps[e].kind;
   switch (kind) {
   case Kind::kTensor:
-  case Kind::kInvariant:
-  case Kind::kZero: {
+  case Kind::kInvariant: {
     // Either the index is really used in the tensor expression, or it is
     // set to the undefined index in that dimension. An invariant expression
     // is set to a synthetic tensor with undefined indices only.
@@ -325,6 +323,18 @@ unsigned Merger::buildLattices(unsigned e, unsigned i) {
     latSets[s].push_back(addLat(t, i, e));
     return s;
   }
+  case kAbsF:
+  case kCeilF:
+  case kFloorF:
+  case kNegF:
+  case kNegI:
+    // A zero preserving operation (viz. f(0) = 0, [Bik96,Ch5]) maps the
+    // lattice set of the operand through the operator into a new set.
+    //
+    //  -y|!y | y |
+    //  --+---+---+
+    //    | 0 |-y |
+    return mapSet(kind, buildLattices(tensorExps[e].children.e0, i));
   case Kind::kMulF:
   case Kind::kMulI:
   case Kind::kAndI:
@@ -357,16 +367,12 @@ unsigned Merger::buildLattices(unsigned e, unsigned i) {
     return takeConj(kind, // take binary conjunction
                     buildLattices(tensorExps[e].children.e0, i),
                     buildLattices(tensorExps[e].children.e1, i));
-  case Kind::kSubF:
-  case Kind::kSubI:
-    // Special case: 0-y is -y.
-    if (tensorExps[tensorExps[e].children.e0].kind == Kind::kZero)
-      return mapZero(kind, // maps to 0-y with just y's lattices
-                     buildLattices(tensorExps[e].children.e1, i));
-    LLVM_FALLTHROUGH;
   case Kind::kAddF:
   case Kind::kAddI:
+  case Kind::kSubF:
+  case Kind::kSubI:
   case Kind::kOrI:
+  case Kind::kXorI:
     // An additive operation needs to be performed
     // for the disjunction of sparse iteration spaces.
     //
@@ -420,10 +426,15 @@ Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
   if (def->getNumOperands() == 1) {
     auto x = buildTensorExp(op, def->getOperand(0));
     if (x.hasValue()) {
-      unsigned e0 = addExp(Kind::kZero);
-      unsigned e1 = x.getValue();
+      unsigned e = x.getValue();
+      if (isa<AbsFOp>(def))
+        return addExp(Kind::kAbsF, e);
+      if (isa<CeilFOp>(def))
+        return addExp(Kind::kCeilF, e);
+      if (isa<FloorFOp>(def))
+        return addExp(Kind::kFloorF, e);
       if (isa<NegFOp>(def))
-        return addExp(Kind::kSubF, e0, e1);
+        return addExp(Kind::kNegF, e);
       // TODO: no negi in std?
     }
   }
@@ -457,6 +468,8 @@ Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
         return addExp(Kind::kAndI, e0, e1);
       if (isa<OrOp>(def))
         return addExp(Kind::kOrI, e0, e1);
+      if (isa<XOrOp>(def))
+        return addExp(Kind::kXorI, e0, e1);
     }
   }
   // Cannot build.
@@ -468,8 +481,18 @@ Value Merger::buildExp(PatternRewriter &rewriter, Location loc, unsigned e,
   switch (tensorExps[e].kind) {
   case Kind::kTensor:
   case Kind::kInvariant:
-  case Kind::kZero:
     llvm_unreachable("unexpected non-op");
+  case kAbsF:
+    return rewriter.create<AbsFOp>(loc, v0);
+  case kCeilF:
+    return rewriter.create<CeilFOp>(loc, v0);
+  case kFloorF:
+    return rewriter.create<FloorFOp>(loc, v0);
+  case kNegF:
+    return rewriter.create<NegFOp>(loc, v0);
+  case kNegI:
+    assert(v1); // no negi in std
+    return rewriter.create<SubIOp>(loc, v0, v1);
   case Kind::kMulF:
     return rewriter.create<MulFOp>(loc, v0, v1);
   case Kind::kMulI:
@@ -492,6 +515,8 @@ Value Merger::buildExp(PatternRewriter &rewriter, Location loc, unsigned e,
     return rewriter.create<AndOp>(loc, v0, v1);
   case Kind::kOrI:
     return rewriter.create<OrOp>(loc, v0, v1);
+  case Kind::kXorI:
+    return rewriter.create<XOrOp>(loc, v0, v1);
   }
   llvm_unreachable("unexpected expression kind in build");
 }

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_fp_ops.mlir b/mlir/test/Dialect/SparseTensor/sparse_fp_ops.mlir
index ff43aa7e48fa0..c7c6507517c1a 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_fp_ops.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_fp_ops.mlir
@@ -31,26 +31,120 @@
   doc = "x(i) = a(i) OP c"
 }
 
+// CHECK-LABEL:   func @abs(
+// CHECK-SAME:              %[[VAL_0:.*]]: tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>,
+// CHECK-SAME:              %[[VAL_1:.*]]: tensor<32xf64> {linalg.inplaceable = true}) -> tensor<32xf64> {
+// CHECK:           %[[VAL_2:.*]] = constant 0 : index
+// CHECK:           %[[VAL_3:.*]] = constant 1 : index
+// CHECK:           %[[VAL_4:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_2]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
+// CHECK:           %[[VAL_5:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_2]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
+// CHECK:           %[[VAL_6:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xf64>
+// CHECK:           %[[VAL_7:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32xf64>
+// CHECK:           %[[VAL_8:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_2]]] : memref<?xindex>
+// CHECK:           %[[VAL_9:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_3]]] : memref<?xindex>
+// CHECK:           scf.for %[[VAL_10:.*]] = %[[VAL_8]] to %[[VAL_9]] step %[[VAL_3]] {
+// CHECK:             %[[VAL_11:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_10]]] : memref<?xindex>
+// CHECK:             %[[VAL_12:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_10]]] : memref<?xf64>
+// CHECK:             %[[VAL_13:.*]] = absf %[[VAL_12]] : f64
+// CHECK:             memref.store %[[VAL_13]], %[[VAL_7]]{{\[}}%[[VAL_11]]] : memref<32xf64>
+// CHECK:           }
+// CHECK:           %[[VAL_14:.*]] = memref.tensor_load %[[VAL_7]] : memref<32xf64>
+// CHECK:           return %[[VAL_14]] : tensor<32xf64>
+func @abs(%arga: tensor<32xf64, #SV>,
+          %argx: tensor<32xf64> {linalg.inplaceable = true}) -> tensor<32xf64> {
+  %0 = linalg.generic #trait1
+     ins(%arga: tensor<32xf64, #SV>)
+    outs(%argx: tensor<32xf64>) {
+      ^bb(%a: f64, %x: f64):
+        %0 = absf %a : f64
+        linalg.yield %0 : f64
+  } -> tensor<32xf64>
+  return %0 : tensor<32xf64>
+}
+
+// CHECK-LABEL:   func @ceil(
+// CHECK-SAME:              %[[VAL_0:.*]]: tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>,
+// CHECK-SAME:              %[[VAL_1:.*]]: tensor<32xf64> {linalg.inplaceable = true}) -> tensor<32xf64> {
+// CHECK:           %[[VAL_2:.*]] = constant 0 : index
+// CHECK:           %[[VAL_3:.*]] = constant 1 : index
+// CHECK:           %[[VAL_4:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_2]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
+// CHECK:           %[[VAL_5:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_2]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
+// CHECK:           %[[VAL_6:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xf64>
+// CHECK:           %[[VAL_7:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32xf64>
+// CHECK:           %[[VAL_8:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_2]]] : memref<?xindex>
+// CHECK:           %[[VAL_9:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_3]]] : memref<?xindex>
+// CHECK:           scf.for %[[VAL_10:.*]] = %[[VAL_8]] to %[[VAL_9]] step %[[VAL_3]] {
+// CHECK:             %[[VAL_11:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_10]]] : memref<?xindex>
+// CHECK:             %[[VAL_12:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_10]]] : memref<?xf64>
+// CHECK:             %[[VAL_13:.*]] = ceilf %[[VAL_12]] : f64
+// CHECK:             memref.store %[[VAL_13]], %[[VAL_7]]{{\[}}%[[VAL_11]]] : memref<32xf64>
+// CHECK:           }
+// CHECK:           %[[VAL_14:.*]] = memref.tensor_load %[[VAL_7]] : memref<32xf64>
+// CHECK:           return %[[VAL_14]] : tensor<32xf64>
+// CHECK:         }
+func @ceil(%arga: tensor<32xf64, #SV>,
+           %argx: tensor<32xf64> {linalg.inplaceable = true}) -> tensor<32xf64> {
+  %0 = linalg.generic #trait1
+     ins(%arga: tensor<32xf64, #SV>)
+    outs(%argx: tensor<32xf64>) {
+      ^bb(%a: f64, %x: f64):
+        %0 = ceilf %a : f64
+        linalg.yield %0 : f64
+  } -> tensor<32xf64>
+  return %0 : tensor<32xf64>
+}
+
+// CHECK-LABEL:   func @floor(
+// CHECK-SAME:              %[[VAL_0:.*]]: tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>,
+// CHECK-SAME:              %[[VAL_1:.*]]: tensor<32xf64> {linalg.inplaceable = true}) -> tensor<32xf64> {
+// CHECK:           %[[VAL_2:.*]] = constant 0 : index
+// CHECK:           %[[VAL_3:.*]] = constant 1 : index
+// CHECK:           %[[VAL_4:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_2]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
+// CHECK:           %[[VAL_5:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_2]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
+// CHECK:           %[[VAL_6:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xf64>
+// CHECK:           %[[VAL_7:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32xf64>
+// CHECK:           %[[VAL_8:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_2]]] : memref<?xindex>
+// CHECK:           %[[VAL_9:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_3]]] : memref<?xindex>
+// CHECK:           scf.for %[[VAL_10:.*]] = %[[VAL_8]] to %[[VAL_9]] step %[[VAL_3]] {
+// CHECK:             %[[VAL_11:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_10]]] : memref<?xindex>
+// CHECK:             %[[VAL_12:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_10]]] : memref<?xf64>
+// CHECK:             %[[VAL_13:.*]] = floorf %[[VAL_12]] : f64
+// CHECK:             memref.store %[[VAL_13]], %[[VAL_7]]{{\[}}%[[VAL_11]]] : memref<32xf64>
+// CHECK:           }
+// CHECK:           %[[VAL_14:.*]] = memref.tensor_load %[[VAL_7]] : memref<32xf64>
+// CHECK:           return %[[VAL_14]] : tensor<32xf64>
+// CHECK:         }
+func @floor(%arga: tensor<32xf64, #SV>,
+            %argx: tensor<32xf64> {linalg.inplaceable = true}) -> tensor<32xf64> {
+  %0 = linalg.generic #trait1
+     ins(%arga: tensor<32xf64, #SV>)
+    outs(%argx: tensor<32xf64>) {
+      ^bb(%a: f64, %x: f64):
+        %0 = floorf %a : f64
+        linalg.yield %0 : f64
+  } -> tensor<32xf64>
+  return %0 : tensor<32xf64>
+}
+
 // CHECK-LABEL:   func @neg(
 // CHECK-SAME:              %[[VAL_0:.*]]: tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>,
 // CHECK-SAME:              %[[VAL_1:.*]]: tensor<32xf64> {linalg.inplaceable = true}) -> tensor<32xf64> {
 // CHECK:           %[[VAL_2:.*]] = constant 0 : index
 // CHECK:           %[[VAL_3:.*]] = constant 1 : index
-// CHECK:           %[[VAL_4:.*]] = constant 0.000000e+00 : f64
-// CHECK:           %[[VAL_5:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_2]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>
-// CHECK:           %[[VAL_6:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_2]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>
-// CHECK:           %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>
-// CHECK:           %[[VAL_8:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32xf64>
-// CHECK:           %[[VAL_9:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_2]]] : memref<?xindex>
-// CHECK:           %[[VAL_10:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref<?xindex>
-// CHECK:           scf.for %[[VAL_11:.*]] = %[[VAL_9]] to %[[VAL_10]] step %[[VAL_3]] {
-// CHECK:             %[[VAL_12:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_11]]] : memref<?xindex>
-// CHECK:             %[[VAL_13:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_11]]] : memref<?xf64>
-// CHECK:             %[[VAL_14:.*]] = subf %[[VAL_4]], %[[VAL_13]] : f64
-// CHECK:             memref.store %[[VAL_14]], %[[VAL_8]]{{\[}}%[[VAL_12]]] : memref<32xf64>
+// CHECK:           %[[VAL_4:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_2]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
+// CHECK:           %[[VAL_5:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_2]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
+// CHECK:           %[[VAL_6:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xf64>
+// CHECK:           %[[VAL_7:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32xf64>
+// CHECK:           %[[VAL_8:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_2]]] : memref<?xindex>
+// CHECK:           %[[VAL_9:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_3]]] : memref<?xindex>
+// CHECK:           scf.for %[[VAL_10:.*]] = %[[VAL_8]] to %[[VAL_9]] step %[[VAL_3]] {
+// CHECK:             %[[VAL_11:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_10]]] : memref<?xindex>
+// CHECK:             %[[VAL_12:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_10]]] : memref<?xf64>
+// CHECK:             %[[VAL_13:.*]] = negf %[[VAL_12]] : f64
+// CHECK:             memref.store %[[VAL_13]], %[[VAL_7]]{{\[}}%[[VAL_11]]] : memref<32xf64>
 // CHECK:           }
-// CHECK:           %[[VAL_15:.*]] = memref.tensor_load %[[VAL_8]] : memref<32xf64>
-// CHECK:           return %[[VAL_15]] : tensor<32xf64>
+// CHECK:           %[[VAL_14:.*]] = memref.tensor_load %[[VAL_7]] : memref<32xf64>
+// CHECK:           return %[[VAL_14]] : tensor<32xf64>
 // CHECK:         }
 func @neg(%arga: tensor<32xf64, #SV>,
           %argx: tensor<32xf64> {linalg.inplaceable = true}) -> tensor<32xf64> {
@@ -132,47 +226,46 @@ func @add(%arga: tensor<32xf64, #SV>,
 // CHECK:           %[[VAL_4:.*]] = constant 0 : index
 // CHECK:           %[[VAL_5:.*]] = constant true
 // CHECK:           %[[VAL_6:.*]] = constant 1 : index
-// CHECK:           %[[VAL_7:.*]] = constant 0.000000e+00 : f64
-// CHECK:           %[[VAL_8:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_4]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>
-// CHECK:           %[[VAL_9:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_4]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>
-// CHECK:           %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>
-// CHECK:           %[[VAL_11:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32xf64>
-// CHECK:           %[[VAL_12:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32xf64>
-// CHECK:           %[[VAL_13:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_4]]] : memref<?xindex>
-// CHECK:           %[[VAL_14:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_6]]] : memref<?xindex>
-// CHECK:           %[[VAL_15:.*]]:2 = scf.while (%[[VAL_16:.*]] = %[[VAL_13]], %[[VAL_17:.*]] = %[[VAL_4]]) : (index, index) -> (index, index) {
-// CHECK:             %[[VAL_18:.*]] = cmpi ult, %[[VAL_16]], %[[VAL_14]] : index
-// CHECK:             scf.condition(%[[VAL_18]]) %[[VAL_16]], %[[VAL_17]] : index, index
+// CHECK:           %[[VAL_7:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_4]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
+// CHECK:           %[[VAL_8:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_4]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
+// CHECK:           %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xf64>
+// CHECK:           %[[VAL_10:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32xf64>
+// CHECK:           %[[VAL_11:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32xf64>
+// CHECK:           %[[VAL_12:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_4]]] : memref<?xindex>
+// CHECK:           %[[VAL_13:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_6]]] : memref<?xindex>
+// CHECK:           %[[VAL_14:.*]]:2 = scf.while (%[[VAL_15:.*]] = %[[VAL_12]], %[[VAL_16:.*]] = %[[VAL_4]]) : (index, index) -> (index, index) {
+// CHECK:             %[[VAL_17:.*]] = cmpi ult, %[[VAL_15]], %[[VAL_13]] : index
+// CHECK:             scf.condition(%[[VAL_17]]) %[[VAL_15]], %[[VAL_16]] : index, index
 // CHECK:           } do {
-// CHECK:           ^bb0(%[[VAL_19:.*]]: index, %[[VAL_20:.*]]: index):
-// CHECK:             %[[VAL_21:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_19]]] : memref<?xindex>
-// CHECK:             %[[VAL_22:.*]] = cmpi eq, %[[VAL_21]], %[[VAL_20]] : index
-// CHECK:             scf.if %[[VAL_22]] {
-// CHECK:               %[[VAL_23:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_19]]] : memref<?xf64>
-// CHECK:               %[[VAL_24:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_20]]] : memref<32xf64>
-// CHECK:               %[[VAL_25:.*]] = subf %[[VAL_23]], %[[VAL_24]] : f64
-// CHECK:               memref.store %[[VAL_25]], %[[VAL_12]]{{\[}}%[[VAL_20]]] : memref<32xf64>
+// CHECK:           ^bb0(%[[VAL_18:.*]]: index, %[[VAL_19:.*]]: index):
+// CHECK:             %[[VAL_20:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_18]]] : memref<?xindex>
+// CHECK:             %[[VAL_21:.*]] = cmpi eq, %[[VAL_20]], %[[VAL_19]] : index
+// CHECK:             scf.if %[[VAL_21]] {
+// CHECK:               %[[VAL_22:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_18]]] : memref<?xf64>
+// CHECK:               %[[VAL_23:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_19]]] : memref<32xf64>
+// CHECK:               %[[VAL_24:.*]] = subf %[[VAL_22]], %[[VAL_23]] : f64
+// CHECK:               memref.store %[[VAL_24]], %[[VAL_11]]{{\[}}%[[VAL_19]]] : memref<32xf64>
 // CHECK:             } else {
 // CHECK:               scf.if %[[VAL_5]] {
-// CHECK:                 %[[VAL_26:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_20]]] : memref<32xf64>
-// CHECK:                 %[[VAL_27:.*]] = subf %[[VAL_7]], %[[VAL_26]] : f64
-// CHECK:                 memref.store %[[VAL_27]], %[[VAL_12]]{{\[}}%[[VAL_20]]] : memref<32xf64>
+// CHECK:                 %[[VAL_25:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_19]]] : memref<32xf64>
+// CHECK:                 %[[VAL_26:.*]] = negf %[[VAL_25]] : f64
+// CHECK:                 memref.store %[[VAL_26]], %[[VAL_11]]{{\[}}%[[VAL_19]]] : memref<32xf64>
 // CHECK:               } else {
 // CHECK:               }
 // CHECK:             }
-// CHECK:             %[[VAL_28:.*]] = cmpi eq, %[[VAL_21]], %[[VAL_20]] : index
-// CHECK:             %[[VAL_29:.*]] = addi %[[VAL_19]], %[[VAL_6]] : index
-// CHECK:             %[[VAL_30:.*]] = select %[[VAL_28]], %[[VAL_29]], %[[VAL_19]] : index
-// CHECK:             %[[VAL_31:.*]] = addi %[[VAL_20]], %[[VAL_6]] : index
-// CHECK:             scf.yield %[[VAL_30]], %[[VAL_31]] : index, index
+// CHECK:             %[[VAL_27:.*]] = cmpi eq, %[[VAL_20]], %[[VAL_19]] : index
+// CHECK:             %[[VAL_28:.*]] = addi %[[VAL_18]], %[[VAL_6]] : index
+// CHECK:             %[[VAL_29:.*]] = select %[[VAL_27]], %[[VAL_28]], %[[VAL_18]] : index
+// CHECK:             %[[VAL_30:.*]] = addi %[[VAL_19]], %[[VAL_6]] : index
+// CHECK:             scf.yield %[[VAL_29]], %[[VAL_30]] : index, index
 // CHECK:           }
-// CHECK:           scf.for %[[VAL_32:.*]] = %[[VAL_33:.*]]#1 to %[[VAL_3]] step %[[VAL_6]] {
-// CHECK:             %[[VAL_34:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_32]]] : memref<32xf64>
-// CHECK:             %[[VAL_35:.*]] = subf %[[VAL_7]], %[[VAL_34]] : f64
-// CHECK:             memref.store %[[VAL_35]], %[[VAL_12]]{{\[}}%[[VAL_32]]] : memref<32xf64>
+// CHECK:           scf.for %[[VAL_31:.*]] = %[[VAL_32:.*]]#1 to %[[VAL_3]] step %[[VAL_6]] {
+// CHECK:             %[[VAL_33:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_31]]] : memref<32xf64>
+// CHECK:             %[[VAL_34:.*]] = negf %[[VAL_33]] : f64
+// CHECK:             memref.store %[[VAL_34]], %[[VAL_11]]{{\[}}%[[VAL_31]]] : memref<32xf64>
 // CHECK:           }
-// CHECK:           %[[VAL_36:.*]] = memref.tensor_load %[[VAL_12]] : memref<32xf64>
-// CHECK:           return %[[VAL_36]] : tensor<32xf64>
+// CHECK:           %[[VAL_35:.*]] = memref.tensor_load %[[VAL_11]] : memref<32xf64>
+// CHECK:           return %[[VAL_35]] : tensor<32xf64>
 // CHECK:         }
 func @sub(%arga: tensor<32xf64, #SV>,
           %argb: tensor<32xf64>,

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_int_ops.mlir b/mlir/test/Dialect/SparseTensor/sparse_int_ops.mlir
index 84d2557c2b29d..11f82f9baf62f 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_int_ops.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_int_ops.mlir
@@ -345,3 +345,62 @@ func @or(%arga: tensor<32xi64, #SV>,
   return %0 : tensor<32xi64>
 }
 
+// CHECK-LABEL:   func @xor(
+// CHECK-SAME:             %[[VAL_0:.*]]: tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>>,
+// CHECK-SAME:             %[[VAL_1:.*]]: tensor<32xi64>,
+// CHECK-SAME:             %[[VAL_2:.*]]: tensor<32xi64> {linalg.inplaceable = true}) -> tensor<32xi64> {
+// CHECK:           %[[VAL_3:.*]] = constant 32 : index
+// CHECK:           %[[VAL_4:.*]] = constant 0 : index
+// CHECK:           %[[VAL_5:.*]] = constant true
+// CHECK:           %[[VAL_6:.*]] = constant 1 : index
+// CHECK:           %[[VAL_7:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_4]] : tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
+// CHECK:           %[[VAL_8:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_4]] : tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
+// CHECK:           %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xi64>
+// CHECK:           %[[VAL_10:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32xi64>
+// CHECK:           %[[VAL_11:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32xi64>
+// CHECK:           %[[VAL_12:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_4]]] : memref<?xindex>
+// CHECK:           %[[VAL_13:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_6]]] : memref<?xindex>
+// CHECK:           %[[VAL_14:.*]]:2 = scf.while (%[[VAL_15:.*]] = %[[VAL_12]], %[[VAL_16:.*]] = %[[VAL_4]]) : (index, index) -> (index, index) {
+// CHECK:             %[[VAL_17:.*]] = cmpi ult, %[[VAL_15]], %[[VAL_13]] : index
+// CHECK:             scf.condition(%[[VAL_17]]) %[[VAL_15]], %[[VAL_16]] : index, index
+// CHECK:           } do {
+// CHECK:           ^bb0(%[[VAL_18:.*]]: index, %[[VAL_19:.*]]: index):
+// CHECK:             %[[VAL_20:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_18]]] : memref<?xindex>
+// CHECK:             %[[VAL_21:.*]] = cmpi eq, %[[VAL_20]], %[[VAL_19]] : index
+// CHECK:             scf.if %[[VAL_21]] {
+// CHECK:               %[[VAL_22:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_18]]] : memref<?xi64>
+// CHECK:               %[[VAL_23:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_19]]] : memref<32xi64>
+// CHECK:               %[[VAL_24:.*]] = xor %[[VAL_22]], %[[VAL_23]] : i64
+// CHECK:               memref.store %[[VAL_24]], %[[VAL_11]]{{\[}}%[[VAL_19]]] : memref<32xi64>
+// CHECK:             } else {
+// CHECK:               scf.if %[[VAL_5]] {
+// CHECK:                 %[[VAL_25:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_19]]] : memref<32xi64>
+// CHECK:                 memref.store %[[VAL_25]], %[[VAL_11]]{{\[}}%[[VAL_19]]] : memref<32xi64>
+// CHECK:               } else {
+// CHECK:               }
+// CHECK:             }
+// CHECK:             %[[VAL_26:.*]] = cmpi eq, %[[VAL_20]], %[[VAL_19]] : index
+// CHECK:             %[[VAL_27:.*]] = addi %[[VAL_18]], %[[VAL_6]] : index
+// CHECK:             %[[VAL_28:.*]] = select %[[VAL_26]], %[[VAL_27]], %[[VAL_18]] : index
+// CHECK:             %[[VAL_29:.*]] = addi %[[VAL_19]], %[[VAL_6]] : index
+// CHECK:             scf.yield %[[VAL_28]], %[[VAL_29]] : index, index
+// CHECK:           }
+// CHECK:           scf.for %[[VAL_30:.*]] = %[[VAL_31:.*]]#1 to %[[VAL_3]] step %[[VAL_6]] {
+// CHECK:             %[[VAL_32:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_30]]] : memref<32xi64>
+// CHECK:             memref.store %[[VAL_32]], %[[VAL_11]]{{\[}}%[[VAL_30]]] : memref<32xi64>
+// CHECK:           }
+// CHECK:           %[[VAL_33:.*]] = memref.tensor_load %[[VAL_11]] : memref<32xi64>
+// CHECK:           return %[[VAL_33]] : tensor<32xi64>
+// CHECK:         }
+func @xor(%arga: tensor<32xi64, #SV>,
+          %argb: tensor<32xi64>,
+          %argx: tensor<32xi64> {linalg.inplaceable = true}) -> tensor<32xi64> {
+  %0 = linalg.generic #trait2
+     ins(%arga, %argb: tensor<32xi64, #SV>, tensor<32xi64>)
+    outs(%argx: tensor<32xi64>) {
+      ^bb(%a: i64, %b: i64, %x: i64):
+        %0 = xor %a, %b : i64
+        linalg.yield %0 : i64
+  } -> tensor<32xi64>
+  return %0 : tensor<32xi64>
+}

diff  --git a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
index b544ce16469ae..1cb4d085b23b7 100644
--- a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
+++ b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
@@ -145,14 +145,24 @@ class MergerTestBase : public ::testing::Test {
     switch (tensorExp.kind) {
     case Kind::kTensor:
       return tensorExp.tensor == pattern->tensorNum;
-    case Kind::kZero:
-      return true;
+    case Kind::kAbsF:
+    case Kind::kCeilF:
+    case Kind::kFloorF:
+    case Kind::kNegF:
+    case Kind::kNegI:
+      return compareExpression(tensorExp.children.e0, pattern->e0);
     case Kind::kMulF:
     case Kind::kMulI:
+    case Kind::kDivF:
+    case Kind::kDivS:
+    case Kind::kDivU:
     case Kind::kAddF:
     case Kind::kAddI:
     case Kind::kSubF:
     case Kind::kSubI:
+    case Kind::kAndI:
+    case Kind::kOrI:
+    case Kind::kXorI:
       return compareExpression(tensorExp.children.e0, pattern->e0) &&
              compareExpression(tensorExp.children.e1, pattern->e1);
     default:


        


More information about the Mlir-commits mailing list