[Mlir-commits] [mlir] c374ef2 - [mlir][sparse] Extend the operator new rewriter to handle isSymmetric flag.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Nov 17 10:48:29 PST 2022
Author: bixia1
Date: 2022-11-17T10:48:24-08:00
New Revision: c374ef2eb7768b2994aff6a49b55f729df2fb9c8
URL: https://github.com/llvm/llvm-project/commit/c374ef2eb7768b2994aff6a49b55f729df2fb9c8
DIFF: https://github.com/llvm/llvm-project/commit/c374ef2eb7768b2994aff6a49b55f729df2fb9c8.diff
LOG: [mlir][sparse] Extend the operator new rewriter to handle isSymmetric flag.
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D138214
Added:
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
mlir/test/Dialect/SparseTensor/rewriting_for_codegen.mlir
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sum.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 29430d35f7108..5fa2b4ebbfd2c 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -911,6 +911,16 @@ struct NewRewriter : public OpRewritePattern<NewOp> {
Value nnz = createFuncCall(rewriter, loc, "getSparseTensorReaderNNZ",
{indexTp}, {reader}, EmitCInterface::Off)
.getResult(0);
+ Value symmetric;
+ // We assume only rank 2 tensors may have the isSymmetric flag set.
+ if (rank == 2) {
+ symmetric =
+ createFuncCall(rewriter, loc, "getSparseTensorReaderIsSymmetric",
+ {rewriter.getI1Type()}, {reader}, EmitCInterface::Off)
+ .getResult(0);
+ } else {
+ symmetric = Value();
+ }
Type eltTp = dstTp.getElementType();
Value value = genAllocaScalar(rewriter, loc, eltTp);
scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, c0, nnz, c1,
@@ -929,8 +939,23 @@ struct NewRewriter : public OpRewritePattern<NewOp> {
loc, indices, constantIndex(rewriter, loc, i)));
}
Value v = rewriter.create<memref::LoadOp>(loc, value);
- auto t = rewriter.create<InsertOp>(loc, v, forOp.getRegionIterArg(0),
- indicesArray);
+ Value t = rewriter.create<InsertOp>(loc, v, forOp.getRegionIterArg(0),
+ indicesArray);
+ if (symmetric) {
+ Value eq = rewriter.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::ne, indicesArray[0], indicesArray[1]);
+ Value cond = rewriter.create<arith::AndIOp>(loc, symmetric, eq);
+ scf::IfOp ifOp =
+ rewriter.create<scf::IfOp>(loc, t.getType(), cond, /*else*/ true);
+ rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
+ rewriter.create<scf::YieldOp>(
+ loc, Value(rewriter.create<InsertOp>(
+ loc, v, t, ValueRange{indicesArray[1], indicesArray[0]})));
+ rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
+ rewriter.create<scf::YieldOp>(loc, t);
+ t = ifOp.getResult(0);
+ rewriter.setInsertionPointAfter(ifOp);
+ }
rewriter.create<scf::YieldOp>(loc, ArrayRef<Value>(t));
rewriter.setInsertionPointAfter(forOp);
// Link SSA chain.
diff --git a/mlir/test/Dialect/SparseTensor/rewriting_for_codegen.mlir b/mlir/test/Dialect/SparseTensor/rewriting_for_codegen.mlir
index 94c373bab9972..31b35d75733c1 100644
--- a/mlir/test/Dialect/SparseTensor/rewriting_for_codegen.mlir
+++ b/mlir/test/Dialect/SparseTensor/rewriting_for_codegen.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -post-sparsification-rewrite="enable-runtime-library=false enable-convert=false" |\
+// RUN: mlir-opt %s -post-sparsification-rewrite="enable-runtime-library=false enable-convert=false" | \
// RUN: FileCheck %s
#CSR = #sparse_tensor.encoding<{
@@ -17,6 +17,7 @@
// CHECK: %[[D1:.*]] = memref.load %[[DS]]{{\[}}%[[C1]]]
// CHECK: %[[T:.*]] = bufferization.alloc_tensor(%[[D0]], %[[D1]])
// CHECK: %[[N:.*]] = call @getSparseTensorReaderNNZ(%[[R]])
+// CHECK: %[[S:.*]] = call @getSparseTensorReaderIsSymmetric(%[[R]])
// CHECK: %[[VB:.*]] = memref.alloca()
// CHECK: %[[T2:.*]] = scf.for %{{.*}} = %[[C0]] to %[[N]] step %[[C1]] iter_args(%[[A2:.*]] = %[[T]])
// CHECK: func.call @getSparseTensorReaderNextF32(%[[R]], %[[DS]], %[[VB]])
@@ -24,12 +25,19 @@
// CHECK: %[[E1:.*]] = memref.load %[[DS]]{{\[}}%[[C1]]]
// CHECK: %[[V:.*]] = memref.load %[[VB]][]
// CHECK: %[[T1:.*]] = sparse_tensor.insert %[[V]] into %[[A2]]{{\[}}%[[E0]], %[[E1]]]
-// CHECK: scf.yield %[[T1]]
+// CHECK: %[[NE:.*]] = arith.cmpi ne, %[[E0]], %[[E1]]
+// CHECK: %[[COND:.*]] = arith.andi %[[S]], %[[NE]]
+// CHECK: %[[T3:.*]] = scf.if %[[COND]]
+// CHECK: %[[T4:.*]] = sparse_tensor.insert %[[V]] into %[[T1]]{{\[}}%[[E1]], %[[E0]]]
+// CHECK: scf.yield %[[T4]]
+// CHECK: else
+// CHECK: scf.yield %[[T1]]
+// CHECK: scf.yield %[[T3]]
// CHECK: }
// CHECK: call @delSparseTensorReader(%[[R]])
-// CHECK: %[[T3:.*]] = sparse_tensor.load %[[T2]] hasInserts
-// CHECK: %[[R:.*]] = sparse_tensor.convert %[[T3]]
-// CHECK: bufferization.dealloc_tensor %[[T3]]
+// CHECK: %[[T5:.*]] = sparse_tensor.load %[[T2]] hasInserts
+// CHECK: %[[R:.*]] = sparse_tensor.convert %[[T5]]
+// CHECK: bufferization.dealloc_tensor %[[T5]]
// CHECK: return %[[R]]
func.func @sparse_new(%arg0: !llvm.ptr<i8>) -> tensor<?x?xf32, #CSR> {
%0 = sparse_tensor.new %arg0 : !llvm.ptr<i8> to tensor<?x?xf32, #CSR>
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sum.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sum.mlir
index c6b4e03590d40..13a3ae25c0f0d 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sum.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sum.mlir
@@ -1,9 +1,16 @@
-// RUN: mlir-opt %s --sparse-compiler | \
-// RUN: TENSOR0="%mlir_src_dir/test/Integration/data/test_symmetric.mtx" \
-// RUN: mlir-cpu-runner \
-// RUN: -e entry -entry-point-result=void \
-// RUN: -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \
-// RUN: FileCheck %s
+// DEFINE: %{option} = enable-runtime-library=true
+// DEFINE: %{command} = mlir-opt %s --sparse-compiler=%{option} | \
+// DEFINE: TENSOR0="%mlir_src_dir/test/Integration/data/test_symmetric.mtx" \
+// DEFINE: mlir-cpu-runner \
+// DEFINE: -e entry -entry-point-result=void \
+// DEFINE: -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \
+// DEFINE: FileCheck %s
+//
+// RUN: %{command}
+//
+// Do the same run, but now with direct IR generation.
+// REDEFINE: %{option} = enable-runtime-library=false
+// RUN: %{command}
!Filename = !llvm.ptr<i8>
More information about the Mlir-commits
mailing list