[Mlir-commits] [mlir] 2c7b068 - Fix tensor.extract for complex elements

Frederik Gossen llvmlistbot at llvm.org
Thu Jan 27 19:33:34 PST 2022


Author: Frederik Gossen
Date: 2022-01-28T04:33:15+01:00
New Revision: 2c7b0685e179bcd96045098df8aaef3b73b434ba

URL: https://github.com/llvm/llvm-project/commit/2c7b0685e179bcd96045098df8aaef3b73b434ba
DIFF: https://github.com/llvm/llvm-project/commit/2c7b0685e179bcd96045098df8aaef3b73b434ba.diff

LOG: Fix tensor.extract for complex elements

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
    mlir/include/mlir/Dialect/Tensor/IR/TensorBase.td
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
    mlir/test/Dialect/Tensor/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
index cfec22be37e5f..c8a84be219b2e 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
+++ b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
@@ -10,6 +10,7 @@
 #define MLIR_DIALECT_TENSOR_IR_TENSOR_H_
 
 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/Complex/IR/Complex.h"
 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Dialect.h"

diff  --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorBase.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorBase.td
index 493de0aa053f0..9bf174d4e15fd 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorBase.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorBase.td
@@ -46,7 +46,10 @@ def Tensor_Dialect : Dialect {
   }];
 
   let hasConstantMaterializer = 1;
-  let dependentDialects = ["arith::ArithmeticDialect"];
+  let dependentDialects = [
+    "arith::ArithmeticDialect",
+    "complex::ComplexDialect",
+  ];
 }
 
 #endif // TENSOR_BASE

diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index bba7f9770f21a..2cc927555c3d6 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/Complex/IR/Complex.h"
 #include "mlir/Dialect/StandardOps/Utils/Utils.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
@@ -29,6 +30,9 @@ Operation *TensorDialect::materializeConstant(OpBuilder &builder,
                                               Location loc) {
   if (arith::ConstantOp::isBuildableWith(value, type))
     return builder.create<arith::ConstantOp>(loc, value, type);
+  if (complex::ConstantOp::isBuildableWith(value, type))
+    return builder.create<complex::ConstantOp>(loc, type,
+                                               value.cast<ArrayAttr>());
   if (ConstantOp::isBuildableWith(value, type))
     return builder.create<ConstantOp>(loc, value, type);
   return nullptr;

diff  --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 3084a262af7dd..81be542915d69 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -65,7 +65,7 @@ func @tensor.cast_chain_invalid(%input: tensor<4x8xi32>) -> tensor<8x4xi32> {
 // -----
 
 // CHECK-LABEL: func @fold_extract
-func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32) {
+func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32, complex<f32>) {
   %const_0 = arith.constant 0 : index
   %const_1 = arith.constant 1 : index
   %const_3 = arith.constant 3 : index
@@ -87,11 +87,16 @@ func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32) {
   %ext_3 = tensor.extract %2[%const_0, %const_0, %const_0] : tensor<2x2x2xf16>
 
   // Fold an extract into a dense tensor.
-   %3 = arith.constant dense<[[[1, -2, 1, 36]], [[0, 2, -1, 64]]]> : tensor<2x1x4xi32>
+  %3 = arith.constant dense<[[[1, -2, 1, 36]], [[0, 2, -1, 64]]]> : tensor<2x1x4xi32>
   %ext_4 = tensor.extract %3[%const_1, %const_0, %const_3] : tensor<2x1x4xi32>
 
-  // CHECK-NEXT: return [[C4]], [[CM2]], [[C0]], [[C64]]
-  return %ext_1, %ext_2, %ext_3, %ext_4 : f32, f16, f16, i32
+  // Fold an extract into a complex constant.
+  // CHECK-DAG: [[C5:%.+]] = complex.constant [1.200000e+00 : f32, 2.300000e+00 : f32] : complex<f32>
+  %4 = arith.constant dense<(1.2, 2.3)> : tensor<complex<f32>>
+  %ext_5 = tensor.extract %4[] : tensor<complex<f32>>
+
+  // CHECK-NEXT: return [[C4]], [[CM2]], [[C0]], [[C64]], [[C5]]
+  return %ext_1, %ext_2, %ext_3, %ext_4, %ext_5 : f32, f16, f16, i32, complex<f32>
 }
 
 // -----


        


More information about the Mlir-commits mailing list