[Mlir-commits] [mlir] 69edacb - [mlir][sparse] Add support for complex.im and complex.re to the sparse compiler.
Bixia Zheng
llvmlistbot at llvm.org
Wed May 18 08:53:11 PDT 2022
Author: Bixia Zheng
Date: 2022-05-18T15:53:07Z
New Revision: 69edacbcf0c232de6213297cb600b0f0313c6397
URL: https://github.com/llvm/llvm-project/commit/69edacbcf0c232de6213297cb600b0f0313c6397
DIFF: https://github.com/llvm/llvm-project/commit/69edacbcf0c232de6213297cb600b0f0313c6397.diff
LOG: [mlir][sparse] Add support for complex.im and complex.re to the sparse compiler.
Add a test.
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D125834
Added:
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_re_im.mlir
Modified:
mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
index d4aafb74093e..ceee28466449 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
@@ -50,6 +50,8 @@ enum Kind {
kCastU, // unsigned
kCastIdx,
kTruncI,
+ kCIm, // complex.im
+ kCRe, // complex.re
kBitCast,
kBinaryBranch, // semiring unary branch created from a binary op
kUnary, // semiring unary op
diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index c5cd2e4f4cb6..9a589965a528 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -46,6 +46,8 @@ TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v, Operation *o)
case kTanhF:
case kNegF:
case kNegI:
+ case kCIm:
+ case kCRe:
assert(x != -1u && y == -1u && !v && !o);
children.e0 = x;
children.e1 = y;
@@ -291,6 +293,8 @@ bool Merger::isSingleCondition(unsigned t, unsigned e) const {
case kCastU:
case kCastIdx:
case kTruncI:
+ case kCIm:
+ case kCRe:
case kBitCast:
return isSingleCondition(t, tensorExps[e].children.e0);
case kDivF: // note: x / c only
@@ -367,6 +371,10 @@ static const char *kindToOpSymbol(Kind kind) {
case kCastU:
case kCastIdx:
case kTruncI:
+ case kCIm:
+ return "complex.im";
+ case kCRe:
+ return "complex.re";
case kBitCast:
return "cast";
case kBinaryBranch:
@@ -526,6 +534,8 @@ unsigned Merger::buildLattices(unsigned e, unsigned i) {
}
case kAbsF:
case kCeilF:
+ case kCIm:
+ case kCRe:
case kFloorF:
case kSqrtF:
case kExpm1F:
@@ -776,6 +786,10 @@ Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
return addExp(kCastIdx, e, v);
if (isa<arith::TruncIOp>(def))
return addExp(kTruncI, e, v);
+ if (isa<complex::ImOp>(def))
+ return addExp(kCIm, e);
+ if (isa<complex::ReOp>(def))
+ return addExp(kCRe, e);
if (isa<arith::BitcastOp>(def))
return addExp(kBitCast, e, v);
if (isa<sparse_tensor::UnaryOp>(def))
@@ -930,6 +944,15 @@ Value Merger::buildExp(RewriterBase &rewriter, Location loc, unsigned e,
return rewriter.create<arith::IndexCastOp>(loc, inferType(e, v0), v0);
case kTruncI:
return rewriter.create<arith::TruncIOp>(loc, inferType(e, v0), v0);
+ case kCIm:
+ case kCRe: {
+ auto type = v0.getType().template cast<ComplexType>();
+ auto eltType = type.getElementType().template cast<FloatType>();
+ if (tensorExps[e].kind == kCIm)
+ return rewriter.create<complex::ImOp>(loc, eltType, v0);
+
+ return rewriter.create<complex::ReOp>(loc, eltType, v0);
+ }
case kBitCast:
return rewriter.create<arith::BitcastOp>(loc, inferType(e, v0), v0);
// Binary ops.
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_re_im.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_re_im.mlir
new file mode 100644
index 000000000000..2656eb555165
--- /dev/null
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_re_im.mlir
@@ -0,0 +1,93 @@
+// RUN: mlir-opt %s --sparse-compiler | \
+// RUN: mlir-cpu-runner \
+// RUN: -e entry -entry-point-result=void \
+// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+#SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}>
+
+#trait_op = {
+ indexing_maps = [
+ affine_map<(i) -> (i)>, // a (in)
+ affine_map<(i) -> (i)> // x (out)
+ ],
+ iterator_types = ["parallel"],
+ doc = "x(i) = OP a(i)"
+}
+
+module {
+ func.func @cre(%arga: tensor<?xcomplex<f32>, #SparseVector>)
+ -> tensor<?xf32, #SparseVector> {
+ %c = arith.constant 0 : index
+ %d = tensor.dim %arga, %c : tensor<?xcomplex<f32>, #SparseVector>
+ %xv = sparse_tensor.init [%d] : tensor<?xf32, #SparseVector>
+ %0 = linalg.generic #trait_op
+ ins(%arga: tensor<?xcomplex<f32>, #SparseVector>)
+ outs(%xv: tensor<?xf32, #SparseVector>) {
+ ^bb(%a: complex<f32>, %x: f32):
+ %1 = complex.re %a : complex<f32>
+ linalg.yield %1 : f32
+ } -> tensor<?xf32, #SparseVector>
+ return %0 : tensor<?xf32, #SparseVector>
+ }
+
+ func.func @cim(%arga: tensor<?xcomplex<f32>, #SparseVector>)
+ -> tensor<?xf32, #SparseVector> {
+ %c = arith.constant 0 : index
+ %d = tensor.dim %arga, %c : tensor<?xcomplex<f32>, #SparseVector>
+ %xv = sparse_tensor.init [%d] : tensor<?xf32, #SparseVector>
+ %0 = linalg.generic #trait_op
+ ins(%arga: tensor<?xcomplex<f32>, #SparseVector>)
+ outs(%xv: tensor<?xf32, #SparseVector>) {
+ ^bb(%a: complex<f32>, %x: f32):
+ %1 = complex.im %a : complex<f32>
+ linalg.yield %1 : f32
+ } -> tensor<?xf32, #SparseVector>
+ return %0 : tensor<?xf32, #SparseVector>
+ }
+
+ func.func @dump(%arg0: tensor<?xf32, #SparseVector>) {
+ %c0 = arith.constant 0 : index
+ %d0 = arith.constant -1.0 : f32
+ %values = sparse_tensor.values %arg0 : tensor<?xf32, #SparseVector> to memref<?xf32>
+ %0 = vector.transfer_read %values[%c0], %d0: memref<?xf32>, vector<4xf32>
+ vector.print %0 : vector<4xf32>
+ %indices = sparse_tensor.indices %arg0, %c0 : tensor<?xf32, #SparseVector> to memref<?xindex>
+ %1 = vector.transfer_read %indices[%c0], %c0: memref<?xindex>, vector<4xindex>
+ vector.print %1 : vector<4xindex>
+ return
+ }
+
+ // Driver method to call and verify functions cim and cre.
+ func.func @entry() {
+ // Setup sparse vectors.
+ %v1 = arith.constant sparse<
+ [ [0], [20], [31] ],
+ [ (5.13, 2.0), (3.0, 4.0), (5.0, 6.0) ] > : tensor<32xcomplex<f32>>
+ %sv1 = sparse_tensor.convert %v1 : tensor<32xcomplex<f32>> to tensor<?xcomplex<f32>, #SparseVector>
+
+ // Call sparse vector kernels.
+ %0 = call @cre(%sv1)
+ : (tensor<?xcomplex<f32>, #SparseVector>) -> tensor<?xf32, #SparseVector>
+
+ %1 = call @cim(%sv1)
+ : (tensor<?xcomplex<f32>, #SparseVector>) -> tensor<?xf32, #SparseVector>
+
+ //
+ // Verify the results.
+ //
+ // CHECK: ( 5.13, 3, 5, -1 )
+ // CHECK-NEXT: ( 0, 20, 31, 0 )
+ // CHECK-NEXT: ( 2, 4, 6, -1 )
+ // CHECK-NEXT: ( 0, 20, 31, 0 )
+ //
+ call @dump(%0) : (tensor<?xf32, #SparseVector>) -> ()
+ call @dump(%1) : (tensor<?xf32, #SparseVector>) -> ()
+
+ // Release the resources.
+ sparse_tensor.release %sv1 : tensor<?xcomplex<f32>, #SparseVector>
+ sparse_tensor.release %0 : tensor<?xf32, #SparseVector>
+ sparse_tensor.release %1 : tensor<?xf32, #SparseVector>
+ return
+ }
+}
More information about the Mlir-commits
mailing list