[llvm] [mlir] [mlir][sparse] Support explicit/implicit value for complex type (PR #90771)

Yinying Li via llvm-commits llvm-commits at lists.llvm.org
Wed May 1 14:12:05 PDT 2024


https://github.com/yinying-lisa-li updated https://github.com/llvm/llvm-project/pull/90771

>From 72793de30b0c992daf81f397059afb884d825a2d Mon Sep 17 00:00:00 2001
From: Yinying Li <yinyingli at google.com>
Date: Wed, 1 May 2024 19:51:14 +0000
Subject: [PATCH 1/2] [mlir][sparse] Support explicit/implicit value for
 complex type

---
 .../Dialect/SparseTensor/IR/CMakeLists.txt    |  1 +
 .../SparseTensor/IR/SparseTensorDialect.cpp   |  7 +++++++
 .../Transforms/Utils/CodegenUtils.h           |  9 ++++++---
 .../SparseTensor/roundtrip_encoding.mlir      | 15 ++++++++++++++
 .../SparseTensor/sparse_matmul_one.mlir       | 20 +++++++++++++------
 .../llvm-project-overlay/mlir/BUILD.bazel     |  1 +
 6 files changed, 44 insertions(+), 9 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/IR/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/IR/CMakeLists.txt
index dd6f1037f71b53..6f59b69bddce86 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/SparseTensor/IR/CMakeLists.txt
@@ -45,6 +45,7 @@ add_mlir_dialect_library(MLIRSparseTensorDialect
 
   LINK_LIBS PUBLIC
   MLIRArithDialect
+  MLIRComplexDialect
   MLIRDialect
   MLIRDialectUtils
   MLIRIR
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 028a69da10c1e1..dac028e8d53cb0 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -17,6 +17,7 @@
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
+#include "mlir/Dialect/Complex/IR/Complex.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/DialectImplementation.h"
@@ -663,6 +664,9 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
         explicitVal = result;
       } else if (auto result = llvm::dyn_cast<IntegerAttr>(attr)) {
         explicitVal = result;
+      } else if (auto result =
+                     llvm::dyn_cast<::mlir::complex::NumberAttr>(attr)) {
+        explicitVal = result;
       } else {
         parser.emitError(parser.getNameLoc(),
                          "expected a numeric value for explicitVal");
@@ -678,6 +682,9 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
         implicitVal = result;
       } else if (auto result = llvm::dyn_cast<IntegerAttr>(attr)) {
         implicitVal = result;
+      } else if (auto result =
+                     llvm::dyn_cast<::mlir::complex::NumberAttr>(attr)) {
+        implicitVal = result;
       } else {
         parser.emitError(parser.getNameLoc(),
                          "expected a numeric value for implicitVal");
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h
index cf3c35f5fa4c78..d0ef8a6860bb2d 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h
@@ -401,9 +401,12 @@ inline Value constantLevelTypeEncoding(OpBuilder &builder, Location loc,
 
 // Generates a constant from a validated value carrying attribute.
 inline Value genValFromAttr(OpBuilder &builder, Location loc, Attribute attr) {
-  if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
-    Type tp = cast<TypedAttr>(arrayAttr[0]).getType();
-    return builder.create<complex::ConstantOp>(loc, tp, arrayAttr);
+  if (auto complexAttr = dyn_cast<complex::NumberAttr>(attr)) {
+    Type tp = cast<ComplexType>(complexAttr.getType()).getElementType();
+    return builder.create<complex::ConstantOp>(
+        loc, complexAttr.getType(),
+        builder.getArrayAttr({FloatAttr::get(tp, complexAttr.getReal()),
+                              FloatAttr::get(tp, complexAttr.getImag())}));
   }
   return builder.create<arith::ConstantOp>(loc, cast<TypedAttr>(attr));
 }
diff --git a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
index 7eeda9a9880268..7fb1c76c1a1ff6 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
@@ -80,6 +80,21 @@ func.func private @sparse_csr(tensor<?x?xi64, #CSR_OnlyOnes>)
 
 // -----
 
+#CSR_OnlyOnes = #sparse_tensor.encoding<{
+  map = (d0, d1) -> (d0 : dense, d1 : compressed),
+  posWidth = 64,
+  crdWidth = 64,
+  explicitVal = #complex.number<:f32 1.0, 0.0>,
+  implicitVal = #complex.number<:f32 0.0, 0.0>
+}>
+
+// CHECK: #[[$CSR_OnlyOnes:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed), posWidth = 64, crdWidth = 64, explicitVal = #complex.number<:f32 1.000000e+00, 0.000000e+00> : complex<f32>, implicitVal = #complex.number<:f32 0.000000e+00, 0.000000e+00> : complex<f32> }>
+// CHECK-LABEL: func private @sparse_csr(
+// CHECK-SAME: tensor<?x?xcomplex<f32>, #[[$CSR_OnlyOnes]]>)
+func.func private @sparse_csr(tensor<?x?xcomplex<f32>, #CSR_OnlyOnes>)
+
+// -----
+
 #BCSR = #sparse_tensor.encoding<{
   map = (d0, d1, d2) -> (d0 : batch, d1: dense, d2 : compressed),
 }>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_matmul_one.mlir b/mlir/test/Dialect/SparseTensor/sparse_matmul_one.mlir
index 82f3147d3206bd..be2172515d08bf 100755
--- a/mlir/test/Dialect/SparseTensor/sparse_matmul_one.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_matmul_one.mlir
@@ -2,9 +2,9 @@
 // RUN:             --sparsification-and-bufferization | FileCheck %s
 
 #CSR_ones_complex = #sparse_tensor.encoding<{
-  map = (d0, d1) -> (d0 : dense, d1 : compressed)
-// explicitVal = (1.0, 0.0) : complex<f32>,
-// implicitVal = (0.0, 0.0) : complex<f32>
+  map = (d0, d1) -> (d0 : dense, d1 : compressed),
+  explicitVal = #complex.number<:f32 1.0, 0.0>,
+  implicitVal = #complex.number<:f32 0.0, 0.0>
 }>
 
 #CSR_ones_fp = #sparse_tensor.encoding<{
@@ -20,9 +20,17 @@
 }>
 
 // CHECK-LABEL:   func.func @matmul_complex
-//
-// TODO: make this work
-//
+// CHECK:         scf.for
+// CHECK:           scf.for
+// CHECK:             %[[X:.*]] = memref.load
+// CHECK:             scf.for
+// CHECK:               %[[I:.*]] = memref.load
+// CHECK:               %[[Y:.*]] = memref.load
+// CHECK:               %[[M:.*]] = complex.add %[[Y]], %[[X]] : complex<f32>
+// CHECK:               memref.store %[[M]]
+// CHECK:             }
+// CHECK:           }
+// CHECK:         }
 func.func @matmul_complex(%a: tensor<10x20xcomplex<f32>>,
                           %b: tensor<20x30xcomplex<f32>, #CSR_ones_complex>,
                           %c: tensor<10x30xcomplex<f32>>) -> tensor<10x30xcomplex<f32>> {
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index acd2d3a14d7411..13c246a3fec6af 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -3066,6 +3066,7 @@ cc_library(
         ":ArithDialect",
         ":BufferizationInterfaces",
         ":BytecodeOpInterface",
+        ":ComplexDialect",
         ":DialectUtils",
         ":IR",
         ":InferTypeOpInterface",

>From 62d23eeea9a3dacff5bd486d340cefaab1cdb93c Mon Sep 17 00:00:00 2001
From: Yinying Li <yinyingli at google.com>
Date: Wed, 1 May 2024 21:11:25 +0000
Subject: [PATCH 2/2] address comments

---
 mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp | 6 ++----
 1 file changed, 2 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index dac028e8d53cb0..de3d3006ebaac5 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -664,8 +664,7 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
         explicitVal = result;
       } else if (auto result = llvm::dyn_cast<IntegerAttr>(attr)) {
         explicitVal = result;
-      } else if (auto result =
-                     llvm::dyn_cast<::mlir::complex::NumberAttr>(attr)) {
+      } else if (auto result = llvm::dyn_cast<complex::NumberAttr>(attr)) {
         explicitVal = result;
       } else {
         parser.emitError(parser.getNameLoc(),
@@ -682,8 +681,7 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
         implicitVal = result;
       } else if (auto result = llvm::dyn_cast<IntegerAttr>(attr)) {
         implicitVal = result;
-      } else if (auto result =
-                     llvm::dyn_cast<::mlir::complex::NumberAttr>(attr)) {
+      } else if (auto result = llvm::dyn_cast<complex::NumberAttr>(attr)) {
         implicitVal = result;
       } else {
         parser.emitError(parser.getNameLoc(),



More information about the llvm-commits mailing list