[Mlir-commits] [mlir] e7df828 - [mlir][sparse] rewrite arith::SelectOp to semiring operations to sparsify it.
Peiming Liu
llvmlistbot at llvm.org
Wed Jun 21 14:22:22 PDT 2023
Author: Peiming Liu
Date: 2023-06-21T21:22:18Z
New Revision: e7df82816b6af3e2929a703718d9ef9dcd55b5f4
URL: https://github.com/llvm/llvm-project/commit/e7df82816b6af3e2929a703718d9ef9dcd55b5f4
DIFF: https://github.com/llvm/llvm-project/commit/e7df82816b6af3e2929a703718d9ef9dcd55b5f4.diff
LOG: [mlir][sparse] rewrite arith::SelectOp to semiring operations to sparsify it.
Reviewed By: aartbik, K-Wu
Differential Revision: https://reviews.llvm.org/D153397
Added:
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_semiring_select.mlir
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
mlir/test/Dialect/SparseTensor/pre_rewriting.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index bb9c15ea463da..ebbe88ee90294 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -42,10 +42,13 @@ static bool isZeroValue(Value val) {
}
// Helper to detect a sparse tensor type operand.
-static bool isSparseTensor(OpOperand *op) {
- auto enc = getSparseTensorEncoding(op->get().getType());
- return enc && llvm::is_contained(enc.getLvlTypes(), DimLevelType::Compressed);
+static bool isSparseTensor(Value v) {
+ auto enc = getSparseTensorEncoding(v.getType());
+ return enc && !llvm::all_of(enc.getLvlTypes(), [](auto dlt) {
+ return dlt == DimLevelType::Dense;
+ });
}
+static bool isSparseTensor(OpOperand *op) { return isSparseTensor(op->get()); }
// Helper method to find zero/uninitialized allocation.
static bool isAlloc(OpOperand *op, bool isZero) {
@@ -387,6 +390,137 @@ struct FuseTensorCast : public OpRewritePattern<tensor::CastOp> {
}
};
+/// Rewrites a sequence of operations for sparse tensor selections in to
+/// semi-ring operations such that they can be compiled correctly by the sparse
+/// compiler. E.g., transforming the following sequence
+///
+/// %sel = arith.select %cond, %sp1, %sp2
+///
+/// to
+///
+/// %sel = binary %sp1, %sp2:
+/// both (%l, %r) {yield select %cond, %l, %r}
+/// left (%l) {yield select %cond, %l, 0}
+/// right (%r) {yield select %cond, 0, %r}
+///
+/// TODO: We require that the tensor used for extracting conditions to be dense
+/// to sparsify the code. To support a sparse condition tensor, we need a
+/// tri-nary operation.
+struct GenSemiRingSelect : public OpRewritePattern<GenericOp> {
+public:
+ using OpRewritePattern<GenericOp>::OpRewritePattern;
+ LogicalResult matchAndRewrite(GenericOp op,
+ PatternRewriter &rewriter) const override {
+ // Rejects non sparse kernels.
+ if (!op.hasTensorSemantics() || !hasAnySparseOperand(op))
+ return failure();
+
+ Location loc = op.getLoc();
+ SmallVector<std::pair<Operation *, sparse_tensor::BinaryOp>> semiRings;
+ for (Operation &inst : *op.getBody()) {
+ // Matches pattern.
+ auto matched = isRewritablePattern(op, &inst);
+ if (!matched.has_value())
+ continue;
+
+ rewriter.setInsertionPoint(&inst);
+ auto [c, t, f] = matched.value();
+ assert(t.getType() == f.getType());
+ auto selTp = t.getType();
+ auto c0 = constantZero(rewriter, loc, selTp);
+ auto binOp = rewriter.create<sparse_tensor::BinaryOp>(loc, selTp, t, f);
+ // Initializes all the blocks.
+ rewriter.createBlock(&binOp.getOverlapRegion(), {}, {selTp, selTp},
+ {t.getLoc(), f.getLoc()});
+ rewriter.createBlock(&binOp.getRightRegion(), {}, selTp, f.getLoc());
+ rewriter.createBlock(&binOp.getLeftRegion(), {}, selTp, t.getLoc());
+
+ for (auto *r : binOp.getRegions()) {
+ Block *b = &r->front();
+ rewriter.setInsertionPointToStart(b);
+
+ IRMapping irMap;
+ // Clones the cmp operations into the region to make the binary op
+ // admissible.
+ Value newC = c;
+ if (auto *def = c.getDefiningOp())
+ newC = rewriter.clone(*def, irMap)->getResult(0);
+
+ irMap.map(c, newC);
+ if (r == &binOp.getLeftRegion()) {
+ irMap.map(t, b->getArgument(0));
+ irMap.map(f, c0);
+ } else if (r == &binOp.getRightRegion()) {
+ irMap.map(t, c0);
+ irMap.map(f, b->getArgument(0));
+ } else {
+ irMap.map(t, b->getArgument(0));
+ irMap.map(f, b->getArgument(1));
+ }
+ auto y = rewriter.clone(inst, irMap)->getResult(0);
+ rewriter.create<sparse_tensor::YieldOp>(loc, y);
+ }
+
+ // We successfully rewrited a operation. We can not do replacement here
+ // becuase it invalidate the iterator for the current loop to traverse
+ // the instructions.
+ semiRings.emplace_back(&inst, binOp);
+ }
+
+ // Finalizes the replacement.
+ for (auto [sel, semi] : semiRings)
+ rewriter.replaceOp(sel, semi->getResults());
+
+ return success(!semiRings.empty());
+ }
+
+private:
+ static std::optional<std::tuple<Value, BlockArgument, BlockArgument>>
+ isRewritablePattern(GenericOp op, Operation *v) {
+ auto sel = dyn_cast<arith::SelectOp>(v);
+ if (!sel)
+ return std::nullopt;
+
+ auto tVal = sel.getTrueValue().dyn_cast<BlockArgument>();
+ auto fVal = sel.getFalseValue().dyn_cast<BlockArgument>();
+ // TODO: For simplicity, we only handle cases where both true/false value
+ // are directly loaded the input tensor. We can probably admit more cases
+ // in theory.
+ if (!tVal || !fVal)
+ return std::nullopt;
+
+ // Helper lambda to determine whether the value is loaded from a dense input
+ // or is a loop invariant.
+ auto isValFromDenseInputOrInvariant = [&op](Value v) -> bool {
+ if (auto bArg = v.dyn_cast<BlockArgument>();
+ bArg && !isSparseTensor(op.getDpsInputOperand(bArg.getArgNumber())))
+ return true;
+ // If the value is defined outside the loop, it is a loop invariant.
+ return v.getDefiningOp() && v.getDefiningOp()->getBlock() != op.getBody();
+ };
+
+ // If the condition value is load directly from a dense tensor or
+ // loop-invariants, we can sparsify the kernel.
+ auto cond = sel.getCondition();
+ if (isValFromDenseInputOrInvariant(cond))
+ return std::make_tuple(cond, tVal, fVal);
+
+ Value cmpL, cmpR;
+ if (matchPattern(cond, m_Op<arith::CmpIOp>(matchers::m_Any(&cmpL),
+ matchers::m_Any(&cmpR))) ||
+ matchPattern(cond, m_Op<arith::CmpFOp>(matchers::m_Any(&cmpL),
+ matchers::m_Any(&cmpR)))) {
+ // TODO: we can do it recursively to check whether all the leaf values are
+ // loaded from dense tensors or are loop invariants.
+ if (isValFromDenseInputOrInvariant(cmpL) ||
+ isValFromDenseInputOrInvariant(cmpR))
+ return std::make_tuple(cond, tVal, fVal);
+ }
+
+ return std::nullopt;
+ };
+};
+
/// Rewrites a sparse reduction that would not sparsify directly since
/// doing so would only iterate over the stored elements, ignoring the
/// implicit zeros, into a semi-ring. Applies to all prod/and/min/max
@@ -1348,7 +1482,7 @@ struct OutRewriter : public OpRewritePattern<OutOp> {
void mlir::populatePreSparsificationRewriting(RewritePatternSet &patterns) {
patterns.add<FoldInvariantYield, FuseSparseMultiplyOverAdd, FuseTensorCast,
- GenSemiRingReduction>(patterns.getContext());
+ GenSemiRingReduction, GenSemiRingSelect>(patterns.getContext());
}
void mlir::populatePostSparsificationRewriting(RewritePatternSet &patterns,
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 881e02ea0f91c..b767f61598ff7 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -1211,7 +1211,8 @@ static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e,
if (ee &&
(kind == TensorExp::Kind::kUnary || kind == TensorExp::Kind::kBinary ||
kind == TensorExp::Kind::kBinaryBranch ||
- kind == TensorExp::Kind::kReduce || kind == TensorExp::Kind::kSelect)) {
+ kind == TensorExp::Kind::kReduce ||
+ kind == TensorExp::Kind::kSelect)) {
OpBuilder::InsertionGuard guard(rewriter);
ee = relinkBranch(env, rewriter, ee.getParentBlock(), ee, ldx);
}
diff --git a/mlir/test/Dialect/SparseTensor/pre_rewriting.mlir b/mlir/test/Dialect/SparseTensor/pre_rewriting.mlir
index 8aed1d6d205bd..feda768ed29f7 100644
--- a/mlir/test/Dialect/SparseTensor/pre_rewriting.mlir
+++ b/mlir/test/Dialect/SparseTensor/pre_rewriting.mlir
@@ -8,11 +8,25 @@
lvlTypes = [ "compressed-nu", "singleton" ]
}>
+#DCSR = #sparse_tensor.encoding<{
+ lvlTypes = ["compressed", "compressed"]
+}>
+
#Slice = #sparse_tensor.encoding<{
lvlTypes = [ "compressed-nu", "singleton" ],
dimSlices = [ (?, 1, 1), (?, 3, 1) ]
}>
+#sel_trait = {
+ indexing_maps = [
+ affine_map<(i,j) -> (i,j)>, // C (in)
+ affine_map<(i,j) -> (i,j)>, // L (in)
+ affine_map<(i,j) -> (i,j)>, // R (in)
+ affine_map<(i,j) -> (i,j)> // X (out)
+ ],
+ iterator_types = ["parallel", "parallel"]
+}
+
// CHECK-LABEL: func @sparse_nop_cast(
// CHECK-SAME: %[[A:.*]]: tensor<?xf32, #sparse_tensor.encoding<{{{.*}}}>>)
// CHECK: return %[[A]] : tensor<?xf32, #sparse_tensor.encoding<{{{.*}}}>>
@@ -43,3 +57,46 @@ func.func @sparse_fuse_slice(%a : tensor<2x3xi64, #SortedCOO>) -> tensor<1x3xi64
%0 = sparse_tensor.convert %cast : tensor<1x3xi64, #Slice> to tensor<1x3xi64, #SortedCOO>
return %0 : tensor<1x3xi64, #SortedCOO>
}
+
+// CHECK-LABEL: func.func @sparse_select(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x4xi1>,
+// CHECK-SAME: %[[VAL_1:.*]]: tensor<4x4xf64, #sparse_tensor.encoding<{{.*}}>>,
+// CHECK-SAME: %[[VAL_2:.*]]: tensor<4x4xf64, #sparse_tensor.encoding<{{.*}}>>) -> tensor<4x4xf64, #sparse_tensor.encoding<{{.*}}>> {
+// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0.000000e+00 : f64
+// CHECK-DAG: %[[VAL_4:.*]] = bufferization.alloc_tensor() : tensor<4x4xf64, #sparse_tensor.encoding<{{.*}}>>
+// CHECK-NEXT: %[[VAL_5:.*]] = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel", "parallel"]}
+// CHECK-SAME: ins(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]
+// CHECK-NEXT: ^bb0(%[[VAL_6:.*]]: i1, %[[VAL_7:.*]]: f64, %[[VAL_8:.*]]: f64, %[[VAL_9:.*]]: f64):
+// CHECK-NEXT: %[[VAL_10:.*]] = sparse_tensor.binary %[[VAL_7]], %[[VAL_8]] : f64, f64 to f64
+// CHECK-NEXT: overlap = {
+// CHECK-NEXT: ^bb0(%[[VAL_11:.*]]: f64, %[[VAL_12:.*]]: f64):
+// CHECK-NEXT: %[[VAL_13:.*]] = arith.select %[[VAL_6]], %[[VAL_11]], %[[VAL_12]] : f64
+// CHECK-NEXT: sparse_tensor.yield %[[VAL_13]] : f64
+// CHECK-NEXT: }
+// CHECK-NEXT: left = {
+// CHECK-NEXT: ^bb0(%[[VAL_14:.*]]: f64):
+// CHECK-NEXT: %[[VAL_15:.*]] = arith.select %[[VAL_6]], %[[VAL_14]], %[[VAL_3]] : f64
+// CHECK-NEXT: sparse_tensor.yield %[[VAL_15]] : f64
+// CHECK-NEXT: }
+// CHECK-NEXT: right = {
+// CHECK-NEXT: ^bb0(%[[VAL_16:.*]]: f64):
+// CHECK-NEXT: %[[VAL_17:.*]] = arith.select %[[VAL_6]], %[[VAL_3]], %[[VAL_16]] : f64
+// CHECK-NEXT: sparse_tensor.yield %[[VAL_17]] : f64
+// CHECK-NEXT: }
+// CHECK-NEXT: linalg.yield %[[VAL_10]] : f64
+// CHECK-NEXT: } -> tensor<4x4xf64, #sparse_tensor.encoding<{{.*}}>>
+// CHECK-NEXT: return %[[VAL_18:.*]] : tensor<4x4xf64, #sparse_tensor.encoding<{{.*}}>>
+// CHECK-NEXT: }
+func.func @sparse_select(%cond: tensor<4x4xi1>,
+ %arga: tensor<4x4xf64, #DCSR>,
+ %argb: tensor<4x4xf64, #DCSR>) -> tensor<4x4xf64, #DCSR> {
+ %xv = bufferization.alloc_tensor() : tensor<4x4xf64, #DCSR>
+ %0 = linalg.generic #sel_trait
+ ins(%cond, %arga, %argb: tensor<4x4xi1>, tensor<4x4xf64, #DCSR>, tensor<4x4xf64, #DCSR>)
+ outs(%xv: tensor<4x4xf64, #DCSR>) {
+ ^bb(%c: i1, %a: f64, %b: f64, %x: f64):
+ %1 = arith.select %c, %a, %b : f64
+ linalg.yield %1 : f64
+ } -> tensor<4x4xf64, #DCSR>
+ return %0 : tensor<4x4xf64, #DCSR>
+}
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_semiring_select.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_semiring_select.mlir
new file mode 100644
index 0000000000000..7648e8d1f7d0f
--- /dev/null
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_semiring_select.mlir
@@ -0,0 +1,97 @@
+// DEFINE: %{option} = enable-runtime-library=true
+// DEFINE: %{compile} = mlir-opt %s --sparse-compiler=%{option}
+// DEFINE: %{run} = mlir-cpu-runner \
+// DEFINE: -e entry -entry-point-result=void \
+// DEFINE: -shared-libs=%mlir_c_runner_utils | \
+// DEFINE: FileCheck %s
+//
+// RUN: %{compile} | %{run}
+//
+// Do the same run, but now with direct IR generation.
+// REDEFINE: %{option} = "enable-runtime-library=false enable-buffer-initialization=true"
+// RUN: %{compile} | %{run}
+
+// Do the same run, but now with direct IR generation and, if available, VLA
+// vectorization.
+// REDEFINE: %{option} = "enable-runtime-library=false enable-buffer-initialization=true vl=4 enable-arm-sve=%ENABLE_VLA"
+// REDEFINE: %{run} = %lli_host_or_aarch64_cmd \
+// REDEFINE: --entry-function=entry_lli \
+// REDEFINE: --extra-module=%S/Inputs/main_for_lli.ll \
+// REDEFINE: %VLA_ARCH_ATTR_OPTIONS \
+// REDEFINE: --dlopen=%mlir_native_utils_lib_dir/libmlir_c_runner_utils%shlibext | \
+// REDEFINE: FileCheck %s
+// RUN: %{compile} | mlir-translate -mlir-to-llvmir | %{run}
+
+#DCSR = #sparse_tensor.encoding<{
+ lvlTypes = ["compressed", "compressed"]
+}>
+
+#sel_trait = {
+ indexing_maps = [
+ affine_map<(i,j) -> (i,j)>, // C (in)
+ affine_map<(i,j) -> (i,j)>, // L (in)
+ affine_map<(i,j) -> (i,j)>, // R (in)
+ affine_map<(i,j) -> (i,j)> // X (out)
+ ],
+ iterator_types = ["parallel", "parallel"]
+}
+
+module {
+ func.func @sparse_select(%cond: tensor<5x5xi1>,
+ %arga: tensor<5x5xf64, #DCSR>,
+ %argb: tensor<5x5xf64, #DCSR>) -> tensor<5x5xf64, #DCSR> {
+ %xv = bufferization.alloc_tensor() : tensor<5x5xf64, #DCSR>
+ %0 = linalg.generic #sel_trait
+ ins(%cond, %arga, %argb: tensor<5x5xi1>, tensor<5x5xf64, #DCSR>, tensor<5x5xf64, #DCSR>)
+ outs(%xv: tensor<5x5xf64, #DCSR>) {
+ ^bb(%c: i1, %a: f64, %b: f64, %x: f64):
+ %1 = arith.select %c, %a, %b : f64
+ linalg.yield %1 : f64
+ } -> tensor<5x5xf64, #DCSR>
+ return %0 : tensor<5x5xf64, #DCSR>
+ }
+
+ // Driver method to call and verify vector kernels.
+ func.func @entry() {
+ %c0 = arith.constant 0 : index
+ %f0 = arith.constant 0.0 : f64
+
+ %cond = arith.constant sparse<
+ [ [0, 0], [1, 1], [2, 2], [3, 3], [4, 4] ],
+ [ 1, 1, 1, 1, 1 ]
+ > : tensor<5x5xi1>
+ %lhs = arith.constant sparse<
+ [ [0, 0], [1, 1], [2, 2], [3, 3], [4, 4] ],
+ [ 0.1, 1.1, 2.1, 3.1, 4.1 ]
+ > : tensor<5x5xf64>
+ %rhs = arith.constant sparse<
+ [ [0, 1], [1, 2], [2, 3], [3, 4], [4, 4]],
+ [ 1.1, 2.2, 3.3, 4.4 , 5.5 ]
+ > : tensor<5x5xf64>
+
+ %sl = sparse_tensor.convert %lhs : tensor<5x5xf64> to tensor<5x5xf64, #DCSR>
+ %sr = sparse_tensor.convert %rhs : tensor<5x5xf64> to tensor<5x5xf64, #DCSR>
+
+ // Call sparse matrix kernels.
+ %1 = call @sparse_select(%cond, %sl, %sr) : (tensor<5x5xi1>,
+ tensor<5x5xf64, #DCSR>,
+ tensor<5x5xf64, #DCSR>) -> tensor<5x5xf64, #DCSR>
+
+
+ // CHECK: ( ( 0.1, 1.1, 0, 0, 0 ),
+ // CHECK-SAME: ( 0, 1.1, 2.2, 0, 0 ),
+ // CHECK-SAME: ( 0, 0, 2.1, 3.3, 0 ),
+ // CHECK-SAME: ( 0, 0, 0, 3.1, 4.4 ),
+ // CHECK-SAME: ( 0, 0, 0, 0, 4.1 ) )
+ %r = sparse_tensor.convert %1 : tensor<5x5xf64, #DCSR> to tensor<5x5xf64>
+ %v2 = vector.transfer_read %r[%c0, %c0], %f0 : tensor<5x5xf64>, vector<5x5xf64>
+ vector.print %v2 : vector<5x5xf64>
+
+ // Release the resources.
+ bufferization.dealloc_tensor %sl: tensor<5x5xf64, #DCSR>
+ bufferization.dealloc_tensor %sr: tensor<5x5xf64, #DCSR>
+ bufferization.dealloc_tensor %1: tensor<5x5xf64, #DCSR>
+
+ return
+ }
+}
More information about the Mlir-commits
mailing list