[Mlir-commits] [mlir] 480cd4c - [mlir] Move the complex support of std.constant to a new complex.constant operation

River Riddle llvmlistbot at llvm.org
Wed Jan 26 12:04:57 PST 2022


Author: River Riddle
Date: 2022-01-26T11:52:00-08:00
New Revision: 480cd4cb8560532e544fc0c234749912dde759c6

URL: https://github.com/llvm/llvm-project/commit/480cd4cb8560532e544fc0c234749912dde759c6
DIFF: https://github.com/llvm/llvm-project/commit/480cd4cb8560532e544fc0c234749912dde759c6.diff

LOG: [mlir] Move the complex support of std.constant to a new complex.constant operation

This is part of splitting up the standard dialect.

Differential Revision: https://reviews.llvm.org/D118182

Added: 
    mlir/test/Dialect/Complex/invalid.mlir

Modified: 
    mlir/include/mlir/Dialect/Complex/IR/Complex.h
    mlir/include/mlir/Dialect/Complex/IR/ComplexBase.td
    mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
    mlir/lib/Dialect/Complex/IR/CMakeLists.txt
    mlir/lib/Dialect/Complex/IR/ComplexDialect.cpp
    mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/test/Dialect/Complex/canonicalize.mlir
    mlir/test/Dialect/Complex/ops.mlir
    mlir/test/Dialect/Standard/invalid.mlir
    mlir/test/Dialect/Standard/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Complex/IR/Complex.h b/mlir/include/mlir/Dialect/Complex/IR/Complex.h
index 6f3026a5affb4..2a8a8e7a18332 100644
--- a/mlir/include/mlir/Dialect/Complex/IR/Complex.h
+++ b/mlir/include/mlir/Dialect/Complex/IR/Complex.h
@@ -10,13 +10,12 @@
 #define MLIR_DIALECT_COMPLEX_IR_COMPLEX_H_
 
 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
-#include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/OpImplementation.h"
+#include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
-#include "mlir/Interfaces/VectorInterfaces.h"
 
 //===----------------------------------------------------------------------===//
 // Complex Dialect

diff  --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexBase.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexBase.td
index 4382183254ac4..6cb7d9d92b58c 100644
--- a/mlir/include/mlir/Dialect/Complex/IR/ComplexBase.td
+++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexBase.td
@@ -19,7 +19,7 @@ def Complex_Dialect : Dialect {
     arithmetic ops.
   }];
 
-  let dependentDialects = ["arith::ArithmeticDialect", "StandardOpsDialect"];
+  let dependentDialects = ["arith::ArithmeticDialect"];
   let hasConstantMaterializer = 1;
   let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed;
 }

diff  --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
index 02b44ff16f561..a79d7ac8a2157 100644
--- a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
+++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
@@ -10,6 +10,7 @@
 #define COMPLEX_OPS
 
 include "mlir/Dialect/Complex/IR/ComplexBase.td"
+include "mlir/IR/OpAsmInterface.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 
@@ -76,6 +77,40 @@ def AddOp : ComplexArithmeticOp<"add"> {
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// ConstantOp
+//===----------------------------------------------------------------------===//
+
+def ConstantOp : Complex_Op<"constant", [
+    ConstantLike, NoSideEffect,
+    DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
+  ]> {
+  let summary = "complex number constant operation";
+  let description = [{
+    The `complex.constant` operation creates a constant complex number from an
+    attribute containing the real and imaginary parts.
+
+    Example:
+
+    ```mlir
+    %a = complex.constant [0.1, -1.0] : complex<f64>
+    ```
+  }];
+
+  let arguments = (ins ArrayAttr:$value);
+  let results = (outs Complex<AnyFloat>:$complex);
+
+  let assemblyFormat = "$value attr-dict `:` type($complex)";
+  let hasFolder = 1;
+  let verifier = [{ return ::verify(*this); }];
+  
+  let extraClassDeclaration = [{
+    /// Returns true if a constant operation can be built with the given value
+    /// and result type.
+    static bool isBuildableWith(Attribute value, Type type);
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // CreateOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Complex/IR/CMakeLists.txt b/mlir/lib/Dialect/Complex/IR/CMakeLists.txt
index fdb7e748a01b1..da419d9b25994 100644
--- a/mlir/lib/Dialect/Complex/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Complex/IR/CMakeLists.txt
@@ -11,6 +11,6 @@ add_mlir_dialect_library(MLIRComplex
   LINK_LIBS PUBLIC
   MLIRArithmetic
   MLIRDialect
+  MLIRInferTypeOpInterface
   MLIRIR
-  MLIRStandard
   )

diff  --git a/mlir/lib/Dialect/Complex/IR/ComplexDialect.cpp b/mlir/lib/Dialect/Complex/IR/ComplexDialect.cpp
index a5aa9799f5c16..f189c2b2666d0 100644
--- a/mlir/lib/Dialect/Complex/IR/ComplexDialect.cpp
+++ b/mlir/lib/Dialect/Complex/IR/ComplexDialect.cpp
@@ -8,7 +8,6 @@
 
 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
 #include "mlir/Dialect/Complex/IR/Complex.h"
-#include "mlir/Dialect/StandardOps/IR/Ops.h"
 
 using namespace mlir;
 
@@ -25,9 +24,10 @@ Operation *complex::ComplexDialect::materializeConstant(OpBuilder &builder,
                                                         Attribute value,
                                                         Type type,
                                                         Location loc) {
-  // TODO complex.constant
-  if (type.isa<ComplexType>())
-    return builder.create<ConstantOp>(loc, type, value);
+  if (complex::ConstantOp::isBuildableWith(value, type)) {
+    return builder.create<complex::ConstantOp>(loc, type,
+                                               value.cast<ArrayAttr>());
+  }
   if (arith::ConstantOp::isBuildableWith(value, type))
     return builder.create<arith::ConstantOp>(loc, type, value);
   return nullptr;

diff  --git a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
index 58412d37605b0..36745c5264b86 100644
--- a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
+++ b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
@@ -13,11 +13,54 @@ using namespace mlir;
 using namespace mlir::complex;
 
 //===----------------------------------------------------------------------===//
-// TableGen'd op method definitions
+// ConstantOp
 //===----------------------------------------------------------------------===//
 
-#define GET_OP_CLASSES
-#include "mlir/Dialect/Complex/IR/ComplexOps.cpp.inc"
+OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) {
+  assert(operands.empty() && "constant has no operands");
+  return getValue();
+}
+
+void ConstantOp::getAsmResultNames(
+    function_ref<void(Value, StringRef)> setNameFn) {
+  setNameFn(getResult(), "cst");
+}
+
+bool ConstantOp::isBuildableWith(Attribute value, Type type) {
+  if (auto arrAttr = value.dyn_cast<ArrayAttr>()) {
+    auto complexTy = type.dyn_cast<ComplexType>();
+    if (!complexTy)
+      return false;
+    auto complexEltTy = complexTy.getElementType();
+    return arrAttr.size() == 2 && arrAttr[0].getType() == complexEltTy &&
+           arrAttr[1].getType() == complexEltTy;
+  }
+  return false;
+}
+
+static LogicalResult verify(ConstantOp op) {
+  ArrayAttr arrayAttr = op.getValue();
+  if (arrayAttr.size() != 2) {
+    return op.emitOpError(
+        "requires 'value' to be a complex constant, represented as array of "
+        "two values");
+  }
+
+  auto complexEltTy = op.getType().getElementType();
+  if (complexEltTy != arrayAttr[0].getType() ||
+      complexEltTy != arrayAttr[1].getType()) {
+    return op.emitOpError()
+           << "requires attribute's element types (" << arrayAttr[0].getType()
+           << ", " << arrayAttr[1].getType()
+           << ") to match the element type of the op's return type ("
+           << complexEltTy << ")";
+  }
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// CreateOp
+//===----------------------------------------------------------------------===//
 
 OpFoldResult CreateOp::fold(ArrayRef<Attribute> operands) {
   assert(operands.size() == 2 && "binary op takes two operands");
@@ -32,6 +75,10 @@ OpFoldResult CreateOp::fold(ArrayRef<Attribute> operands) {
   return {};
 }
 
+//===----------------------------------------------------------------------===//
+// ImOp
+//===----------------------------------------------------------------------===//
+
 OpFoldResult ImOp::fold(ArrayRef<Attribute> operands) {
   assert(operands.size() == 1 && "unary op takes 1 operand");
   ArrayAttr arrayAttr = operands[0].dyn_cast_or_null<ArrayAttr>();
@@ -42,6 +89,10 @@ OpFoldResult ImOp::fold(ArrayRef<Attribute> operands) {
   return {};
 }
 
+//===----------------------------------------------------------------------===//
+// ReOp
+//===----------------------------------------------------------------------===//
+
 OpFoldResult ReOp::fold(ArrayRef<Attribute> operands) {
   assert(operands.size() == 1 && "unary op takes 1 operand");
   ArrayAttr arrayAttr = operands[0].dyn_cast_or_null<ArrayAttr>();
@@ -51,3 +102,10 @@ OpFoldResult ReOp::fold(ArrayRef<Attribute> operands) {
     return createOp.getOperand(0);
   return {};
 }
+
+//===----------------------------------------------------------------------===//
+// TableGen'd op method definitions
+//===----------------------------------------------------------------------===//
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/Complex/IR/ComplexOps.cpp.inc"

diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 1110879395259..7487aea0902de 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -669,8 +669,8 @@ static void print(OpAsmPrinter &p, ConstantOp &op) {
     p << ' ';
   p << op.getValue();
 
-  // If the value is a symbol reference or Array, print a trailing type.
-  if (op.getValue().isa<SymbolRefAttr, ArrayAttr>())
+  // If the value is a symbol reference, print a trailing type.
+  if (op.getValue().isa<SymbolRefAttr>())
     p << " : " << op.getType();
 }
 
@@ -681,10 +681,9 @@ static ParseResult parseConstantOp(OpAsmParser &parser,
       parser.parseAttribute(valueAttr, "value", result.attributes))
     return failure();
 
-  // If the attribute is a symbol reference or array, then we expect a trailing
-  // type.
+  // If the attribute is a symbol reference, then we expect a trailing type.
   Type type;
-  if (!valueAttr.isa<SymbolRefAttr, ArrayAttr>())
+  if (!valueAttr.isa<SymbolRefAttr>())
     type = valueAttr.getType();
   else if (parser.parseColonType(type))
     return failure();
@@ -705,24 +704,6 @@ static LogicalResult verify(ConstantOp &op) {
     return op.emitOpError() << "requires attribute's type (" << value.getType()
                             << ") to match op's return type (" << type << ")";
 
-  if (auto complexTy = type.dyn_cast<ComplexType>()) {
-    auto arrayAttr = value.dyn_cast<ArrayAttr>();
-    if (!complexTy || arrayAttr.size() != 2)
-      return op.emitOpError(
-          "requires 'value' to be a complex constant, represented as array of "
-          "two values");
-    auto complexEltTy = complexTy.getElementType();
-    if (complexEltTy != arrayAttr[0].getType() ||
-        complexEltTy != arrayAttr[1].getType()) {
-      return op.emitOpError()
-             << "requires attribute's element types (" << arrayAttr[0].getType()
-             << ", " << arrayAttr[1].getType()
-             << ") to match the element type of the op's return type ("
-             << complexEltTy << ")";
-    }
-    return success();
-  }
-
   if (type.isa<FunctionType>()) {
     auto fnAttr = value.dyn_cast<FlatSymbolRefAttr>();
     if (!fnAttr)
@@ -769,19 +750,8 @@ bool ConstantOp::isBuildableWith(Attribute value, Type type) {
   // SymbolRefAttr can only be used with a function type.
   if (value.isa<SymbolRefAttr>())
     return type.isa<FunctionType>();
-  // The attribute must have the same type as 'type'.
-  if (!value.getType().isa<NoneType>() && value.getType() != type)
-    return false;
-  // Finally, check that the attribute kind is handled.
-  if (auto arrAttr = value.dyn_cast<ArrayAttr>()) {
-    auto complexTy = type.dyn_cast<ComplexType>();
-    if (!complexTy)
-      return false;
-    auto complexEltTy = complexTy.getElementType();
-    return arrAttr.size() == 2 && arrAttr[0].getType() == complexEltTy &&
-           arrAttr[1].getType() == complexEltTy;
-  }
-  return value.isa<UnitAttr>();
+  // Otherwise, this must be a UnitAttr.
+  return value.isa<UnitAttr>() && type.isa<NoneType>();
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Complex/canonicalize.mlir b/mlir/test/Dialect/Complex/canonicalize.mlir
index 038de9908cf2a..c68d87e8c0773 100644
--- a/mlir/test/Dialect/Complex/canonicalize.mlir
+++ b/mlir/test/Dialect/Complex/canonicalize.mlir
@@ -27,7 +27,7 @@ func @create_of_real_and_imag_
diff erent_operand(
 func @real_of_const() -> f32 {
   // CHECK: %[[CST:.*]] = arith.constant 1.000000e+00 : f32
   // CHECK-NEXT: return %[[CST]] : f32
-  %complex = constant [1.0 : f32, 0.0 : f32] : complex<f32>
+  %complex = complex.constant [1.0 : f32, 0.0 : f32] : complex<f32>
   %1 = complex.re %complex : complex<f32>
   return %1 : f32
 }
@@ -47,7 +47,7 @@ func @real_of_create_op() -> f32 {
 func @imag_of_const() -> f32 {
   // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
   // CHECK-NEXT: return %[[CST]] : f32
-  %complex = constant [1.0 : f32, 0.0 : f32] : complex<f32>
+  %complex = complex.constant [1.0 : f32, 0.0 : f32] : complex<f32>
   %1 = complex.im %complex : complex<f32>
   return %1 : f32
 }

diff  --git a/mlir/test/Dialect/Complex/invalid.mlir b/mlir/test/Dialect/Complex/invalid.mlir
new file mode 100644
index 0000000000000..ec046effacf8c
--- /dev/null
+++ b/mlir/test/Dialect/Complex/invalid.mlir
@@ -0,0 +1,23 @@
+// RUN: mlir-opt -split-input-file %s -verify-diagnostics
+
+func @complex_constant_wrong_array_attribute_length() {
+  // expected-error @+1 {{requires 'value' to be a complex constant, represented as array of two values}}
+  %0 = complex.constant [1.0 : f32] : complex<f32>
+  return
+}
+
+// -----
+
+func @complex_constant_wrong_element_types() {
+  // expected-error @+1 {{requires attribute's element types ('f32', 'f32') to match the element type of the op's return type ('f64')}}
+  %0 = complex.constant [1.0 : f32, -1.0 : f32] : complex<f64>
+  return
+}
+
+// -----
+
+func @complex_constant_two_
diff erent_element_types() {
+  // expected-error @+1 {{requires attribute's element types ('f32', 'f64') to match the element type of the op's return type ('f64')}}
+  %0 = complex.constant [1.0 : f32, -1.0 : f64] : complex<f64>
+  return
+}

diff  --git a/mlir/test/Dialect/Complex/ops.mlir b/mlir/test/Dialect/Complex/ops.mlir
index 3fc0e9299c0fd..75bb082efb2ab 100644
--- a/mlir/test/Dialect/Complex/ops.mlir
+++ b/mlir/test/Dialect/Complex/ops.mlir
@@ -5,6 +5,12 @@
 // CHECK-LABEL: func @ops(
 // CHECK-SAME:            %[[F:.*]]: f32) {
 func @ops(%f: f32) {
+  // CHECK: complex.constant [1.{{.*}}, -1.{{.*}}] : complex<f64>
+  %cst_f64 = complex.constant [0.1, -1.0] : complex<f64>
+
+  // CHECK: complex.constant [1.{{.*}} : f32, -1.{{.*}} : f32] : complex<f32>
+  %cst_f32 = complex.constant [0.1 : f32, -1.0 : f32] : complex<f32>
+
   // CHECK: %[[C:.*]] = complex.create %[[F]], %[[F]] : complex<f32>
   %complex = complex.create %f, %f : complex<f32>
 
@@ -51,4 +57,3 @@ func @ops(%f: f32) {
   %
diff  = complex.sub %complex, %complex : complex<f32>
   return
 }
-

diff  --git a/mlir/test/Dialect/Standard/invalid.mlir b/mlir/test/Dialect/Standard/invalid.mlir
index defd1aed97b1a..836158dd2160b 100644
--- a/mlir/test/Dialect/Standard/invalid.mlir
+++ b/mlir/test/Dialect/Standard/invalid.mlir
@@ -8,30 +8,6 @@ func @unsupported_attribute() {
 
 // -----
 
-func @complex_constant_wrong_array_attribute_length() {
-  // expected-error @+1 {{requires 'value' to be a complex constant, represented as array of two values}}
-  %0 = constant [1.0 : f32] : complex<f32>
-  return
-}
-
-// -----
-
-func @complex_constant_wrong_element_types() {
-  // expected-error @+1 {{requires attribute's element types ('f32', 'f32') to match the element type of the op's return type ('f64')}}
-  %0 = constant [1.0 : f32, -1.0 : f32] : complex<f64>
-  return
-}
-
-// -----
-
-func @complex_constant_two_
diff erent_element_types() {
-  // expected-error @+1 {{requires attribute's element types ('f32', 'f64') to match the element type of the op's return type ('f64')}}
-  %0 = constant [1.0 : f32, -1.0 : f64] : complex<f64>
-  return
-}
-
-// -----
-
 func @return_i32_f32() -> (i32, f32) {
   %0 = arith.constant 1 : i32
   %1 = arith.constant 1. : f32

diff  --git a/mlir/test/Dialect/Standard/ops.mlir b/mlir/test/Dialect/Standard/ops.mlir
index 64322f066b354..1c3570eedb7ea 100644
--- a/mlir/test/Dialect/Standard/ops.mlir
+++ b/mlir/test/Dialect/Standard/ops.mlir
@@ -51,18 +51,6 @@ func @switch_i64(%flag : i64, %caseOperand : i32) {
     return
 }
 
-// CHECK-LABEL: func @constant_complex_f32(
-func @constant_complex_f32() -> complex<f32> {
-  %result = constant [0.1 : f32, -1.0 : f32] : complex<f32>
-  return %result : complex<f32>
-}
-
-// CHECK-LABEL: func @constant_complex_f64(
-func @constant_complex_f64() -> complex<f64> {
-  %result = constant [0.1 : f64, -1.0 : f64] : complex<f64>
-  return %result : complex<f64>
-}
-
 // CHECK-LABEL: func @vector_splat_0d(
 func @vector_splat_0d(%a: f32) -> vector<f32> {
   // CHECK: splat %{{.*}} : vector<f32>


        


More information about the Mlir-commits mailing list