[Mlir-commits] [mlir] 28b6d41 - [mlir][sparse] add support for complex zero/one building

Aart Bik llvmlistbot at llvm.org
Fri May 20 08:53:48 PDT 2022


Author: Aart Bik
Date: 2022-05-20T08:53:30-07:00
New Revision: 28b6d412afc52d46e44960de429e6688826d9f4f

URL: https://github.com/llvm/llvm-project/commit/28b6d412afc52d46e44960de429e6688826d9f4f
DIFF: https://github.com/llvm/llvm-project/commit/28b6d412afc52d46e44960de429e6688826d9f4f.diff

LOG: [mlir][sparse] add support for complex zero/one building

Reviewed By: bixia

Differential Revision: https://reviews.llvm.org/D126039

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
    mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
    mlir/test/Dialect/SparseTensor/conversion.mlir
    utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
index 51d7bd4fa8728..7e7ed55416360 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
@@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
   LINK_LIBS PUBLIC
   MLIRArithmetic
   MLIRBufferization
+  MLIRComplex
   MLIRFunc
   MLIRIR
   MLIRLLVMIR

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
index 5eb0785baf938..cf574fdabb7ae 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
@@ -189,5 +189,7 @@ Value mlir::sparse_tensor::genIsNonzero(OpBuilder &builder, mlir::Location loc,
   if (tp.isIntOrIndex())
     return builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, v,
                                          zero);
+  if (tp.dyn_cast<ComplexType>())
+    return builder.create<complex::NotEqualOp>(loc, v, zero);
   llvm_unreachable("Non-numeric type");
 }

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
index 944605bab4b9f..ba897f7f7c14b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
@@ -14,6 +14,7 @@
 #define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_CODEGENUTILS_H_
 
 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/Complex/IR/Complex.h"
 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
 #include "mlir/ExecutionEngine/SparseTensorUtils.h"
 #include "mlir/IR/Builders.h"
@@ -102,16 +103,27 @@ Value genIsNonzero(OpBuilder &builder, Location loc, Value v);
 //===----------------------------------------------------------------------===//
 
 /// Generates a 0-valued constant of the given type.  In addition to
-/// the scalar types (`FloatType`, `IndexType`, `IntegerType`), this also
-/// works for `RankedTensorType` and `VectorType` (for which it generates
-/// a constant `DenseElementsAttr` of zeros).
+/// the scalar types (`ComplexType`, ``FloatType`, `IndexType`, `IntegerType`),
+/// this also works for `RankedTensorType` and `VectorType` (for which it
+/// generates a constant `DenseElementsAttr` of zeros).
 inline Value constantZero(OpBuilder &builder, Location loc, Type tp) {
+  if (auto ctp = tp.dyn_cast<ComplexType>()) {
+    auto zeroe = builder.getZeroAttr(ctp.getElementType());
+    auto zeroa = builder.getArrayAttr({zeroe, zeroe});
+    return builder.create<complex::ConstantOp>(loc, tp, zeroa);
+  }
   return builder.create<arith::ConstantOp>(loc, tp, builder.getZeroAttr(tp));
 }
 
 /// Generates a 1-valued constant of the given type.  This supports all
 /// the same types as `constantZero`.
 inline Value constantOne(OpBuilder &builder, Location loc, Type tp) {
+  if (auto ctp = tp.dyn_cast<ComplexType>()) {
+    auto zeroe = builder.getZeroAttr(ctp.getElementType());
+    auto onee = getOneAttr(builder, ctp.getElementType());
+    auto zeroa = builder.getArrayAttr({onee, zeroe});
+    return builder.create<complex::ConstantOp>(loc, tp, zeroa);
+  }
   return builder.create<arith::ConstantOp>(loc, tp, getOneAttr(builder, tp));
 }
 

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index 2b957e6fbad60..8c1f296a39e98 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -9,6 +9,7 @@
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Complex/IR/Complex.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/Func/Transforms/FuncConversions.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
@@ -112,7 +113,8 @@ struct SparseTensorConversionPass
     // The following operations and dialects may be introduced by the
     // rewriting rules, and are therefore marked as legal.
     target.addLegalOp<arith::CmpFOp, arith::CmpIOp, arith::ConstantOp,
-                      arith::IndexCastOp, linalg::FillOp, linalg::YieldOp,
+                      arith::IndexCastOp, complex::ConstantOp,
+                      complex::NotEqualOp, linalg::FillOp, linalg::YieldOp,
                       tensor::ExtractOp>();
     target
         .addLegalDialect<bufferization::BufferizationDialect, LLVM::LLVMDialect,

diff  --git a/mlir/test/Dialect/SparseTensor/conversion.mlir b/mlir/test/Dialect/SparseTensor/conversion.mlir
index cdd1536614742..a1838141cc0ef 100644
--- a/mlir/test/Dialect/SparseTensor/conversion.mlir
+++ b/mlir/test/Dialect/SparseTensor/conversion.mlir
@@ -173,9 +173,10 @@ func.func @sparse_nop_cast(%arg0: tensor<64xf32, #SparseVector>) -> tensor<?xf32
 }
 
 // CHECK-LABEL: func @sparse_convert_1d(
-//  CHECK-SAME: %[[A:.*]]: tensor<?xi32>) -> !llvm.ptr<i8>
+//  CHECK-SAME: %[[A:.*]]: tensor<?xi32>) -> !llvm.ptr<i8> {
 //   CHECK-DAG: %[[EmptyCOO:.*]] = arith.constant 4 : i32
 //   CHECK-DAG: %[[FromCOO:.*]] = arith.constant 2 : i32
+//   CHECK-DAG: %[[I0:.*]] = arith.constant 0 : i32
 //   CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
 //   CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
 //   CHECK-DAG: %[[U:.*]] = tensor.dim %[[A]], %[[C0]] : tensor<?xi32>
@@ -191,8 +192,11 @@ func.func @sparse_nop_cast(%arg0: tensor<64xf32, #SparseVector>) -> tensor<?xf32
 //       CHECK: %[[T:.*]] = memref.cast %[[M]] : memref<1xindex> to memref<?xindex>
 //       CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[U]] step %[[C1]] {
 //       CHECK:   %[[E:.*]] = tensor.extract %[[A]][%[[I]]] : tensor<?xi32>
-//       CHECK:   memref.store %[[I]], %[[M]][%[[C0]]] : memref<1xindex>
-//       CHECK:   call @addEltI32(%[[C]], %[[E]], %[[T]], %[[Z]])
+//       CHECK:   %[[N:.*]] = arith.cmpi ne, %[[E]], %[[I0]] : i32
+//       CHECK:   scf.if %[[N]] {
+//       CHECK:     memref.store %[[I]], %[[M]][%[[C0]]] : memref<1xindex>
+//       CHECK:     call @addEltI32(%[[C]], %[[E]], %[[T]], %[[Z]])
+//       CHECK:   }
 //       CHECK: }
 //       CHECK: %[[T:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %[[FromCOO]], %[[C]])
 //       CHECK: call @delSparseTensorCOOI32(%[[C]])
@@ -202,6 +206,28 @@ func.func @sparse_convert_1d(%arg0: tensor<?xi32>) -> tensor<?xi32, #SparseVecto
   return %0 : tensor<?xi32, #SparseVector>
 }
 
+// CHECK-LABEL: func @sparse_convert_complex(
+//  CHECK-SAME: %[[A:.*]]: tensor<100xcomplex<f64>>) -> !llvm.ptr<i8> {
+//   CHECK-DAG: %[[CC:.*]] = complex.constant [0.000000e+00, 0.000000e+00] : complex<f64>
+//   CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+//   CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+//   CHECK-DAG: %[[C100:.*]] = arith.constant 100 : index
+//       CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[C100]] step %[[C1]] {
+//       CHECK:   %[[E:.*]] = tensor.extract %[[A]][%[[I]]] : tensor<100xcomplex<f64>>
+//       CHECK:   %[[N:.*]] = complex.neq %[[E]], %[[CC]] : complex<f64>
+//       CHECK:   scf.if %[[N]] {
+//       CHECK:     memref.store %[[I]], %{{.*}}[%[[C0]]] : memref<1xindex>
+//       CHECK:     call @addEltC64
+//       CHECK:   }
+//       CHECK: }
+//       CHECK: %[[T:.*]] = call @newSparseTensor
+//       CHECK: call @delSparseTensorCOOC64
+//       CHECK: return %[[T]] : !llvm.ptr<i8>
+func.func @sparse_convert_complex(%arg0: tensor<100xcomplex<f64>>) -> tensor<100xcomplex<f64>, #SparseVector> {
+  %0 = sparse_tensor.convert %arg0 : tensor<100xcomplex<f64>> to tensor<100xcomplex<f64>, #SparseVector>
+  return %0 : tensor<100xcomplex<f64>, #SparseVector>
+}
+
 // CHECK-LABEL: func @sparse_convert_1d_ss(
 //  CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>)
 //  CHECK-DAG:  %[[ToCOO:.*]] = arith.constant 5 : i32

diff  --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 53d8467739ef7..3bb09cb168c13 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -2033,6 +2033,7 @@ cc_library(
         ":Affine",
         ":ArithmeticDialect",
         ":BufferizationDialect",
+        ":ComplexDialect",
         ":FuncDialect",
         ":FuncTransforms",
         ":IR",


        


More information about the Mlir-commits mailing list