[Mlir-commits] [mlir] 2c7b068 - Fix tensor.extract for complex elements
Frederik Gossen
llvmlistbot at llvm.org
Thu Jan 27 19:33:34 PST 2022
Author: Frederik Gossen
Date: 2022-01-28T04:33:15+01:00
New Revision: 2c7b0685e179bcd96045098df8aaef3b73b434ba
URL: https://github.com/llvm/llvm-project/commit/2c7b0685e179bcd96045098df8aaef3b73b434ba
DIFF: https://github.com/llvm/llvm-project/commit/2c7b0685e179bcd96045098df8aaef3b73b434ba.diff
LOG: Fix tensor.extract for complex elements
Added:
Modified:
mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
mlir/include/mlir/Dialect/Tensor/IR/TensorBase.td
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/test/Dialect/Tensor/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
index cfec22be37e5f..c8a84be219b2e 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
+++ b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
@@ -10,6 +10,7 @@
#define MLIR_DIALECT_TENSOR_IR_TENSOR_H_
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorBase.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorBase.td
index 493de0aa053f0..9bf174d4e15fd 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorBase.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorBase.td
@@ -46,7 +46,10 @@ def Tensor_Dialect : Dialect {
}];
let hasConstantMaterializer = 1;
- let dependentDialects = ["arith::ArithmeticDialect"];
+ let dependentDialects = [
+ "arith::ArithmeticDialect",
+ "complex::ComplexDialect",
+ ];
}
#endif // TENSOR_BASE
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index bba7f9770f21a..2cc927555c3d6 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/StandardOps/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
@@ -29,6 +30,9 @@ Operation *TensorDialect::materializeConstant(OpBuilder &builder,
Location loc) {
if (arith::ConstantOp::isBuildableWith(value, type))
return builder.create<arith::ConstantOp>(loc, value, type);
+ if (complex::ConstantOp::isBuildableWith(value, type))
+ return builder.create<complex::ConstantOp>(loc, type,
+ value.cast<ArrayAttr>());
if (ConstantOp::isBuildableWith(value, type))
return builder.create<ConstantOp>(loc, value, type);
return nullptr;
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 3084a262af7dd..81be542915d69 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -65,7 +65,7 @@ func @tensor.cast_chain_invalid(%input: tensor<4x8xi32>) -> tensor<8x4xi32> {
// -----
// CHECK-LABEL: func @fold_extract
-func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32) {
+func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32, complex<f32>) {
%const_0 = arith.constant 0 : index
%const_1 = arith.constant 1 : index
%const_3 = arith.constant 3 : index
@@ -87,11 +87,16 @@ func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32) {
%ext_3 = tensor.extract %2[%const_0, %const_0, %const_0] : tensor<2x2x2xf16>
// Fold an extract into a dense tensor.
- %3 = arith.constant dense<[[[1, -2, 1, 36]], [[0, 2, -1, 64]]]> : tensor<2x1x4xi32>
+ %3 = arith.constant dense<[[[1, -2, 1, 36]], [[0, 2, -1, 64]]]> : tensor<2x1x4xi32>
%ext_4 = tensor.extract %3[%const_1, %const_0, %const_3] : tensor<2x1x4xi32>
- // CHECK-NEXT: return [[C4]], [[CM2]], [[C0]], [[C64]]
- return %ext_1, %ext_2, %ext_3, %ext_4 : f32, f16, f16, i32
+ // Fold an extract into a complex constant.
+ // CHECK-DAG: [[C5:%.+]] = complex.constant [1.200000e+00 : f32, 2.300000e+00 : f32] : complex<f32>
+ %4 = arith.constant dense<(1.2, 2.3)> : tensor<complex<f32>>
+ %ext_5 = tensor.extract %4[] : tensor<complex<f32>>
+
+ // CHECK-NEXT: return [[C4]], [[CM2]], [[C0]], [[C64]], [[C5]]
+ return %ext_1, %ext_2, %ext_3, %ext_4, %ext_5 : f32, f16, f16, i32, complex<f32>
}
// -----
More information about the Mlir-commits
mailing list