[Mlir-commits] [mlir] 480cd4c - [mlir] Move the complex support of std.constant to a new complex.constant operation
River Riddle
llvmlistbot at llvm.org
Wed Jan 26 12:04:57 PST 2022
Author: River Riddle
Date: 2022-01-26T11:52:00-08:00
New Revision: 480cd4cb8560532e544fc0c234749912dde759c6
URL: https://github.com/llvm/llvm-project/commit/480cd4cb8560532e544fc0c234749912dde759c6
DIFF: https://github.com/llvm/llvm-project/commit/480cd4cb8560532e544fc0c234749912dde759c6.diff
LOG: [mlir] Move the complex support of std.constant to a new complex.constant operation
This is part of splitting up the standard dialect.
Differential Revision: https://reviews.llvm.org/D118182
Added:
mlir/test/Dialect/Complex/invalid.mlir
Modified:
mlir/include/mlir/Dialect/Complex/IR/Complex.h
mlir/include/mlir/Dialect/Complex/IR/ComplexBase.td
mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
mlir/lib/Dialect/Complex/IR/CMakeLists.txt
mlir/lib/Dialect/Complex/IR/ComplexDialect.cpp
mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/test/Dialect/Complex/canonicalize.mlir
mlir/test/Dialect/Complex/ops.mlir
mlir/test/Dialect/Standard/invalid.mlir
mlir/test/Dialect/Standard/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Complex/IR/Complex.h b/mlir/include/mlir/Dialect/Complex/IR/Complex.h
index 6f3026a5affb4..2a8a8e7a18332 100644
--- a/mlir/include/mlir/Dialect/Complex/IR/Complex.h
+++ b/mlir/include/mlir/Dialect/Complex/IR/Complex.h
@@ -10,13 +10,12 @@
#define MLIR_DIALECT_COMPLEX_IR_COMPLEX_H_
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
-#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
+#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
-#include "mlir/Interfaces/VectorInterfaces.h"
//===----------------------------------------------------------------------===//
// Complex Dialect
diff --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexBase.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexBase.td
index 4382183254ac4..6cb7d9d92b58c 100644
--- a/mlir/include/mlir/Dialect/Complex/IR/ComplexBase.td
+++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexBase.td
@@ -19,7 +19,7 @@ def Complex_Dialect : Dialect {
arithmetic ops.
}];
- let dependentDialects = ["arith::ArithmeticDialect", "StandardOpsDialect"];
+ let dependentDialects = ["arith::ArithmeticDialect"];
let hasConstantMaterializer = 1;
let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed;
}
diff --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
index 02b44ff16f561..a79d7ac8a2157 100644
--- a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
+++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
@@ -10,6 +10,7 @@
#define COMPLEX_OPS
include "mlir/Dialect/Complex/IR/ComplexBase.td"
+include "mlir/IR/OpAsmInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
@@ -76,6 +77,40 @@ def AddOp : ComplexArithmeticOp<"add"> {
}];
}
+//===----------------------------------------------------------------------===//
+// ConstantOp
+//===----------------------------------------------------------------------===//
+
+def ConstantOp : Complex_Op<"constant", [
+ ConstantLike, NoSideEffect,
+ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
+ ]> {
+ let summary = "complex number constant operation";
+ let description = [{
+ The `complex.constant` operation creates a constant complex number from an
+ attribute containing the real and imaginary parts.
+
+ Example:
+
+ ```mlir
+ %a = complex.constant [0.1, -1.0] : complex<f64>
+ ```
+ }];
+
+ let arguments = (ins ArrayAttr:$value);
+ let results = (outs Complex<AnyFloat>:$complex);
+
+ let assemblyFormat = "$value attr-dict `:` type($complex)";
+ let hasFolder = 1;
+ let verifier = [{ return ::verify(*this); }];
+
+ let extraClassDeclaration = [{
+ /// Returns true if a constant operation can be built with the given value
+ /// and result type.
+ static bool isBuildableWith(Attribute value, Type type);
+ }];
+}
+
//===----------------------------------------------------------------------===//
// CreateOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Complex/IR/CMakeLists.txt b/mlir/lib/Dialect/Complex/IR/CMakeLists.txt
index fdb7e748a01b1..da419d9b25994 100644
--- a/mlir/lib/Dialect/Complex/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Complex/IR/CMakeLists.txt
@@ -11,6 +11,6 @@ add_mlir_dialect_library(MLIRComplex
LINK_LIBS PUBLIC
MLIRArithmetic
MLIRDialect
+ MLIRInferTypeOpInterface
MLIRIR
- MLIRStandard
)
diff --git a/mlir/lib/Dialect/Complex/IR/ComplexDialect.cpp b/mlir/lib/Dialect/Complex/IR/ComplexDialect.cpp
index a5aa9799f5c16..f189c2b2666d0 100644
--- a/mlir/lib/Dialect/Complex/IR/ComplexDialect.cpp
+++ b/mlir/lib/Dialect/Complex/IR/ComplexDialect.cpp
@@ -8,7 +8,6 @@
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
-#include "mlir/Dialect/StandardOps/IR/Ops.h"
using namespace mlir;
@@ -25,9 +24,10 @@ Operation *complex::ComplexDialect::materializeConstant(OpBuilder &builder,
Attribute value,
Type type,
Location loc) {
- // TODO complex.constant
- if (type.isa<ComplexType>())
- return builder.create<ConstantOp>(loc, type, value);
+ if (complex::ConstantOp::isBuildableWith(value, type)) {
+ return builder.create<complex::ConstantOp>(loc, type,
+ value.cast<ArrayAttr>());
+ }
if (arith::ConstantOp::isBuildableWith(value, type))
return builder.create<arith::ConstantOp>(loc, type, value);
return nullptr;
diff --git a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
index 58412d37605b0..36745c5264b86 100644
--- a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
+++ b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
@@ -13,11 +13,54 @@ using namespace mlir;
using namespace mlir::complex;
//===----------------------------------------------------------------------===//
-// TableGen'd op method definitions
+// ConstantOp
//===----------------------------------------------------------------------===//
-#define GET_OP_CLASSES
-#include "mlir/Dialect/Complex/IR/ComplexOps.cpp.inc"
+OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) {
+ assert(operands.empty() && "constant has no operands");
+ return getValue();
+}
+
+void ConstantOp::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+ setNameFn(getResult(), "cst");
+}
+
+bool ConstantOp::isBuildableWith(Attribute value, Type type) {
+ if (auto arrAttr = value.dyn_cast<ArrayAttr>()) {
+ auto complexTy = type.dyn_cast<ComplexType>();
+ if (!complexTy)
+ return false;
+ auto complexEltTy = complexTy.getElementType();
+ return arrAttr.size() == 2 && arrAttr[0].getType() == complexEltTy &&
+ arrAttr[1].getType() == complexEltTy;
+ }
+ return false;
+}
+
+static LogicalResult verify(ConstantOp op) {
+ ArrayAttr arrayAttr = op.getValue();
+ if (arrayAttr.size() != 2) {
+ return op.emitOpError(
+ "requires 'value' to be a complex constant, represented as array of "
+ "two values");
+ }
+
+ auto complexEltTy = op.getType().getElementType();
+ if (complexEltTy != arrayAttr[0].getType() ||
+ complexEltTy != arrayAttr[1].getType()) {
+ return op.emitOpError()
+ << "requires attribute's element types (" << arrayAttr[0].getType()
+ << ", " << arrayAttr[1].getType()
+ << ") to match the element type of the op's return type ("
+ << complexEltTy << ")";
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// CreateOp
+//===----------------------------------------------------------------------===//
OpFoldResult CreateOp::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 2 && "binary op takes two operands");
@@ -32,6 +75,10 @@ OpFoldResult CreateOp::fold(ArrayRef<Attribute> operands) {
return {};
}
+//===----------------------------------------------------------------------===//
+// ImOp
+//===----------------------------------------------------------------------===//
+
OpFoldResult ImOp::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 1 && "unary op takes 1 operand");
ArrayAttr arrayAttr = operands[0].dyn_cast_or_null<ArrayAttr>();
@@ -42,6 +89,10 @@ OpFoldResult ImOp::fold(ArrayRef<Attribute> operands) {
return {};
}
+//===----------------------------------------------------------------------===//
+// ReOp
+//===----------------------------------------------------------------------===//
+
OpFoldResult ReOp::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 1 && "unary op takes 1 operand");
ArrayAttr arrayAttr = operands[0].dyn_cast_or_null<ArrayAttr>();
@@ -51,3 +102,10 @@ OpFoldResult ReOp::fold(ArrayRef<Attribute> operands) {
return createOp.getOperand(0);
return {};
}
+
+//===----------------------------------------------------------------------===//
+// TableGen'd op method definitions
+//===----------------------------------------------------------------------===//
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/Complex/IR/ComplexOps.cpp.inc"
diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 1110879395259..7487aea0902de 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -669,8 +669,8 @@ static void print(OpAsmPrinter &p, ConstantOp &op) {
p << ' ';
p << op.getValue();
- // If the value is a symbol reference or Array, print a trailing type.
- if (op.getValue().isa<SymbolRefAttr, ArrayAttr>())
+ // If the value is a symbol reference, print a trailing type.
+ if (op.getValue().isa<SymbolRefAttr>())
p << " : " << op.getType();
}
@@ -681,10 +681,9 @@ static ParseResult parseConstantOp(OpAsmParser &parser,
parser.parseAttribute(valueAttr, "value", result.attributes))
return failure();
- // If the attribute is a symbol reference or array, then we expect a trailing
- // type.
+ // If the attribute is a symbol reference, then we expect a trailing type.
Type type;
- if (!valueAttr.isa<SymbolRefAttr, ArrayAttr>())
+ if (!valueAttr.isa<SymbolRefAttr>())
type = valueAttr.getType();
else if (parser.parseColonType(type))
return failure();
@@ -705,24 +704,6 @@ static LogicalResult verify(ConstantOp &op) {
return op.emitOpError() << "requires attribute's type (" << value.getType()
<< ") to match op's return type (" << type << ")";
- if (auto complexTy = type.dyn_cast<ComplexType>()) {
- auto arrayAttr = value.dyn_cast<ArrayAttr>();
- if (!complexTy || arrayAttr.size() != 2)
- return op.emitOpError(
- "requires 'value' to be a complex constant, represented as array of "
- "two values");
- auto complexEltTy = complexTy.getElementType();
- if (complexEltTy != arrayAttr[0].getType() ||
- complexEltTy != arrayAttr[1].getType()) {
- return op.emitOpError()
- << "requires attribute's element types (" << arrayAttr[0].getType()
- << ", " << arrayAttr[1].getType()
- << ") to match the element type of the op's return type ("
- << complexEltTy << ")";
- }
- return success();
- }
-
if (type.isa<FunctionType>()) {
auto fnAttr = value.dyn_cast<FlatSymbolRefAttr>();
if (!fnAttr)
@@ -769,19 +750,8 @@ bool ConstantOp::isBuildableWith(Attribute value, Type type) {
// SymbolRefAttr can only be used with a function type.
if (value.isa<SymbolRefAttr>())
return type.isa<FunctionType>();
- // The attribute must have the same type as 'type'.
- if (!value.getType().isa<NoneType>() && value.getType() != type)
- return false;
- // Finally, check that the attribute kind is handled.
- if (auto arrAttr = value.dyn_cast<ArrayAttr>()) {
- auto complexTy = type.dyn_cast<ComplexType>();
- if (!complexTy)
- return false;
- auto complexEltTy = complexTy.getElementType();
- return arrAttr.size() == 2 && arrAttr[0].getType() == complexEltTy &&
- arrAttr[1].getType() == complexEltTy;
- }
- return value.isa<UnitAttr>();
+ // Otherwise, this must be a UnitAttr.
+ return value.isa<UnitAttr>() && type.isa<NoneType>();
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Complex/canonicalize.mlir b/mlir/test/Dialect/Complex/canonicalize.mlir
index 038de9908cf2a..c68d87e8c0773 100644
--- a/mlir/test/Dialect/Complex/canonicalize.mlir
+++ b/mlir/test/Dialect/Complex/canonicalize.mlir
@@ -27,7 +27,7 @@ func @create_of_real_and_imag_
diff erent_operand(
func @real_of_const() -> f32 {
// CHECK: %[[CST:.*]] = arith.constant 1.000000e+00 : f32
// CHECK-NEXT: return %[[CST]] : f32
- %complex = constant [1.0 : f32, 0.0 : f32] : complex<f32>
+ %complex = complex.constant [1.0 : f32, 0.0 : f32] : complex<f32>
%1 = complex.re %complex : complex<f32>
return %1 : f32
}
@@ -47,7 +47,7 @@ func @real_of_create_op() -> f32 {
func @imag_of_const() -> f32 {
// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-NEXT: return %[[CST]] : f32
- %complex = constant [1.0 : f32, 0.0 : f32] : complex<f32>
+ %complex = complex.constant [1.0 : f32, 0.0 : f32] : complex<f32>
%1 = complex.im %complex : complex<f32>
return %1 : f32
}
diff --git a/mlir/test/Dialect/Complex/invalid.mlir b/mlir/test/Dialect/Complex/invalid.mlir
new file mode 100644
index 0000000000000..ec046effacf8c
--- /dev/null
+++ b/mlir/test/Dialect/Complex/invalid.mlir
@@ -0,0 +1,23 @@
+// RUN: mlir-opt -split-input-file %s -verify-diagnostics
+
+func @complex_constant_wrong_array_attribute_length() {
+ // expected-error @+1 {{requires 'value' to be a complex constant, represented as array of two values}}
+ %0 = complex.constant [1.0 : f32] : complex<f32>
+ return
+}
+
+// -----
+
+func @complex_constant_wrong_element_types() {
+ // expected-error @+1 {{requires attribute's element types ('f32', 'f32') to match the element type of the op's return type ('f64')}}
+ %0 = complex.constant [1.0 : f32, -1.0 : f32] : complex<f64>
+ return
+}
+
+// -----
+
+func @complex_constant_two_
diff erent_element_types() {
+ // expected-error @+1 {{requires attribute's element types ('f32', 'f64') to match the element type of the op's return type ('f64')}}
+ %0 = complex.constant [1.0 : f32, -1.0 : f64] : complex<f64>
+ return
+}
diff --git a/mlir/test/Dialect/Complex/ops.mlir b/mlir/test/Dialect/Complex/ops.mlir
index 3fc0e9299c0fd..75bb082efb2ab 100644
--- a/mlir/test/Dialect/Complex/ops.mlir
+++ b/mlir/test/Dialect/Complex/ops.mlir
@@ -5,6 +5,12 @@
// CHECK-LABEL: func @ops(
// CHECK-SAME: %[[F:.*]]: f32) {
func @ops(%f: f32) {
+ // CHECK: complex.constant [1.{{.*}}, -1.{{.*}}] : complex<f64>
+ %cst_f64 = complex.constant [0.1, -1.0] : complex<f64>
+
+ // CHECK: complex.constant [1.{{.*}} : f32, -1.{{.*}} : f32] : complex<f32>
+ %cst_f32 = complex.constant [0.1 : f32, -1.0 : f32] : complex<f32>
+
// CHECK: %[[C:.*]] = complex.create %[[F]], %[[F]] : complex<f32>
%complex = complex.create %f, %f : complex<f32>
@@ -51,4 +57,3 @@ func @ops(%f: f32) {
%
diff = complex.sub %complex, %complex : complex<f32>
return
}
-
diff --git a/mlir/test/Dialect/Standard/invalid.mlir b/mlir/test/Dialect/Standard/invalid.mlir
index defd1aed97b1a..836158dd2160b 100644
--- a/mlir/test/Dialect/Standard/invalid.mlir
+++ b/mlir/test/Dialect/Standard/invalid.mlir
@@ -8,30 +8,6 @@ func @unsupported_attribute() {
// -----
-func @complex_constant_wrong_array_attribute_length() {
- // expected-error @+1 {{requires 'value' to be a complex constant, represented as array of two values}}
- %0 = constant [1.0 : f32] : complex<f32>
- return
-}
-
-// -----
-
-func @complex_constant_wrong_element_types() {
- // expected-error @+1 {{requires attribute's element types ('f32', 'f32') to match the element type of the op's return type ('f64')}}
- %0 = constant [1.0 : f32, -1.0 : f32] : complex<f64>
- return
-}
-
-// -----
-
-func @complex_constant_two_
diff erent_element_types() {
- // expected-error @+1 {{requires attribute's element types ('f32', 'f64') to match the element type of the op's return type ('f64')}}
- %0 = constant [1.0 : f32, -1.0 : f64] : complex<f64>
- return
-}
-
-// -----
-
func @return_i32_f32() -> (i32, f32) {
%0 = arith.constant 1 : i32
%1 = arith.constant 1. : f32
diff --git a/mlir/test/Dialect/Standard/ops.mlir b/mlir/test/Dialect/Standard/ops.mlir
index 64322f066b354..1c3570eedb7ea 100644
--- a/mlir/test/Dialect/Standard/ops.mlir
+++ b/mlir/test/Dialect/Standard/ops.mlir
@@ -51,18 +51,6 @@ func @switch_i64(%flag : i64, %caseOperand : i32) {
return
}
-// CHECK-LABEL: func @constant_complex_f32(
-func @constant_complex_f32() -> complex<f32> {
- %result = constant [0.1 : f32, -1.0 : f32] : complex<f32>
- return %result : complex<f32>
-}
-
-// CHECK-LABEL: func @constant_complex_f64(
-func @constant_complex_f64() -> complex<f64> {
- %result = constant [0.1 : f64, -1.0 : f64] : complex<f64>
- return %result : complex<f64>
-}
-
// CHECK-LABEL: func @vector_splat_0d(
func @vector_splat_0d(%a: f32) -> vector<f32> {
// CHECK: splat %{{.*}} : vector<f32>
More information about the Mlir-commits
mailing list