[Mlir-commits] [mlir] e7df828 - [mlir][sparse] rewrite arith::SelectOp to semiring operations to sparsify it.

Peiming Liu llvmlistbot at llvm.org
Wed Jun 21 14:22:22 PDT 2023


Author: Peiming Liu
Date: 2023-06-21T21:22:18Z
New Revision: e7df82816b6af3e2929a703718d9ef9dcd55b5f4

URL: https://github.com/llvm/llvm-project/commit/e7df82816b6af3e2929a703718d9ef9dcd55b5f4
DIFF: https://github.com/llvm/llvm-project/commit/e7df82816b6af3e2929a703718d9ef9dcd55b5f4.diff

LOG: [mlir][sparse] rewrite arith::SelectOp to semiring operations to sparsify it.

Reviewed By: aartbik, K-Wu

Differential Revision: https://reviews.llvm.org/D153397

Added: 
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_semiring_select.mlir

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
    mlir/test/Dialect/SparseTensor/pre_rewriting.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index bb9c15ea463da..ebbe88ee90294 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -42,10 +42,13 @@ static bool isZeroValue(Value val) {
 }
 
 // Helper to detect a sparse tensor type operand.
-static bool isSparseTensor(OpOperand *op) {
-  auto enc = getSparseTensorEncoding(op->get().getType());
-  return enc && llvm::is_contained(enc.getLvlTypes(), DimLevelType::Compressed);
+static bool isSparseTensor(Value v) {
+  auto enc = getSparseTensorEncoding(v.getType());
+  return enc && !llvm::all_of(enc.getLvlTypes(), [](auto dlt) {
+           return dlt == DimLevelType::Dense;
+         });
 }
+static bool isSparseTensor(OpOperand *op) { return isSparseTensor(op->get()); }
 
 // Helper method to find zero/uninitialized allocation.
 static bool isAlloc(OpOperand *op, bool isZero) {
@@ -387,6 +390,137 @@ struct FuseTensorCast : public OpRewritePattern<tensor::CastOp> {
   }
 };
 
+/// Rewrites a sequence of operations for sparse tensor selections in to
+/// semi-ring operations such that they can be compiled correctly by the sparse
+/// compiler. E.g., transforming the following sequence
+///
+/// %sel = arith.select %cond, %sp1, %sp2
+///
+/// to
+///
+/// %sel = binary %sp1, %sp2:
+///         both  (%l, %r) {yield select %cond, %l, %r}
+///         left  (%l)     {yield select %cond, %l,  0}
+///         right (%r)     {yield select %cond,  0, %r}
+///
+/// TODO: We require that the tensor used for extracting conditions to be dense
+/// to sparsify the code. To support a sparse condition tensor, we need a
+/// tri-nary operation.
+struct GenSemiRingSelect : public OpRewritePattern<GenericOp> {
+public:
+  using OpRewritePattern<GenericOp>::OpRewritePattern;
+  LogicalResult matchAndRewrite(GenericOp op,
+                                PatternRewriter &rewriter) const override {
+    // Rejects non sparse kernels.
+    if (!op.hasTensorSemantics() || !hasAnySparseOperand(op))
+      return failure();
+
+    Location loc = op.getLoc();
+    SmallVector<std::pair<Operation *, sparse_tensor::BinaryOp>> semiRings;
+    for (Operation &inst : *op.getBody()) {
+      // Matches pattern.
+      auto matched = isRewritablePattern(op, &inst);
+      if (!matched.has_value())
+        continue;
+
+      rewriter.setInsertionPoint(&inst);
+      auto [c, t, f] = matched.value();
+      assert(t.getType() == f.getType());
+      auto selTp = t.getType();
+      auto c0 = constantZero(rewriter, loc, selTp);
+      auto binOp = rewriter.create<sparse_tensor::BinaryOp>(loc, selTp, t, f);
+      // Initializes all the blocks.
+      rewriter.createBlock(&binOp.getOverlapRegion(), {}, {selTp, selTp},
+                           {t.getLoc(), f.getLoc()});
+      rewriter.createBlock(&binOp.getRightRegion(), {}, selTp, f.getLoc());
+      rewriter.createBlock(&binOp.getLeftRegion(), {}, selTp, t.getLoc());
+
+      for (auto *r : binOp.getRegions()) {
+        Block *b = &r->front();
+        rewriter.setInsertionPointToStart(b);
+
+        IRMapping irMap;
+        // Clones the cmp operations into the region to make the binary op
+        // admissible.
+        Value newC = c;
+        if (auto *def = c.getDefiningOp())
+          newC = rewriter.clone(*def, irMap)->getResult(0);
+
+        irMap.map(c, newC);
+        if (r == &binOp.getLeftRegion()) {
+          irMap.map(t, b->getArgument(0));
+          irMap.map(f, c0);
+        } else if (r == &binOp.getRightRegion()) {
+          irMap.map(t, c0);
+          irMap.map(f, b->getArgument(0));
+        } else {
+          irMap.map(t, b->getArgument(0));
+          irMap.map(f, b->getArgument(1));
+        }
+        auto y = rewriter.clone(inst, irMap)->getResult(0);
+        rewriter.create<sparse_tensor::YieldOp>(loc, y);
+      }
+
+      // We successfully rewrited a operation. We can not do replacement here
+      // becuase it invalidate the iterator for the current loop to traverse
+      // the instructions.
+      semiRings.emplace_back(&inst, binOp);
+    }
+
+    // Finalizes the replacement.
+    for (auto [sel, semi] : semiRings)
+      rewriter.replaceOp(sel, semi->getResults());
+
+    return success(!semiRings.empty());
+  }
+
+private:
+  static std::optional<std::tuple<Value, BlockArgument, BlockArgument>>
+  isRewritablePattern(GenericOp op, Operation *v) {
+    auto sel = dyn_cast<arith::SelectOp>(v);
+    if (!sel)
+      return std::nullopt;
+
+    auto tVal = sel.getTrueValue().dyn_cast<BlockArgument>();
+    auto fVal = sel.getFalseValue().dyn_cast<BlockArgument>();
+    // TODO: For simplicity, we only handle cases where both true/false value
+    // are directly loaded the input tensor. We can probably admit more cases
+    // in theory.
+    if (!tVal || !fVal)
+      return std::nullopt;
+
+    // Helper lambda to determine whether the value is loaded from a dense input
+    // or is a loop invariant.
+    auto isValFromDenseInputOrInvariant = [&op](Value v) -> bool {
+      if (auto bArg = v.dyn_cast<BlockArgument>();
+          bArg && !isSparseTensor(op.getDpsInputOperand(bArg.getArgNumber())))
+        return true;
+      // If the value is defined outside the loop, it is a loop invariant.
+      return v.getDefiningOp() && v.getDefiningOp()->getBlock() != op.getBody();
+    };
+
+    // If the condition value is load directly from a dense tensor or
+    // loop-invariants, we can sparsify the kernel.
+    auto cond = sel.getCondition();
+    if (isValFromDenseInputOrInvariant(cond))
+      return std::make_tuple(cond, tVal, fVal);
+
+    Value cmpL, cmpR;
+    if (matchPattern(cond, m_Op<arith::CmpIOp>(matchers::m_Any(&cmpL),
+                                               matchers::m_Any(&cmpR))) ||
+        matchPattern(cond, m_Op<arith::CmpFOp>(matchers::m_Any(&cmpL),
+                                               matchers::m_Any(&cmpR)))) {
+      // TODO: we can do it recursively to check whether all the leaf values are
+      // loaded from dense tensors or are loop invariants.
+      if (isValFromDenseInputOrInvariant(cmpL) ||
+          isValFromDenseInputOrInvariant(cmpR))
+        return std::make_tuple(cond, tVal, fVal);
+    }
+
+    return std::nullopt;
+  };
+};
+
 /// Rewrites a sparse reduction that would not sparsify directly since
 /// doing so would only iterate over the stored elements, ignoring the
 /// implicit zeros, into a semi-ring. Applies to all prod/and/min/max
@@ -1348,7 +1482,7 @@ struct OutRewriter : public OpRewritePattern<OutOp> {
 
 void mlir::populatePreSparsificationRewriting(RewritePatternSet &patterns) {
   patterns.add<FoldInvariantYield, FuseSparseMultiplyOverAdd, FuseTensorCast,
-               GenSemiRingReduction>(patterns.getContext());
+               GenSemiRingReduction, GenSemiRingSelect>(patterns.getContext());
 }
 
 void mlir::populatePostSparsificationRewriting(RewritePatternSet &patterns,

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 881e02ea0f91c..b767f61598ff7 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -1211,7 +1211,8 @@ static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e,
     if (ee &&
         (kind == TensorExp::Kind::kUnary || kind == TensorExp::Kind::kBinary ||
          kind == TensorExp::Kind::kBinaryBranch ||
-         kind == TensorExp::Kind::kReduce || kind == TensorExp::Kind::kSelect)) {
+         kind == TensorExp::Kind::kReduce ||
+         kind == TensorExp::Kind::kSelect)) {
       OpBuilder::InsertionGuard guard(rewriter);
       ee = relinkBranch(env, rewriter, ee.getParentBlock(), ee, ldx);
     }

diff  --git a/mlir/test/Dialect/SparseTensor/pre_rewriting.mlir b/mlir/test/Dialect/SparseTensor/pre_rewriting.mlir
index 8aed1d6d205bd..feda768ed29f7 100644
--- a/mlir/test/Dialect/SparseTensor/pre_rewriting.mlir
+++ b/mlir/test/Dialect/SparseTensor/pre_rewriting.mlir
@@ -8,11 +8,25 @@
   lvlTypes = [ "compressed-nu", "singleton" ]
 }>
 
+#DCSR = #sparse_tensor.encoding<{
+  lvlTypes = ["compressed", "compressed"]
+}>
+
 #Slice = #sparse_tensor.encoding<{
   lvlTypes = [ "compressed-nu", "singleton" ],
   dimSlices = [ (?, 1, 1), (?, 3, 1) ]
 }>
 
+#sel_trait = {
+  indexing_maps = [
+    affine_map<(i,j) -> (i,j)>,  // C (in)
+    affine_map<(i,j) -> (i,j)>,  // L (in)
+    affine_map<(i,j) -> (i,j)>,  // R (in)
+    affine_map<(i,j) -> (i,j)>   // X (out)
+  ],
+  iterator_types = ["parallel", "parallel"]
+}
+
 // CHECK-LABEL: func @sparse_nop_cast(
 //  CHECK-SAME: %[[A:.*]]: tensor<?xf32, #sparse_tensor.encoding<{{{.*}}}>>)
 //       CHECK: return %[[A]] : tensor<?xf32, #sparse_tensor.encoding<{{{.*}}}>>
@@ -43,3 +57,46 @@ func.func @sparse_fuse_slice(%a : tensor<2x3xi64, #SortedCOO>) -> tensor<1x3xi64
   %0 = sparse_tensor.convert %cast : tensor<1x3xi64, #Slice> to tensor<1x3xi64, #SortedCOO>
   return %0 : tensor<1x3xi64, #SortedCOO>
 }
+
+// CHECK-LABEL:   func.func @sparse_select(
+// CHECK-SAME:      %[[VAL_0:.*]]: tensor<4x4xi1>,
+// CHECK-SAME:      %[[VAL_1:.*]]: tensor<4x4xf64, #sparse_tensor.encoding<{{.*}}>>,
+// CHECK-SAME:      %[[VAL_2:.*]]: tensor<4x4xf64, #sparse_tensor.encoding<{{.*}}>>) -> tensor<4x4xf64, #sparse_tensor.encoding<{{.*}}>> {
+// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 0.000000e+00 : f64
+// CHECK-DAG:       %[[VAL_4:.*]] = bufferization.alloc_tensor() : tensor<4x4xf64, #sparse_tensor.encoding<{{.*}}>>
+// CHECK-NEXT:      %[[VAL_5:.*]] = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel", "parallel"]}
+// CHECK-SAME:      ins(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]
+// CHECK-NEXT:      ^bb0(%[[VAL_6:.*]]: i1, %[[VAL_7:.*]]: f64, %[[VAL_8:.*]]: f64, %[[VAL_9:.*]]: f64):
+// CHECK-NEXT:        %[[VAL_10:.*]] = sparse_tensor.binary %[[VAL_7]], %[[VAL_8]] : f64, f64 to f64
+// CHECK-NEXT:         overlap = {
+// CHECK-NEXT:        ^bb0(%[[VAL_11:.*]]: f64, %[[VAL_12:.*]]: f64):
+// CHECK-NEXT:          %[[VAL_13:.*]] = arith.select %[[VAL_6]], %[[VAL_11]], %[[VAL_12]] : f64
+// CHECK-NEXT:          sparse_tensor.yield %[[VAL_13]] : f64
+// CHECK-NEXT:        }
+// CHECK-NEXT:         left = {
+// CHECK-NEXT:        ^bb0(%[[VAL_14:.*]]: f64):
+// CHECK-NEXT:          %[[VAL_15:.*]] = arith.select %[[VAL_6]], %[[VAL_14]], %[[VAL_3]] : f64
+// CHECK-NEXT:          sparse_tensor.yield %[[VAL_15]] : f64
+// CHECK-NEXT:        }
+// CHECK-NEXT:         right = {
+// CHECK-NEXT:        ^bb0(%[[VAL_16:.*]]: f64):
+// CHECK-NEXT:          %[[VAL_17:.*]] = arith.select %[[VAL_6]], %[[VAL_3]], %[[VAL_16]] : f64
+// CHECK-NEXT:          sparse_tensor.yield %[[VAL_17]] : f64
+// CHECK-NEXT:        }
+// CHECK-NEXT:        linalg.yield %[[VAL_10]] : f64
+// CHECK-NEXT:      } -> tensor<4x4xf64, #sparse_tensor.encoding<{{.*}}>>
+// CHECK-NEXT:      return %[[VAL_18:.*]] : tensor<4x4xf64, #sparse_tensor.encoding<{{.*}}>>
+// CHECK-NEXT:    }
+func.func @sparse_select(%cond: tensor<4x4xi1>,
+                         %arga: tensor<4x4xf64, #DCSR>,
+                         %argb: tensor<4x4xf64, #DCSR>) -> tensor<4x4xf64, #DCSR> {
+  %xv = bufferization.alloc_tensor() : tensor<4x4xf64, #DCSR>
+  %0 = linalg.generic #sel_trait
+     ins(%cond, %arga, %argb: tensor<4x4xi1>, tensor<4x4xf64, #DCSR>, tensor<4x4xf64, #DCSR>)
+      outs(%xv: tensor<4x4xf64, #DCSR>) {
+      ^bb(%c: i1, %a: f64, %b: f64, %x: f64):
+        %1 = arith.select %c, %a, %b : f64
+        linalg.yield %1 : f64
+  } -> tensor<4x4xf64, #DCSR>
+  return %0 : tensor<4x4xf64, #DCSR>
+}

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_semiring_select.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_semiring_select.mlir
new file mode 100644
index 0000000000000..7648e8d1f7d0f
--- /dev/null
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_semiring_select.mlir
@@ -0,0 +1,97 @@
+// DEFINE: %{option} = enable-runtime-library=true
+// DEFINE: %{compile} = mlir-opt %s --sparse-compiler=%{option}
+// DEFINE: %{run} = mlir-cpu-runner \
+// DEFINE:  -e entry -entry-point-result=void  \
+// DEFINE:  -shared-libs=%mlir_c_runner_utils | \
+// DEFINE: FileCheck %s
+//
+// RUN: %{compile} | %{run}
+//
+// Do the same run, but now with direct IR generation.
+// REDEFINE: %{option} = "enable-runtime-library=false enable-buffer-initialization=true"
+// RUN: %{compile} | %{run}
+
+// Do the same run, but now with direct IR generation and, if available, VLA
+// vectorization.
+// REDEFINE: %{option} = "enable-runtime-library=false enable-buffer-initialization=true vl=4 enable-arm-sve=%ENABLE_VLA"
+// REDEFINE: %{run} = %lli_host_or_aarch64_cmd \
+// REDEFINE:   --entry-function=entry_lli \
+// REDEFINE:   --extra-module=%S/Inputs/main_for_lli.ll \
+// REDEFINE:   %VLA_ARCH_ATTR_OPTIONS \
+// REDEFINE:   --dlopen=%mlir_native_utils_lib_dir/libmlir_c_runner_utils%shlibext | \
+// REDEFINE: FileCheck %s
+// RUN: %{compile} | mlir-translate -mlir-to-llvmir | %{run}
+
+#DCSR = #sparse_tensor.encoding<{
+  lvlTypes = ["compressed", "compressed"]
+}>
+
+#sel_trait = {
+  indexing_maps = [
+    affine_map<(i,j) -> (i,j)>,  // C (in)
+    affine_map<(i,j) -> (i,j)>,  // L (in)
+    affine_map<(i,j) -> (i,j)>,  // R (in)
+    affine_map<(i,j) -> (i,j)>   // X (out)
+  ],
+  iterator_types = ["parallel", "parallel"]
+}
+
+module {
+  func.func @sparse_select(%cond: tensor<5x5xi1>,
+                           %arga: tensor<5x5xf64, #DCSR>,
+                           %argb: tensor<5x5xf64, #DCSR>) -> tensor<5x5xf64, #DCSR> {
+    %xv = bufferization.alloc_tensor() : tensor<5x5xf64, #DCSR>
+    %0 = linalg.generic #sel_trait
+       ins(%cond, %arga, %argb: tensor<5x5xi1>, tensor<5x5xf64, #DCSR>, tensor<5x5xf64, #DCSR>)
+        outs(%xv: tensor<5x5xf64, #DCSR>) {
+        ^bb(%c: i1, %a: f64, %b: f64, %x: f64):
+          %1 = arith.select %c, %a, %b : f64
+          linalg.yield %1 : f64
+    } -> tensor<5x5xf64, #DCSR>
+    return %0 : tensor<5x5xf64, #DCSR>
+  }
+
+  // Driver method to call and verify vector kernels.
+  func.func @entry() {
+    %c0 = arith.constant 0   : index
+    %f0 = arith.constant 0.0 : f64
+
+    %cond = arith.constant sparse<
+        [ [0, 0], [1, 1], [2, 2], [3, 3], [4, 4] ],
+        [     1,      1,      1,      1,      1  ]
+    > : tensor<5x5xi1>
+    %lhs = arith.constant sparse<
+        [ [0, 0], [1, 1], [2, 2], [3, 3], [4, 4] ],
+        [   0.1,    1.1,    2.1,    3.1,    4.1  ]
+    > : tensor<5x5xf64>
+    %rhs = arith.constant sparse<
+        [ [0, 1], [1, 2], [2, 3], [3, 4], [4, 4]],
+        [   1.1,    2.2,    3.3,    4.4 ,   5.5 ]
+    > : tensor<5x5xf64>
+
+    %sl = sparse_tensor.convert %lhs : tensor<5x5xf64> to tensor<5x5xf64, #DCSR>
+    %sr = sparse_tensor.convert %rhs : tensor<5x5xf64> to tensor<5x5xf64, #DCSR>
+
+    // Call sparse matrix kernels.
+    %1 = call @sparse_select(%cond, %sl, %sr) : (tensor<5x5xi1>,
+                                                 tensor<5x5xf64, #DCSR>,
+                                                 tensor<5x5xf64, #DCSR>) -> tensor<5x5xf64, #DCSR>
+
+
+    // CHECK:     ( ( 0.1, 1.1, 0, 0, 0 ),
+    // CHECK-SAME:  ( 0, 1.1, 2.2, 0, 0 ),
+    // CHECK-SAME:  ( 0, 0, 2.1, 3.3, 0 ),
+    // CHECK-SAME:  ( 0, 0, 0, 3.1, 4.4 ),
+    // CHECK-SAME:  ( 0, 0, 0, 0, 4.1 ) )
+    %r = sparse_tensor.convert %1 : tensor<5x5xf64, #DCSR> to tensor<5x5xf64>
+    %v2 = vector.transfer_read %r[%c0, %c0], %f0 : tensor<5x5xf64>, vector<5x5xf64>
+    vector.print %v2 : vector<5x5xf64>
+
+    // Release the resources.
+    bufferization.dealloc_tensor %sl: tensor<5x5xf64, #DCSR>
+    bufferization.dealloc_tensor %sr: tensor<5x5xf64, #DCSR>
+    bufferization.dealloc_tensor %1:  tensor<5x5xf64, #DCSR>
+
+    return
+  }
+}


        


More information about the Mlir-commits mailing list