[Mlir-commits] [mlir] 45b3cfe - [mlir][sparse] add support for AND and OR operations

Aart Bik llvmlistbot at llvm.org
Mon Jul 12 17:47:33 PDT 2021


Author: Aart Bik
Date: 2021-07-12T17:47:18-07:00
New Revision: 45b3cfe8437f78faac2c84b796bb246e16382252

URL: https://github.com/llvm/llvm-project/commit/45b3cfe8437f78faac2c84b796bb246e16382252
DIFF: https://github.com/llvm/llvm-project/commit/45b3cfe8437f78faac2c84b796bb246e16382252.diff

LOG: [mlir][sparse] add support for AND and OR operations

Integral AND and OR follow the simple conjunction and disjuction rules
for lattice building. This revision also completes some of the Merge
refactoring by moving the remainder parts that are merger specific from
sparsification into utils files.

Reviewed By: gussmith23

Differential Revision: https://reviews.llvm.org/D105851

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
    mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
    mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
    mlir/test/Dialect/SparseTensor/sparse_int_ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
index 109ad5c5d34ce..d81cb37b03f9a 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
@@ -38,7 +38,9 @@ enum class Kind {
   kAddF,
   kAddI,
   kSubF,
-  kSubI
+  kSubI,
+  kAndI,
+  kOrI,
 };
 
 /// Children subexpressions of tensor operations.
@@ -171,6 +173,11 @@ class Merger {
   /// Returns true if any set bit corresponds to queried dim.
   bool hasAnyDimOf(const llvm::BitVector &bits, Dim d) const;
 
+  /// Returns true if given tensor co-iterates with conjunction only in the
+  /// given tensor expression. For the output tensor, this defines a "simply
+  /// dynamic" operation [Bik96]. For instance: a(i) *=  b(i) * c(i)
+  bool isConjunction(unsigned t, unsigned e) const;
+
   /// Dimension setter.
   void setDim(unsigned t, unsigned i, Dim d) { dims[t][i] = d; }
 
@@ -193,17 +200,21 @@ class Merger {
   /// Builds the iteration lattices in a bottom-up traversal given the remaining
   /// tensor (sub)expression and the next loop index in the iteration graph.
   /// Returns index of the root expression.
-  unsigned buildLattices(unsigned exp, unsigned idx);
+  unsigned buildLattices(unsigned e, unsigned i);
 
   /// Builds a tensor expression from the given Linalg operation.
   /// Returns index of the root expression on success.
   Optional<unsigned> buildTensorExpFromLinalg(linalg::GenericOp op);
 
+  /// Rebuilds SSA format from a tensor expression.
+  Value buildExp(PatternRewriter &rewriter, Location loc, unsigned e, Value v0,
+                 Value v1);
+
 private:
   bool maybeZero(unsigned e);
 
   /// Traverses the SSA tree (possibly a DAG) to build a tensor expression.
-  Optional<unsigned> buildTensorExp(linalg::GenericOp op, Value val);
+  Optional<unsigned> buildTensorExp(linalg::GenericOp op, Value v);
 
   const unsigned outTensor;
   const unsigned syntheticTensor;

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index dc252f2d0a8b2..9d5447e181c87 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -208,22 +208,6 @@ static bool computeIterationGraph(Merger &merger, linalg::GenericOp op,
   return true;
 }
 
-/// Returns true if given tensor co-iterates with conjunction only.
-/// For the output tensor, this defines a "simply dynamic" operation.
-/// For instance: A(I) = A(I) * B(I) * C(I)
-static unsigned isConjunction(Merger &merger, unsigned tensor, unsigned exp) {
-  switch (merger.exp(exp).kind) {
-  case Kind::kTensor:
-    return merger.exp(exp).tensor == tensor;
-  case Kind::kMulF:
-  case Kind::kMulI:
-    return isConjunction(merger, tensor, merger.exp(exp).children.e0) ||
-           isConjunction(merger, tensor, merger.exp(exp).children.e1);
-  default:
-    return false;
-  }
-}
-
 /// Returns true when the tensor expression is admissable for codegen.
 /// Since all sparse input tensors are admissable, we just need to check
 /// whether the output tensor in the tensor expression codegen is admissable.
@@ -250,7 +234,7 @@ static bool isAdmissableTensorExp(Merger &merger, linalg::GenericOp op,
   // A tensor expression with a sparse output tensor that changes its values
   // but not its nonzero structure, an operation called "simply dynamic" in
   // [Bik96,Ch9], is also admissable without special codegen.
-  if (isConjunction(merger, tensor, exp))
+  if (merger.isConjunction(tensor, exp))
     return true;
   // Reject for now since this requires changes to the nonzero structure.
   // TODO: implement "workspaces" [Kjolstad2019]
@@ -637,31 +621,7 @@ static Value genExp(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
   }
   Value v0 = genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e0);
   Value v1 = genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e1);
-  switch (merger.exp(exp).kind) {
-  case Kind::kTensor:
-  case Kind::kInvariant:
-  case Kind::kZero:
-    llvm_unreachable("handled above");
-  case Kind::kMulF:
-    return rewriter.create<MulFOp>(loc, v0, v1);
-  case Kind::kMulI:
-    return rewriter.create<MulIOp>(loc, v0, v1);
-  case Kind::kDivF:
-    return rewriter.create<DivFOp>(loc, v0, v1);
-  case Kind::kDivS:
-    return rewriter.create<SignedDivIOp>(loc, v0, v1);
-  case Kind::kDivU:
-    return rewriter.create<UnsignedDivIOp>(loc, v0, v1);
-  case Kind::kAddF:
-    return rewriter.create<AddFOp>(loc, v0, v1);
-  case Kind::kAddI:
-    return rewriter.create<AddIOp>(loc, v0, v1);
-  case Kind::kSubF:
-    return rewriter.create<SubFOp>(loc, v0, v1);
-  case Kind::kSubI:
-    return rewriter.create<SubIOp>(loc, v0, v1);
-  }
-  llvm_unreachable("unexpected expression kind");
+  return merger.buildExp(rewriter, loc, exp, v0, v1);
 }
 
 /// Hoists loop invariant tensor loads for which indices have been exhausted.

diff  --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index 4ec748f426efa..74a3a24738eaf 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -190,6 +190,23 @@ bool Merger::hasAnyDimOf(const llvm::BitVector &bits, Dim d) const {
   return false;
 }
 
+bool Merger::isConjunction(unsigned t, unsigned e) const {
+  switch (tensorExps[e].kind) {
+  case Kind::kTensor:
+    return tensorExps[e].tensor == t;
+  case Kind::kMulF:
+  case Kind::kMulI:
+  case Kind::kAndI:
+  case Kind::kDivF: // note: x / c only
+  case Kind::kDivS:
+  case Kind::kDivU:
+    return isConjunction(t, tensorExps[e].children.e0) ||
+           isConjunction(t, tensorExps[e].children.e1);
+  default:
+    return false;
+  }
+}
+
 #ifndef NDEBUG
 
 //
@@ -211,6 +228,10 @@ static char kindToOpSymbol(Kind kind) {
   case Kind::kSubF:
   case Kind::kSubI:
     return '-';
+  case Kind::kAndI:
+    return '&';
+  case Kind::kOrI:
+    return '|';
   default:
     break;
   }
@@ -290,7 +311,7 @@ void Merger::dumpBits(const llvm::BitVector &bits) const {
 // Builder methods.
 //
 
-unsigned Merger::buildLattices(unsigned e, unsigned idx) {
+unsigned Merger::buildLattices(unsigned e, unsigned i) {
   Kind kind = tensorExps[e].kind;
   switch (kind) {
   case Kind::kTensor:
@@ -301,11 +322,12 @@ unsigned Merger::buildLattices(unsigned e, unsigned idx) {
     // is set to a synthetic tensor with undefined indices only.
     unsigned s = addSet();
     unsigned t = kind == Kind::kTensor ? tensorExps[e].tensor : syntheticTensor;
-    latSets[s].push_back(addLat(t, idx, e));
+    latSets[s].push_back(addLat(t, i, e));
     return s;
   }
   case Kind::kMulF:
   case Kind::kMulI:
+  case Kind::kAndI:
     // A multiplicative operation only needs to be performed
     // for the conjunction of sparse iteration spaces.
     //
@@ -314,8 +336,8 @@ unsigned Merger::buildLattices(unsigned e, unsigned idx) {
     //  !x | 0 | 0 |
     //   x | 0 |x*y|
     return takeConj(kind, // take binary conjunction
-                    buildLattices(tensorExps[e].children.e0, idx),
-                    buildLattices(tensorExps[e].children.e1, idx));
+                    buildLattices(tensorExps[e].children.e0, i),
+                    buildLattices(tensorExps[e].children.e1, i));
   case Kind::kDivF:
   case Kind::kDivS:
   case Kind::kDivU:
@@ -333,17 +355,18 @@ unsigned Merger::buildLattices(unsigned e, unsigned idx) {
     //       rules applies (viz. x/c = x*(1/c) as far as lattice
     //       construction is concerned).
     return takeConj(kind, // take binary conjunction
-                    buildLattices(tensorExps[e].children.e0, idx),
-                    buildLattices(tensorExps[e].children.e1, idx));
+                    buildLattices(tensorExps[e].children.e0, i),
+                    buildLattices(tensorExps[e].children.e1, i));
   case Kind::kSubF:
   case Kind::kSubI:
     // Special case: 0-y is -y.
     if (tensorExps[tensorExps[e].children.e0].kind == Kind::kZero)
       return mapZero(kind, // maps to 0-y with just y's lattices
-                     buildLattices(tensorExps[e].children.e1, idx));
+                     buildLattices(tensorExps[e].children.e1, i));
     LLVM_FALLTHROUGH;
   case Kind::kAddF:
   case Kind::kAddI:
+  case Kind::kOrI:
     // An additive operation needs to be performed
     // for the disjunction of sparse iteration spaces.
     //
@@ -352,8 +375,8 @@ unsigned Merger::buildLattices(unsigned e, unsigned idx) {
     //  !x | 0 | y |    !x | 0 |-y |
     //   x | x |x+y|     x | x |x-y|
     return takeDisj(kind, // take binary disjunction
-                    buildLattices(tensorExps[e].children.e0, idx),
-                    buildLattices(tensorExps[e].children.e1, idx));
+                    buildLattices(tensorExps[e].children.e0, i),
+                    buildLattices(tensorExps[e].children.e1, i));
   }
   llvm_unreachable("unexpected expression kind");
 }
@@ -373,8 +396,8 @@ bool Merger::maybeZero(unsigned e) {
   return true;
 }
 
-Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value val) {
-  if (auto arg = val.dyn_cast<BlockArgument>()) {
+Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
+  if (auto arg = v.dyn_cast<BlockArgument>()) {
     unsigned argN = arg.getArgNumber();
     // Any argument of the generic op that is not marked as a scalar
     // argument is considered a tensor, indexed by the implicit loop
@@ -383,16 +406,16 @@ Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value val) {
       OpOperand *t = op.getInputAndOutputOperands()[argN];
       if (!op.isScalar(t))
         return addExp(Kind::kTensor, argN);
-      val = t->get(); // get scalar value
+      v = t->get(); // get scalar value
     }
     // Any other argument (marked as scalar argument for the generic op
     // or belonging to an enveloping op) is considered invariant.
-    return addExp(Kind::kInvariant, val);
+    return addExp(Kind::kInvariant, v);
   }
   // Something defined outside is invariant.
-  Operation *def = val.getDefiningOp();
+  Operation *def = v.getDefiningOp();
   if (def->getBlock() != &op.region().front())
-    return addExp(Kind::kInvariant, val);
+    return addExp(Kind::kInvariant, v);
   // Construct unary operations if subexpression can be built.
   if (def->getNumOperands() == 1) {
     auto x = buildTensorExp(op, def->getOperand(0));
@@ -430,11 +453,48 @@ Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value val) {
         return addExp(Kind::kSubF, e0, e1);
       if (isa<SubIOp>(def))
         return addExp(Kind::kSubI, e0, e1);
+      if (isa<AndOp>(def))
+        return addExp(Kind::kAndI, e0, e1);
+      if (isa<OrOp>(def))
+        return addExp(Kind::kOrI, e0, e1);
     }
   }
   // Cannot build.
   return None;
 }
 
+Value Merger::buildExp(PatternRewriter &rewriter, Location loc, unsigned e,
+                       Value v0, Value v1) {
+  switch (tensorExps[e].kind) {
+  case Kind::kTensor:
+  case Kind::kInvariant:
+  case Kind::kZero:
+    llvm_unreachable("unexpected non-op");
+  case Kind::kMulF:
+    return rewriter.create<MulFOp>(loc, v0, v1);
+  case Kind::kMulI:
+    return rewriter.create<MulIOp>(loc, v0, v1);
+  case Kind::kDivF:
+    return rewriter.create<DivFOp>(loc, v0, v1);
+  case Kind::kDivS:
+    return rewriter.create<SignedDivIOp>(loc, v0, v1);
+  case Kind::kDivU:
+    return rewriter.create<UnsignedDivIOp>(loc, v0, v1);
+  case Kind::kAddF:
+    return rewriter.create<AddFOp>(loc, v0, v1);
+  case Kind::kAddI:
+    return rewriter.create<AddIOp>(loc, v0, v1);
+  case Kind::kSubF:
+    return rewriter.create<SubFOp>(loc, v0, v1);
+  case Kind::kSubI:
+    return rewriter.create<SubIOp>(loc, v0, v1);
+  case Kind::kAndI:
+    return rewriter.create<AndOp>(loc, v0, v1);
+  case Kind::kOrI:
+    return rewriter.create<OrOp>(loc, v0, v1);
+  }
+  llvm_unreachable("unexpected expression kind in build");
+}
+
 } // namespace sparse_tensor
 } // namespace mlir

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_int_ops.mlir b/mlir/test/Dialect/SparseTensor/sparse_int_ops.mlir
index 3161ee51d0217..84d2557c2b29d 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_int_ops.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_int_ops.mlir
@@ -248,3 +248,100 @@ func @divubyc(%arga: tensor<32xi64, #SV>,
   } -> tensor<32xi64>
   return %0 : tensor<32xi64>
 }
+
+// CHECK-LABEL:   func @and(
+// CHECK-SAME:              %[[VAL_0:.*]]: tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>>,
+// CHECK-SAME:              %[[VAL_1:.*]]: tensor<32xi64>,
+// CHECK-SAME:              %[[VAL_2:.*]]: tensor<32xi64> {linalg.inplaceable = true}) -> tensor<32xi64> {
+// CHECK:           %[[VAL_3:.*]] = constant 0 : index
+// CHECK:           %[[VAL_4:.*]] = constant 1 : index
+// CHECK:           %[[VAL_5:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
+// CHECK:           %[[VAL_6:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
+// CHECK:           %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xi64>
+// CHECK:           %[[VAL_8:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32xi64>
+// CHECK:           %[[VAL_9:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32xi64>
+// CHECK:           %[[VAL_10:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref<?xindex>
+// CHECK:           %[[VAL_11:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref<?xindex>
+// CHECK:           scf.for %[[VAL_12:.*]] = %[[VAL_10]] to %[[VAL_11]] step %[[VAL_4]] {
+// CHECK:             %[[VAL_13:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_12]]] : memref<?xindex>
+// CHECK:             %[[VAL_14:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_12]]] : memref<?xi64>
+// CHECK:             %[[VAL_15:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_13]]] : memref<32xi64>
+// CHECK:             %[[VAL_16:.*]] = and %[[VAL_14]], %[[VAL_15]] : i64
+// CHECK:             memref.store %[[VAL_16]], %[[VAL_9]]{{\[}}%[[VAL_13]]] : memref<32xi64>
+// CHECK:           }
+// CHECK:           %[[VAL_17:.*]] = memref.tensor_load %[[VAL_9]] : memref<32xi64>
+// CHECK:           return %[[VAL_17]] : tensor<32xi64>
+// CHECK:         }
+func @and(%arga: tensor<32xi64, #SV>,
+          %argb: tensor<32xi64>,
+          %argx: tensor<32xi64> {linalg.inplaceable = true}) -> tensor<32xi64> {
+  %0 = linalg.generic #trait2
+     ins(%arga, %argb: tensor<32xi64, #SV>, tensor<32xi64>)
+    outs(%argx: tensor<32xi64>) {
+      ^bb(%a: i64, %b: i64, %x: i64):
+        %0 = and %a, %b : i64
+        linalg.yield %0 : i64
+  } -> tensor<32xi64>
+  return %0 : tensor<32xi64>
+}
+
+// CHECK-LABEL:   func @or(
+// CHECK-SAME:             %[[VAL_0:.*]]: tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>>,
+// CHECK-SAME:             %[[VAL_1:.*]]: tensor<32xi64>,
+// CHECK-SAME:             %[[VAL_2:.*]]: tensor<32xi64> {linalg.inplaceable = true}) -> tensor<32xi64> {
+// CHECK:           %[[VAL_3:.*]] = constant 32 : index
+// CHECK:           %[[VAL_4:.*]] = constant 0 : index
+// CHECK:           %[[VAL_5:.*]] = constant true
+// CHECK:           %[[VAL_6:.*]] = constant 1 : index
+// CHECK:           %[[VAL_7:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_4]] : tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
+// CHECK:           %[[VAL_8:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_4]] : tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
+// CHECK:           %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xi64>
+// CHECK:           %[[VAL_10:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32xi64>
+// CHECK:           %[[VAL_11:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32xi64>
+// CHECK:           %[[VAL_12:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_4]]] : memref<?xindex>
+// CHECK:           %[[VAL_13:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_6]]] : memref<?xindex>
+// CHECK:           %[[VAL_14:.*]]:2 = scf.while (%[[VAL_15:.*]] = %[[VAL_12]], %[[VAL_16:.*]] = %[[VAL_4]]) : (index, index) -> (index, index) {
+// CHECK:             %[[VAL_17:.*]] = cmpi ult, %[[VAL_15]], %[[VAL_13]] : index
+// CHECK:             scf.condition(%[[VAL_17]]) %[[VAL_15]], %[[VAL_16]] : index, index
+// CHECK:           } do {
+// CHECK:           ^bb0(%[[VAL_18:.*]]: index, %[[VAL_19:.*]]: index):
+// CHECK:             %[[VAL_20:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_18]]] : memref<?xindex>
+// CHECK:             %[[VAL_21:.*]] = cmpi eq, %[[VAL_20]], %[[VAL_19]] : index
+// CHECK:             scf.if %[[VAL_21]] {
+// CHECK:               %[[VAL_22:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_18]]] : memref<?xi64>
+// CHECK:               %[[VAL_23:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_19]]] : memref<32xi64>
+// CHECK:               %[[VAL_24:.*]] = or %[[VAL_22]], %[[VAL_23]] : i64
+// CHECK:               memref.store %[[VAL_24]], %[[VAL_11]]{{\[}}%[[VAL_19]]] : memref<32xi64>
+// CHECK:             } else {
+// CHECK:               scf.if %[[VAL_5]] {
+// CHECK:                 %[[VAL_25:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_19]]] : memref<32xi64>
+// CHECK:                 memref.store %[[VAL_25]], %[[VAL_11]]{{\[}}%[[VAL_19]]] : memref<32xi64>
+// CHECK:               } else {
+// CHECK:               }
+// CHECK:             }
+// CHECK:             %[[VAL_26:.*]] = cmpi eq, %[[VAL_20]], %[[VAL_19]] : index
+// CHECK:             %[[VAL_27:.*]] = addi %[[VAL_18]], %[[VAL_6]] : index
+// CHECK:             %[[VAL_28:.*]] = select %[[VAL_26]], %[[VAL_27]], %[[VAL_18]] : index
+// CHECK:             %[[VAL_29:.*]] = addi %[[VAL_19]], %[[VAL_6]] : index
+// CHECK:             scf.yield %[[VAL_28]], %[[VAL_29]] : index, index
+// CHECK:           }
+// CHECK:           scf.for %[[VAL_30:.*]] = %[[VAL_31:.*]]#1 to %[[VAL_3]] step %[[VAL_6]] {
+// CHECK:             %[[VAL_32:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_30]]] : memref<32xi64>
+// CHECK:             memref.store %[[VAL_32]], %[[VAL_11]]{{\[}}%[[VAL_30]]] : memref<32xi64>
+// CHECK:           }
+// CHECK:           %[[VAL_33:.*]] = memref.tensor_load %[[VAL_11]] : memref<32xi64>
+// CHECK:           return %[[VAL_33]] : tensor<32xi64>
+// CHECK:         }
+func @or(%arga: tensor<32xi64, #SV>,
+         %argb: tensor<32xi64>,
+         %argx: tensor<32xi64> {linalg.inplaceable = true}) -> tensor<32xi64> {
+  %0 = linalg.generic #trait2
+     ins(%arga, %argb: tensor<32xi64, #SV>, tensor<32xi64>)
+    outs(%argx: tensor<32xi64>) {
+      ^bb(%a: i64, %b: i64, %x: i64):
+        %0 = or %a, %b : i64
+        linalg.yield %0 : i64
+  } -> tensor<32xi64>
+  return %0 : tensor<32xi64>
+}
+


        


More information about the Mlir-commits mailing list