[Mlir-commits] [mlir] 28b6d41 - [mlir][sparse] add support for complex zero/one building
Aart Bik
llvmlistbot at llvm.org
Fri May 20 08:53:48 PDT 2022
Author: Aart Bik
Date: 2022-05-20T08:53:30-07:00
New Revision: 28b6d412afc52d46e44960de429e6688826d9f4f
URL: https://github.com/llvm/llvm-project/commit/28b6d412afc52d46e44960de429e6688826d9f4f
DIFF: https://github.com/llvm/llvm-project/commit/28b6d412afc52d46e44960de429e6688826d9f4f.diff
LOG: [mlir][sparse] add support for complex zero/one building
Reviewed By: bixia
Differential Revision: https://reviews.llvm.org/D126039
Added:
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
mlir/test/Dialect/SparseTensor/conversion.mlir
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
index 51d7bd4fa8728..7e7ed55416360 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
@@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
LINK_LIBS PUBLIC
MLIRArithmetic
MLIRBufferization
+ MLIRComplex
MLIRFunc
MLIRIR
MLIRLLVMIR
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
index 5eb0785baf938..cf574fdabb7ae 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
@@ -189,5 +189,7 @@ Value mlir::sparse_tensor::genIsNonzero(OpBuilder &builder, mlir::Location loc,
if (tp.isIntOrIndex())
return builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, v,
zero);
+ if (tp.dyn_cast<ComplexType>())
+ return builder.create<complex::NotEqualOp>(loc, v, zero);
llvm_unreachable("Non-numeric type");
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
index 944605bab4b9f..ba897f7f7c14b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
@@ -14,6 +14,7 @@
#define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_CODEGENUTILS_H_
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/ExecutionEngine/SparseTensorUtils.h"
#include "mlir/IR/Builders.h"
@@ -102,16 +103,27 @@ Value genIsNonzero(OpBuilder &builder, Location loc, Value v);
//===----------------------------------------------------------------------===//
/// Generates a 0-valued constant of the given type. In addition to
-/// the scalar types (`FloatType`, `IndexType`, `IntegerType`), this also
-/// works for `RankedTensorType` and `VectorType` (for which it generates
-/// a constant `DenseElementsAttr` of zeros).
+/// the scalar types (`ComplexType`, ``FloatType`, `IndexType`, `IntegerType`),
+/// this also works for `RankedTensorType` and `VectorType` (for which it
+/// generates a constant `DenseElementsAttr` of zeros).
inline Value constantZero(OpBuilder &builder, Location loc, Type tp) {
+ if (auto ctp = tp.dyn_cast<ComplexType>()) {
+ auto zeroe = builder.getZeroAttr(ctp.getElementType());
+ auto zeroa = builder.getArrayAttr({zeroe, zeroe});
+ return builder.create<complex::ConstantOp>(loc, tp, zeroa);
+ }
return builder.create<arith::ConstantOp>(loc, tp, builder.getZeroAttr(tp));
}
/// Generates a 1-valued constant of the given type. This supports all
/// the same types as `constantZero`.
inline Value constantOne(OpBuilder &builder, Location loc, Type tp) {
+ if (auto ctp = tp.dyn_cast<ComplexType>()) {
+ auto zeroe = builder.getZeroAttr(ctp.getElementType());
+ auto onee = getOneAttr(builder, ctp.getElementType());
+ auto zeroa = builder.getArrayAttr({onee, zeroe});
+ return builder.create<complex::ConstantOp>(loc, tp, zeroa);
+ }
return builder.create<arith::ConstantOp>(loc, tp, getOneAttr(builder, tp));
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index 2b957e6fbad60..8c1f296a39e98 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -9,6 +9,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
@@ -112,7 +113,8 @@ struct SparseTensorConversionPass
// The following operations and dialects may be introduced by the
// rewriting rules, and are therefore marked as legal.
target.addLegalOp<arith::CmpFOp, arith::CmpIOp, arith::ConstantOp,
- arith::IndexCastOp, linalg::FillOp, linalg::YieldOp,
+ arith::IndexCastOp, complex::ConstantOp,
+ complex::NotEqualOp, linalg::FillOp, linalg::YieldOp,
tensor::ExtractOp>();
target
.addLegalDialect<bufferization::BufferizationDialect, LLVM::LLVMDialect,
diff --git a/mlir/test/Dialect/SparseTensor/conversion.mlir b/mlir/test/Dialect/SparseTensor/conversion.mlir
index cdd1536614742..a1838141cc0ef 100644
--- a/mlir/test/Dialect/SparseTensor/conversion.mlir
+++ b/mlir/test/Dialect/SparseTensor/conversion.mlir
@@ -173,9 +173,10 @@ func.func @sparse_nop_cast(%arg0: tensor<64xf32, #SparseVector>) -> tensor<?xf32
}
// CHECK-LABEL: func @sparse_convert_1d(
-// CHECK-SAME: %[[A:.*]]: tensor<?xi32>) -> !llvm.ptr<i8>
+// CHECK-SAME: %[[A:.*]]: tensor<?xi32>) -> !llvm.ptr<i8> {
// CHECK-DAG: %[[EmptyCOO:.*]] = arith.constant 4 : i32
// CHECK-DAG: %[[FromCOO:.*]] = arith.constant 2 : i32
+// CHECK-DAG: %[[I0:.*]] = arith.constant 0 : i32
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[U:.*]] = tensor.dim %[[A]], %[[C0]] : tensor<?xi32>
@@ -191,8 +192,11 @@ func.func @sparse_nop_cast(%arg0: tensor<64xf32, #SparseVector>) -> tensor<?xf32
// CHECK: %[[T:.*]] = memref.cast %[[M]] : memref<1xindex> to memref<?xindex>
// CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[U]] step %[[C1]] {
// CHECK: %[[E:.*]] = tensor.extract %[[A]][%[[I]]] : tensor<?xi32>
-// CHECK: memref.store %[[I]], %[[M]][%[[C0]]] : memref<1xindex>
-// CHECK: call @addEltI32(%[[C]], %[[E]], %[[T]], %[[Z]])
+// CHECK: %[[N:.*]] = arith.cmpi ne, %[[E]], %[[I0]] : i32
+// CHECK: scf.if %[[N]] {
+// CHECK: memref.store %[[I]], %[[M]][%[[C0]]] : memref<1xindex>
+// CHECK: call @addEltI32(%[[C]], %[[E]], %[[T]], %[[Z]])
+// CHECK: }
// CHECK: }
// CHECK: %[[T:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %[[FromCOO]], %[[C]])
// CHECK: call @delSparseTensorCOOI32(%[[C]])
@@ -202,6 +206,28 @@ func.func @sparse_convert_1d(%arg0: tensor<?xi32>) -> tensor<?xi32, #SparseVecto
return %0 : tensor<?xi32, #SparseVector>
}
+// CHECK-LABEL: func @sparse_convert_complex(
+// CHECK-SAME: %[[A:.*]]: tensor<100xcomplex<f64>>) -> !llvm.ptr<i8> {
+// CHECK-DAG: %[[CC:.*]] = complex.constant [0.000000e+00, 0.000000e+00] : complex<f64>
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C100:.*]] = arith.constant 100 : index
+// CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[C100]] step %[[C1]] {
+// CHECK: %[[E:.*]] = tensor.extract %[[A]][%[[I]]] : tensor<100xcomplex<f64>>
+// CHECK: %[[N:.*]] = complex.neq %[[E]], %[[CC]] : complex<f64>
+// CHECK: scf.if %[[N]] {
+// CHECK: memref.store %[[I]], %{{.*}}[%[[C0]]] : memref<1xindex>
+// CHECK: call @addEltC64
+// CHECK: }
+// CHECK: }
+// CHECK: %[[T:.*]] = call @newSparseTensor
+// CHECK: call @delSparseTensorCOOC64
+// CHECK: return %[[T]] : !llvm.ptr<i8>
+func.func @sparse_convert_complex(%arg0: tensor<100xcomplex<f64>>) -> tensor<100xcomplex<f64>, #SparseVector> {
+ %0 = sparse_tensor.convert %arg0 : tensor<100xcomplex<f64>> to tensor<100xcomplex<f64>, #SparseVector>
+ return %0 : tensor<100xcomplex<f64>, #SparseVector>
+}
+
// CHECK-LABEL: func @sparse_convert_1d_ss(
// CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>)
// CHECK-DAG: %[[ToCOO:.*]] = arith.constant 5 : i32
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 53d8467739ef7..3bb09cb168c13 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -2033,6 +2033,7 @@ cc_library(
":Affine",
":ArithmeticDialect",
":BufferizationDialect",
+ ":ComplexDialect",
":FuncDialect",
":FuncTransforms",
":IR",
More information about the Mlir-commits
mailing list