[Mlir-commits] [mlir] c108966 - [mlir] Generate CmpFPredicate as an EnumAttr in tablegen

River Riddle llvmlistbot at llvm.org
Tue Mar 3 13:20:12 PST 2020


Author: River Riddle
Date: 2020-03-03T13:19:25-08:00
New Revision: c10896682d0bb457b9d77fdd753ed9e7e6806db1

URL: https://github.com/llvm/llvm-project/commit/c10896682d0bb457b9d77fdd753ed9e7e6806db1
DIFF: https://github.com/llvm/llvm-project/commit/c10896682d0bb457b9d77fdd753ed9e7e6806db1.diff

LOG: [mlir] Generate CmpFPredicate as an EnumAttr in tablegen

Summary: This allows for attaching the attribute to CmpF as a proper argument, and thus enables the removal of a bunch of c++ code.

Differential Revision: https://reviews.llvm.org/D75539

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/test/IR/invalid-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
index c9b9bb0e32f1..1e19c0270416 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
@@ -40,37 +40,6 @@ class StandardOpsDialect : public Dialect {
                                  Location loc) override;
 };
 
-/// The predicate indicates the type of the comparison to perform:
-/// (un)orderedness, (in)equality and less/greater than (or equal to) as
-/// well as predicates that are always true or false.
-enum class CmpFPredicate {
-  FirstValidValue,
-  // Always false
-  AlwaysFalse = FirstValidValue,
-  // Ordered comparisons
-  OEQ,
-  OGT,
-  OGE,
-  OLT,
-  OLE,
-  ONE,
-  // Both ordered
-  ORD,
-  // Unordered comparisons
-  UEQ,
-  UGT,
-  UGE,
-  ULT,
-  ULE,
-  UNE,
-  // Any unordered
-  UNO,
-  // Always true
-  AlwaysTrue,
-  // Number of predicates.
-  NumPredicates
-};
-
 #define GET_OP_CLASSES
 #include "mlir/Dialect/StandardOps/IR/Ops.h.inc"
 

diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 85870010f0e2..851a6434a9a1 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -433,6 +433,34 @@ def CeilFOp : FloatUnaryOp<"ceilf"> {
   }];
 }
 
+// The predicate indicates the type of the comparison to perform:
+// (un)orderedness, (in)equality and less/greater than (or equal to) as
+// well as predicates that are always true or false.
+def CMPF_P_FALSE   : I64EnumAttrCase<"AlwaysFalse", 0, "false">;
+def CMPF_P_OEQ     : I64EnumAttrCase<"OEQ", 1, "oeq">;
+def CMPF_P_OGT     : I64EnumAttrCase<"OGT", 2, "ogt">;
+def CMPF_P_OGE     : I64EnumAttrCase<"OGE", 3, "oge">;
+def CMPF_P_OLT     : I64EnumAttrCase<"OLT", 4, "olt">;
+def CMPF_P_OLE     : I64EnumAttrCase<"OLE", 5, "ole">;
+def CMPF_P_ONE     : I64EnumAttrCase<"ONE", 6, "one">;
+def CMPF_P_ORD     : I64EnumAttrCase<"ORD", 7, "ord">;
+def CMPF_P_UEQ     : I64EnumAttrCase<"UEQ", 8, "ueq">;
+def CMPF_P_UGT     : I64EnumAttrCase<"UGT", 9, "ugt">;
+def CMPF_P_UGE     : I64EnumAttrCase<"UGE", 10, "uge">;
+def CMPF_P_ULT     : I64EnumAttrCase<"ULT", 11, "ult">;
+def CMPF_P_ULE     : I64EnumAttrCase<"ULE", 12, "ule">;
+def CMPF_P_UNE     : I64EnumAttrCase<"UNE", 13, "une">;
+def CMPF_P_UNO     : I64EnumAttrCase<"UNO", 14, "uno">;
+def CMPF_P_TRUE    : I64EnumAttrCase<"AlwaysTrue", 15, "true">;
+
+def CmpFPredicateAttr : I64EnumAttr<
+    "CmpFPredicate", "",
+    [CMPF_P_FALSE, CMPF_P_OEQ, CMPF_P_OGT, CMPF_P_OGE, CMPF_P_OLT, CMPF_P_OLE,
+     CMPF_P_ONE, CMPF_P_ORD, CMPF_P_UEQ, CMPF_P_UGT, CMPF_P_UGE, CMPF_P_ULT,
+     CMPF_P_ULE, CMPF_P_UNE, CMPF_P_UNO, CMPF_P_TRUE]> {
+  let cppNamespace = "::mlir";
+}
+
 def CmpFOp : Std_Op<"cmpf",
     [NoSideEffect, SameTypeOperands, SameOperandsAndResultShape,
      TypesMatchWith<
@@ -461,7 +489,11 @@ def CmpFOp : Std_Op<"cmpf",
       %r3 = "std.cmpf"(%0, %1) {predicate: 0} : (f8, f8) -> i1
   }];
 
-  let arguments = (ins FloatLike:$lhs, FloatLike:$rhs);
+  let arguments = (ins
+    CmpFPredicateAttr:$predicate,
+    FloatLike:$lhs,
+    FloatLike:$rhs
+  );
   let results = (outs BoolLike:$result);
 
   let builders = [OpBuilder<
@@ -480,7 +512,11 @@ def CmpFOp : Std_Op<"cmpf",
     }
   }];
 
+  let verifier = [{ return success(); }];
+
   let hasFolder = 1;
+
+  let assemblyFormat = "$predicate `,` $lhs `,` $rhs attr-dict `:` type($lhs)";
 }
 
 def CMPI_P_EQ  : I64EnumAttrCase<"eq", 0>;

diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 4d88aa8e99da..4f16c76fb7d9 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -580,55 +580,6 @@ OpFoldResult CmpIOp::fold(ArrayRef<Attribute> operands) {
 // CmpFOp
 //===----------------------------------------------------------------------===//
 
-// Returns an array of mnemonics for CmpFPredicates indexed by values thereof.
-static inline const char *const *getCmpFPredicateNames() {
-  static const char *predicateNames[] = {
-      /*AlwaysFalse*/ "false",
-      /*OEQ*/ "oeq",
-      /*OGT*/ "ogt",
-      /*OGE*/ "oge",
-      /*OLT*/ "olt",
-      /*OLE*/ "ole",
-      /*ONE*/ "one",
-      /*ORD*/ "ord",
-      /*UEQ*/ "ueq",
-      /*UGT*/ "ugt",
-      /*UGE*/ "uge",
-      /*ULT*/ "ult",
-      /*ULE*/ "ule",
-      /*UNE*/ "une",
-      /*UNO*/ "uno",
-      /*AlwaysTrue*/ "true",
-  };
-  static_assert(std::extent<decltype(predicateNames)>::value ==
-                    (size_t)CmpFPredicate::NumPredicates,
-                "wrong number of predicate names");
-  return predicateNames;
-}
-
-// Returns a value of the predicate corresponding to the given mnemonic.
-// Returns NumPredicates (one-past-end) if there is no such mnemonic.
-CmpFPredicate CmpFOp::getPredicateByName(StringRef name) {
-  return llvm::StringSwitch<CmpFPredicate>(name)
-      .Case("false", CmpFPredicate::AlwaysFalse)
-      .Case("oeq", CmpFPredicate::OEQ)
-      .Case("ogt", CmpFPredicate::OGT)
-      .Case("oge", CmpFPredicate::OGE)
-      .Case("olt", CmpFPredicate::OLT)
-      .Case("ole", CmpFPredicate::OLE)
-      .Case("one", CmpFPredicate::ONE)
-      .Case("ord", CmpFPredicate::ORD)
-      .Case("ueq", CmpFPredicate::UEQ)
-      .Case("ugt", CmpFPredicate::UGT)
-      .Case("uge", CmpFPredicate::UGE)
-      .Case("ult", CmpFPredicate::ULT)
-      .Case("ule", CmpFPredicate::ULE)
-      .Case("une", CmpFPredicate::UNE)
-      .Case("uno", CmpFPredicate::UNO)
-      .Case("true", CmpFPredicate::AlwaysTrue)
-      .Default(CmpFPredicate::NumPredicates);
-}
-
 static void buildCmpFOp(Builder *build, OperationState &result,
                         CmpFPredicate predicate, Value lhs, Value rhs) {
   result.addOperands({lhs, rhs});
@@ -638,73 +589,8 @@ static void buildCmpFOp(Builder *build, OperationState &result,
       build->getI64IntegerAttr(static_cast<int64_t>(predicate)));
 }
 
-static ParseResult parseCmpFOp(OpAsmParser &parser, OperationState &result) {
-  SmallVector<OpAsmParser::OperandType, 2> ops;
-  SmallVector<NamedAttribute, 4> attrs;
-  Attribute predicateNameAttr;
-  Type type;
-  if (parser.parseAttribute(predicateNameAttr, CmpFOp::getPredicateAttrName(),
-                            attrs) ||
-      parser.parseComma() || parser.parseOperandList(ops, 2) ||
-      parser.parseOptionalAttrDict(attrs) || parser.parseColonType(type) ||
-      parser.resolveOperands(ops, type, result.operands))
-    return failure();
-
-  if (!predicateNameAttr.isa<StringAttr>())
-    return parser.emitError(parser.getNameLoc(),
-                            "expected string comparison predicate attribute");
-
-  // Rewrite string attribute to an enum value.
-  StringRef predicateName = predicateNameAttr.cast<StringAttr>().getValue();
-  auto predicate = CmpFOp::getPredicateByName(predicateName);
-  if (predicate == CmpFPredicate::NumPredicates)
-    return parser.emitError(parser.getNameLoc(),
-                            "unknown comparison predicate \"" + predicateName +
-                                "\"");
-
-  auto builder = parser.getBuilder();
-  Type i1Type = getCheckedI1SameShape(type);
-  if (!i1Type)
-    return parser.emitError(parser.getNameLoc(),
-                            "expected type with valid i1 shape");
-
-  attrs[0].second = builder.getI64IntegerAttr(static_cast<int64_t>(predicate));
-  result.attributes = attrs;
-
-  result.addTypes({i1Type});
-  return success();
-}
-
-static void print(OpAsmPrinter &p, CmpFOp op) {
-  p << "cmpf ";
-
-  auto predicateValue =
-      op.getAttrOfType<IntegerAttr>(CmpFOp::getPredicateAttrName()).getInt();
-  assert(predicateValue >= static_cast<int>(CmpFPredicate::FirstValidValue) &&
-         predicateValue < static_cast<int>(CmpFPredicate::NumPredicates) &&
-         "unknown predicate index");
-  p << '"' << getCmpFPredicateNames()[predicateValue] << '"' << ", " << op.lhs()
-    << ", " << op.rhs();
-  p.printOptionalAttrDict(op.getAttrs(),
-                          /*elidedAttrs=*/{CmpFOp::getPredicateAttrName()});
-  p << " : " << op.lhs().getType();
-}
-
-static LogicalResult verify(CmpFOp op) {
-  auto predicateAttr =
-      op.getAttrOfType<IntegerAttr>(CmpFOp::getPredicateAttrName());
-  if (!predicateAttr)
-    return op.emitOpError("requires an integer attribute named 'predicate'");
-  auto predicate = predicateAttr.getInt();
-  if (predicate < (int64_t)CmpFPredicate::FirstValidValue ||
-      predicate >= (int64_t)CmpFPredicate::NumPredicates)
-    return op.emitOpError("'predicate' attribute value out of range");
-
-  return success();
-}
-
-// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point
-// comparison predicates.
+/// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point
+/// comparison predicates.
 static bool applyCmpPredicate(CmpFPredicate predicate, const APFloat &lhs,
                               const APFloat &rhs) {
   auto cmpResult = lhs.compare(rhs);

diff  --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir
index 5b43103e9018..07b6b9f4b121 100644
--- a/mlir/test/IR/invalid-ops.mlir
+++ b/mlir/test/IR/invalid-ops.mlir
@@ -346,28 +346,28 @@ func @invalid_cmp_attr(%idx : i32) {
 // -----
 
 func @cmpf_generic_invalid_predicate_value(%a : f32) {
-  // expected-error at +1 {{'predicate' attribute value out of range}}
+  // expected-error at +1 {{attribute 'predicate' failed to satisfy constraint: allowed 64-bit integer cases}}
   %r = "std.cmpf"(%a, %a) {predicate = 42} : (f32, f32) -> i1
 }
 
 // -----
 
 func @cmpf_canonical_invalid_predicate_value(%a : f32) {
-  // expected-error at +1 {{unknown comparison predicate "foo"}}
+  // expected-error at +1 {{invalid predicate attribute specification: "foo"}}
   %r = cmpf "foo", %a, %a : f32
 }
 
 // -----
 
 func @cmpf_canonical_invalid_predicate_value_signed(%a : f32) {
-  // expected-error at +1 {{unknown comparison predicate "sge"}}
+  // expected-error at +1 {{invalid predicate attribute specification: "sge"}}
   %r = cmpf "sge", %a, %a : f32
 }
 
 // -----
 
 func @cmpf_canonical_invalid_predicate_value_no_order(%a : f32) {
-  // expected-error at +1 {{unknown comparison predicate "eq"}}
+  // expected-error at +1 {{invalid predicate attribute specification: "eq"}}
   %r = cmpf "eq", %a, %a : f32
 }
 
@@ -380,14 +380,14 @@ func @cmpf_canonical_no_predicate_attr(%a : f32, %b : f32) {
 // -----
 
 func @cmpf_generic_no_predicate_attr(%a : f32, %b : f32) {
-  // expected-error at +1 {{requires an integer attribute named 'predicate'}}
+  // expected-error at +1 {{requires attribute 'predicate'}}
   %r = "std.cmpf"(%a, %b) {foo = 1} : (f32, f32) -> i1
 }
 
 // -----
 
 func @cmpf_wrong_type(%a : i32, %b : i32) {
-  %r = cmpf "oeq", %a, %b : i32 // expected-error {{operand #0 must be floating-point-like}}
+  %r = cmpf "oeq", %a, %b : i32 // expected-error {{must be floating-point-like}}
 }
 
 // -----


        


More information about the Mlir-commits mailing list