[Mlir-commits] [mlir] 2b6e433 - [mlir][sparse] add shift ops support

Aart Bik llvmlistbot at llvm.org
Thu Jul 15 09:43:28 PDT 2021


Author: Aart Bik
Date: 2021-07-15T09:43:12-07:00
New Revision: 2b6e433230ab9fa8a898261cd460a3f1a1bc91ec

URL: https://github.com/llvm/llvm-project/commit/2b6e433230ab9fa8a898261cd460a3f1a1bc91ec
DIFF: https://github.com/llvm/llvm-project/commit/2b6e433230ab9fa8a898261cd460a3f1a1bc91ec.diff

LOG: [mlir][sparse] add shift ops support

Arbitrary shifts have some complications, but shift by invariants
(viz. tensor index exp only at left hand side) can be easily
handled with the conjunctive rule.

Reviewed By: gussmith23

Differential Revision: https://reviews.llvm.org/D106002

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
    mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
    mlir/test/Dialect/SparseTensor/sparse_int_ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
index aa1e52f25b343..9f6657dbf6fbf 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
@@ -47,6 +47,9 @@ enum Kind {
   kAndI,
   kOrI,
   kXorI,
+  kShrS, // signed
+  kShrU, // unsigned
+  kShlI,
 };
 
 /// Children subexpressions of tensor operations.
@@ -215,7 +218,8 @@ class Merger {
                  Value v1);
 
 private:
-  bool maybeZero(unsigned e);
+  bool maybeZero(unsigned e) const;
+  bool isInvariant(unsigned e) const;
 
   /// Traverses the SSA tree (possibly a DAG) to build a tensor expression.
   Optional<unsigned> buildTensorExp(linalg::GenericOp op, Value v);

diff  --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index b254f03139dc2..65bbb8284b822 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -208,13 +208,16 @@ bool Merger::isConjunction(unsigned t, unsigned e) const {
   case kFloorF:
   case kNegF:
   case kNegI:
+  case Kind::kDivF: // note: x / c only
+  case Kind::kDivS:
+  case Kind::kDivU:
+  case Kind::kShrS: // note: x >> inv only
+  case Kind::kShrU:
+  case Kind::kShlI:
     return isConjunction(t, tensorExps[e].children.e0);
   case Kind::kMulF:
   case Kind::kMulI:
   case Kind::kAndI:
-  case Kind::kDivF: // note: x / c only
-  case Kind::kDivS:
-  case Kind::kDivU:
     return isConjunction(t, tensorExps[e].children.e0) ||
            isConjunction(t, tensorExps[e].children.e1);
   default:
@@ -228,9 +231,9 @@ bool Merger::isConjunction(unsigned t, unsigned e) const {
 // Print methods (for debugging).
 //
 
-static const char *kOpSymbols[] = {"",  "",  "abs", "ceil", "floor", "-",
-                                   "-", "*", "*",   "/",    "/",     "+",
-                                   "+", "-", "-",   "&",    "|",     "^"};
+static const char *kOpSymbols[] = {
+    "",  "",  "abs", "ceil", "floor", "-", "-", "*",   "*",  "/", "/",
+    "+", "+", "-",   "-",    "&",     "|", "^", "a>>", ">>", "<<"};
 
 void Merger::dumpExp(unsigned e) const {
   switch (tensorExps[e].kind) {
@@ -383,6 +386,15 @@ unsigned Merger::buildLattices(unsigned e, unsigned i) {
     return takeDisj(kind, // take binary disjunction
                     buildLattices(tensorExps[e].children.e0, i),
                     buildLattices(tensorExps[e].children.e1, i));
+  case Kind::kShrS:
+  case Kind::kShrU:
+  case Kind::kShlI:
+    // A shift operation by an invariant amount (viz. tensor expressions
+    // can only occur at the left-hand-side of the operator) can be handled
+    // with the conjuction rule.
+    return takeConj(kind, // take binary conjunction
+                    buildLattices(tensorExps[e].children.e0, i),
+                    buildLattices(tensorExps[e].children.e1, i));
   }
   llvm_unreachable("unexpected expression kind");
 }
@@ -392,7 +404,7 @@ Optional<unsigned> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) {
   return buildTensorExp(op, yield->getOperand(0));
 }
 
-bool Merger::maybeZero(unsigned e) {
+bool Merger::maybeZero(unsigned e) const {
   if (tensorExps[e].kind == Kind::kInvariant) {
     if (auto c = tensorExps[e].val.getDefiningOp<ConstantIntOp>())
       return c.getValue() == 0;
@@ -402,6 +414,10 @@ bool Merger::maybeZero(unsigned e) {
   return true;
 }
 
+bool Merger::isInvariant(unsigned e) const {
+  return tensorExps[e].kind == Kind::kInvariant;
+}
+
 Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
   if (auto arg = v.dyn_cast<BlockArgument>()) {
     unsigned argN = arg.getArgNumber();
@@ -470,6 +486,12 @@ Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
         return addExp(Kind::kOrI, e0, e1);
       if (isa<XOrOp>(def))
         return addExp(Kind::kXorI, e0, e1);
+      if (isa<SignedShiftRightOp>(def) && isInvariant(e1))
+        return addExp(Kind::kShrS, e0, e1);
+      if (isa<UnsignedShiftRightOp>(def) && isInvariant(e1))
+        return addExp(Kind::kShrU, e0, e1);
+      if (isa<ShiftLeftOp>(def) && isInvariant(e1))
+        return addExp(Kind::kShlI, e0, e1);
     }
   }
   // Cannot build.
@@ -517,6 +539,12 @@ Value Merger::buildExp(PatternRewriter &rewriter, Location loc, unsigned e,
     return rewriter.create<OrOp>(loc, v0, v1);
   case Kind::kXorI:
     return rewriter.create<XOrOp>(loc, v0, v1);
+  case Kind::kShrS:
+    return rewriter.create<SignedShiftRightOp>(loc, v0, v1);
+  case Kind::kShrU:
+    return rewriter.create<UnsignedShiftRightOp>(loc, v0, v1);
+  case Kind::kShlI:
+    return rewriter.create<ShiftLeftOp>(loc, v0, v1);
   }
   llvm_unreachable("unexpected expression kind in build");
 }

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_int_ops.mlir b/mlir/test/Dialect/SparseTensor/sparse_int_ops.mlir
index 11f82f9baf62f..8b5396252ed77 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_int_ops.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_int_ops.mlir
@@ -404,3 +404,106 @@ func @xor(%arga: tensor<32xi64, #SV>,
   } -> tensor<32xi64>
   return %0 : tensor<32xi64>
 }
+
+// CHECK-LABEL:   func @ashrbyc(
+// CHECK-SAME:                  %[[VAL_0:.*]]: tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>>,
+// CHECK-SAME:                  %[[VAL_1:.*]]: tensor<32xi64> {linalg.inplaceable = true}) -> tensor<32xi64> {
+// CHECK:           %[[VAL_2:.*]] = constant 2 : i64
+// CHECK:           %[[VAL_3:.*]] = constant 0 : index
+// CHECK:           %[[VAL_4:.*]] = constant 1 : index
+// CHECK:           %[[VAL_5:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
+// CHECK:           %[[VAL_6:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
+// CHECK:           %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xi64>
+// CHECK:           %[[VAL_8:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32xi64>
+// CHECK:           %[[VAL_9:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref<?xindex>
+// CHECK:           %[[VAL_10:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref<?xindex>
+// CHECK:           scf.for %[[VAL_11:.*]] = %[[VAL_9]] to %[[VAL_10]] step %[[VAL_4]] {
+// CHECK:             %[[VAL_12:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_11]]] : memref<?xindex>
+// CHECK:             %[[VAL_13:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_11]]] : memref<?xi64>
+// CHECK:             %[[VAL_14:.*]] = shift_right_signed %[[VAL_13]], %[[VAL_2]] : i64
+// CHECK:             memref.store %[[VAL_14]], %[[VAL_8]]{{\[}}%[[VAL_12]]] : memref<32xi64>
+// CHECK:           }
+// CHECK:           %[[VAL_15:.*]] = memref.tensor_load %[[VAL_8]] : memref<32xi64>
+// CHECK:           return %[[VAL_15]] : tensor<32xi64>
+// CHECK:         }
+func @ashrbyc(%arga: tensor<32xi64, #SV>,
+              %argx: tensor<32xi64> {linalg.inplaceable = true}) -> tensor<32xi64> {
+  %c = constant 2 : i64
+  %0 = linalg.generic #traitc
+     ins(%arga: tensor<32xi64, #SV>)
+    outs(%argx: tensor<32xi64>) {
+      ^bb(%a: i64, %x: i64):
+        %0 = shift_right_signed %a, %c : i64
+        linalg.yield %0 : i64
+  } -> tensor<32xi64>
+  return %0 : tensor<32xi64>
+}
+
+// CHECK-LABEL:   func @lsrbyc(
+// CHECK-SAME:                 %[[VAL_0:.*]]: tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>>,
+// CHECK-SAME:                 %[[VAL_1:.*]]: tensor<32xi64> {linalg.inplaceable = true}) -> tensor<32xi64> {
+// CHECK:           %[[VAL_2:.*]] = constant 2 : i64
+// CHECK:           %[[VAL_3:.*]] = constant 0 : index
+// CHECK:           %[[VAL_4:.*]] = constant 1 : index
+// CHECK:           %[[VAL_5:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
+// CHECK:           %[[VAL_6:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
+// CHECK:           %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xi64>
+// CHECK:           %[[VAL_8:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32xi64>
+// CHECK:           %[[VAL_9:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref<?xindex>
+// CHECK:           %[[VAL_10:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref<?xindex>
+// CHECK:           scf.for %[[VAL_11:.*]] = %[[VAL_9]] to %[[VAL_10]] step %[[VAL_4]] {
+// CHECK:             %[[VAL_12:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_11]]] : memref<?xindex>
+// CHECK:             %[[VAL_13:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_11]]] : memref<?xi64>
+// CHECK:             %[[VAL_14:.*]] = shift_right_unsigned %[[VAL_13]], %[[VAL_2]] : i64
+// CHECK:             memref.store %[[VAL_14]], %[[VAL_8]]{{\[}}%[[VAL_12]]] : memref<32xi64>
+// CHECK:           }
+// CHECK:           %[[VAL_15:.*]] = memref.tensor_load %[[VAL_8]] : memref<32xi64>
+// CHECK:           return %[[VAL_15]] : tensor<32xi64>
+// CHECK:         }
+func @lsrbyc(%arga: tensor<32xi64, #SV>,
+             %argx: tensor<32xi64> {linalg.inplaceable = true}) -> tensor<32xi64> {
+  %c = constant 2 : i64
+  %0 = linalg.generic #traitc
+     ins(%arga: tensor<32xi64, #SV>)
+    outs(%argx: tensor<32xi64>) {
+      ^bb(%a: i64, %x: i64):
+        %0 = shift_right_unsigned %a, %c : i64
+        linalg.yield %0 : i64
+  } -> tensor<32xi64>
+  return %0 : tensor<32xi64>
+}
+
+// CHECK-LABEL:   func @lslbyc(
+// CHECK-SAME:                 %[[VAL_0:.*]]: tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>>,
+// CHECK-SAME:                 %[[VAL_1:.*]]: tensor<32xi64> {linalg.inplaceable = true}) -> tensor<32xi64> {
+// CHECK:           %[[VAL_2:.*]] = constant 2 : i64
+// CHECK:           %[[VAL_3:.*]] = constant 0 : index
+// CHECK:           %[[VAL_4:.*]] = constant 1 : index
+// CHECK:           %[[VAL_5:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
+// CHECK:           %[[VAL_6:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
+// CHECK:           %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xi64>
+// CHECK:           %[[VAL_8:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32xi64>
+// CHECK:           %[[VAL_9:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref<?xindex>
+// CHECK:           %[[VAL_10:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref<?xindex>
+// CHECK:           scf.for %[[VAL_11:.*]] = %[[VAL_9]] to %[[VAL_10]] step %[[VAL_4]] {
+// CHECK:             %[[VAL_12:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_11]]] : memref<?xindex>
+// CHECK:             %[[VAL_13:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_11]]] : memref<?xi64>
+// CHECK:             %[[VAL_14:.*]] = shift_left %[[VAL_13]], %[[VAL_2]] : i64
+// CHECK:             memref.store %[[VAL_14]], %[[VAL_8]]{{\[}}%[[VAL_12]]] : memref<32xi64>
+// CHECK:           }
+// CHECK:           %[[VAL_15:.*]] = memref.tensor_load %[[VAL_8]] : memref<32xi64>
+// CHECK:           return %[[VAL_15]] : tensor<32xi64>
+// CHECK:         }
+func @lslbyc(%arga: tensor<32xi64, #SV>,
+             %argx: tensor<32xi64> {linalg.inplaceable = true}) -> tensor<32xi64> {
+  %c = constant 2 : i64
+  %0 = linalg.generic #traitc
+     ins(%arga: tensor<32xi64, #SV>)
+    outs(%argx: tensor<32xi64>) {
+      ^bb(%a: i64, %x: i64):
+        %0 = shift_left %a, %c : i64
+        linalg.yield %0 : i64
+  } -> tensor<32xi64>
+  return %0 : tensor<32xi64>
+}
+


        


More information about the Mlir-commits mailing list