[Mlir-commits] [mlir] 1bd1eda - [mlir:ODS] Support using attributes in AllTypesMatch to automatically add InferTypeOpInterface

River Riddle llvmlistbot at llvm.org
Thu Apr 28 12:58:38 PDT 2022


Author: River Riddle
Date: 2022-04-28T12:57:59-07:00
New Revision: 1bd1edaf4006ff66a88ac59e0931f22105003a26

URL: https://github.com/llvm/llvm-project/commit/1bd1edaf4006ff66a88ac59e0931f22105003a26
DIFF: https://github.com/llvm/llvm-project/commit/1bd1edaf4006ff66a88ac59e0931f22105003a26.diff

LOG: [mlir:ODS] Support using attributes in AllTypesMatch to automatically add InferTypeOpInterface

This allows for using attribute types in result type inference for use with
InferTypeOpInterface. This was a TODO before, but it isn't much
additional work to properly support this. After this commit,
arith::ConstantOp can now have its InferTypeOpInterface implementation automatically
generated.

Differential Revision: https://reviews.llvm.org/D124580

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
    mlir/include/mlir/TableGen/CodeGenHelpers.h
    mlir/include/mlir/TableGen/Constraint.h
    mlir/lib/TableGen/Constraint.cpp
    mlir/lib/TableGen/Operator.cpp
    mlir/python/mlir/dialects/_arith_ops_ext.py
    mlir/test/Dialect/Arithmetic/invalid.mlir
    mlir/test/IR/diagnostic-handler.mlir
    mlir/test/mlir-tblgen/op-result.td
    mlir/tools/mlir-tblgen/CodeGenHelpers.cpp
    mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
index 6305c6947c75..a46c5119296b 100644
--- a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
@@ -124,9 +124,7 @@ class Arith_CompareOpOfAnyRank<string mnemonic, list<Trait> traits = []> :
 def Arith_ConstantOp : Op<Arithmetic_Dialect, "constant",
     [ConstantLike, NoSideEffect,
      DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
-     TypesMatchWith<
-    "result and attribute have the same type",
-    "value", "result", "$_self">]> {
+     AllTypesMatch<["value", "result"]>]> {
   let summary = "integer or floating point constant";
   let description = [{
     The `constant` operation produces an SSA value equal to some integer or
@@ -154,8 +152,6 @@ def Arith_ConstantOp : Op<Arithmetic_Dialect, "constant",
   let results = (outs /*SignlessIntegerOrFloatLike*/AnyType:$result);
 
   let builders = [
-    OpBuilder<(ins "Attribute":$value),
-    [{ build($_builder, $_state, value.getType(), value); }]>,
     OpBuilder<(ins "Attribute":$value, "Type":$type),
     [{ build($_builder, $_state, type, value); }]>,
   ];

diff  --git a/mlir/include/mlir/TableGen/CodeGenHelpers.h b/mlir/include/mlir/TableGen/CodeGenHelpers.h
index d4d3294ea2a6..69b3a897fc38 100644
--- a/mlir/include/mlir/TableGen/CodeGenHelpers.h
+++ b/mlir/include/mlir/TableGen/CodeGenHelpers.h
@@ -187,19 +187,9 @@ class StaticVerifierFunctionEmitter {
   /// ensure that the static functions have a unique name.
   std::string uniqueOutputLabel;
 
-  /// Unique constraints by their predicate and summary. Constraints that share
-  /// the same predicate may have 
diff erent descriptions; ensure that the
-  /// correct error message is reported when verification fails.
-  struct ConstraintUniquer {
-    static Constraint getEmptyKey();
-    static Constraint getTombstoneKey();
-    static unsigned getHashValue(Constraint constraint);
-    static bool isEqual(Constraint lhs, Constraint rhs);
-  };
   /// Use a MapVector to ensure that functions are generated deterministically.
-  using ConstraintMap =
-      llvm::MapVector<Constraint, std::string,
-                      llvm::DenseMap<Constraint, unsigned, ConstraintUniquer>>;
+  using ConstraintMap = llvm::MapVector<Constraint, std::string,
+                                        llvm::DenseMap<Constraint, unsigned>>;
 
   /// A generic function to emit constraints
   void emitConstraints(const ConstraintMap &constraints, StringRef selfName,

diff  --git a/mlir/include/mlir/TableGen/Constraint.h b/mlir/include/mlir/TableGen/Constraint.h
index 0f6c2b58faaf..0c74f89189e9 100644
--- a/mlir/include/mlir/TableGen/Constraint.h
+++ b/mlir/include/mlir/TableGen/Constraint.h
@@ -94,4 +94,20 @@ struct AppliedConstraint {
 } // namespace tblgen
 } // namespace mlir
 
+namespace llvm {
+/// Unique constraints by their predicate and summary. Constraints that share
+/// the same predicate may have 
diff erent descriptions; ensure that the
+/// correct error message is reported when verification fails.
+template <>
+struct DenseMapInfo<mlir::tblgen::Constraint> {
+  using RecordDenseMapInfo = llvm::DenseMapInfo<const llvm::Record *>;
+
+  static mlir::tblgen::Constraint getEmptyKey();
+  static mlir::tblgen::Constraint getTombstoneKey();
+  static unsigned getHashValue(mlir::tblgen::Constraint constraint);
+  static bool isEqual(mlir::tblgen::Constraint lhs,
+                      mlir::tblgen::Constraint rhs);
+};
+} // namespace llvm
+
 #endif // MLIR_TABLEGEN_CONSTRAINT_H_

diff  --git a/mlir/lib/TableGen/Constraint.cpp b/mlir/lib/TableGen/Constraint.cpp
index 0c5e034a8ee6..8e62120ad837 100644
--- a/mlir/lib/TableGen/Constraint.cpp
+++ b/mlir/lib/TableGen/Constraint.cpp
@@ -108,3 +108,34 @@ AppliedConstraint::AppliedConstraint(Constraint &&constraint,
                                      std::vector<std::string> &&entities)
     : constraint(constraint), self(std::string(self)),
       entities(std::move(entities)) {}
+
+Constraint DenseMapInfo<Constraint>::getEmptyKey() {
+  return Constraint(RecordDenseMapInfo::getEmptyKey(),
+                    Constraint::CK_Uncategorized);
+}
+
+Constraint DenseMapInfo<Constraint>::getTombstoneKey() {
+  return Constraint(RecordDenseMapInfo::getTombstoneKey(),
+                    Constraint::CK_Uncategorized);
+}
+
+unsigned DenseMapInfo<Constraint>::getHashValue(Constraint constraint) {
+  if (constraint == getEmptyKey())
+    return RecordDenseMapInfo::getHashValue(RecordDenseMapInfo::getEmptyKey());
+  if (constraint == getTombstoneKey()) {
+    return RecordDenseMapInfo::getHashValue(
+        RecordDenseMapInfo::getTombstoneKey());
+  }
+  return llvm::hash_combine(constraint.getPredicate(), constraint.getSummary());
+}
+
+bool DenseMapInfo<Constraint>::isEqual(Constraint lhs, Constraint rhs) {
+  if (lhs == rhs)
+    return true;
+  if (lhs == getEmptyKey() || lhs == getTombstoneKey())
+    return false;
+  if (rhs == getEmptyKey() || rhs == getTombstoneKey())
+    return false;
+  return lhs.getPredicate() == rhs.getPredicate() &&
+         lhs.getSummary() == rhs.getSummary();
+}

diff  --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp
index 2a0d49fcfccf..35afb8d7f694 100644
--- a/mlir/lib/TableGen/Operator.cpp
+++ b/mlir/lib/TableGen/Operator.cpp
@@ -357,10 +357,6 @@ void Operator::populateTypeInferenceInfo(
         continue;
       }
 
-      if (getArg(*mi).is<NamedAttribute *>()) {
-        // TODO: Handle attributes.
-        continue;
-      }
       resultTypeMapping[i].emplace_back(*mi);
       found = true;
     }

diff  --git a/mlir/python/mlir/dialects/_arith_ops_ext.py b/mlir/python/mlir/dialects/_arith_ops_ext.py
index e35f5f2a4794..c755df255c1e 100644
--- a/mlir/python/mlir/dialects/_arith_ops_ext.py
+++ b/mlir/python/mlir/dialects/_arith_ops_ext.py
@@ -41,11 +41,11 @@ def __init__(self,
                loc=None,
                ip=None):
     if isinstance(value, int):
-      super().__init__(result, IntegerAttr.get(result, value), loc=loc, ip=ip)
+      super().__init__(IntegerAttr.get(result, value), loc=loc, ip=ip)
     elif isinstance(value, float):
-      super().__init__(result, FloatAttr.get(result, value), loc=loc, ip=ip)
+      super().__init__(FloatAttr.get(result, value), loc=loc, ip=ip)
     else:
-      super().__init__(result, value, loc=loc, ip=ip)
+      super().__init__(value, loc=loc, ip=ip)
 
   @classmethod
   def create_index(cls, value: int, *, loc=None, ip=None):

diff  --git a/mlir/test/Dialect/Arithmetic/invalid.mlir b/mlir/test/Dialect/Arithmetic/invalid.mlir
index 71014b10726f..47f5f1f511c1 100644
--- a/mlir/test/Dialect/Arithmetic/invalid.mlir
+++ b/mlir/test/Dialect/Arithmetic/invalid.mlir
@@ -25,7 +25,7 @@ func.func @non_signless_constant() {
 // -----
 
 func.func @complex_constant_wrong_attribute_type() {
-  // expected-error @+1 {{'arith.constant' op failed to verify that result and attribute have the same type}}
+  // expected-error @+1 {{'arith.constant' op failed to verify that all of {value, result} have same type}}
   %0 = "arith.constant" () {value = 1.0 : f32} : () -> complex<f32>
   return
 }
@@ -50,7 +50,7 @@ func.func @bitcast_
diff erent_bit_widths(%arg : f16) -> f32 {
 
 func.func @constant() {
 ^bb:
-  %x = "arith.constant"(){value = "xyz"} : () -> i32 // expected-error {{'arith.constant' op failed to verify that result and attribute have the same type}}
+  %x = "arith.constant"(){value = "xyz"} : () -> i32 // expected-error {{'arith.constant' op failed to verify that all of {value, result} have same type}}
   return
 }
 
@@ -58,7 +58,7 @@ func.func @constant() {
 
 func.func @constant_out_of_range() {
 ^bb:
-  %x = "arith.constant"(){value = 100} : () -> i1 // expected-error {{'arith.constant' op failed to verify that result and attribute have the same type}}
+  %x = "arith.constant"(){value = 100} : () -> i1 // expected-error {{'arith.constant' op failed to verify that all of {value, result} have same type}}
   return
 }
 
@@ -66,7 +66,7 @@ func.func @constant_out_of_range() {
 
 func.func @constant_wrong_type() {
 ^bb:
-  %x = "arith.constant"(){value = 10.} : () -> f32 // expected-error {{'arith.constant' op failed to verify that result and attribute have the same type}}
+  %x = "arith.constant"(){value = 10.} : () -> f32 // expected-error {{'arith.constant' op failed to verify that all of {value, result} have same type}}
   return
 }
 

diff  --git a/mlir/test/IR/diagnostic-handler.mlir b/mlir/test/IR/diagnostic-handler.mlir
index f94a632bbb8f..592656cefeb5 100644
--- a/mlir/test/IR/diagnostic-handler.mlir
+++ b/mlir/test/IR/diagnostic-handler.mlir
@@ -5,7 +5,7 @@
 
 // Emit the first available call stack in the fused location.
 func.func @constant_out_of_range() {
-  // CHECK: mysource1:0:0: error: 'arith.constant' op failed to verify that result and attribute have the same type
+  // CHECK: mysource1:0:0: error: 'arith.constant' op failed to verify that all of {value, result} have same type
   // CHECK-NEXT: mysource2:1:0: note: called from
   // CHECK-NEXT: mysource3:2:0: note: called from
   %x = "arith.constant"() {value = 100} : () -> i1 loc(fused["bar", callsite("foo"("mysource1":0:0) at callsite("mysource2":1:0 at "mysource3":2:0))])

diff  --git a/mlir/test/mlir-tblgen/op-result.td b/mlir/test/mlir-tblgen/op-result.td
index a190a47a9eea..d4d8746ac441 100644
--- a/mlir/test/mlir-tblgen/op-result.td
+++ b/mlir/test/mlir-tblgen/op-result.td
@@ -123,7 +123,8 @@ def OpL1 : NS_Op<"op_with_all_types_constraint",
 
 // CHECK-LABEL: LogicalResult OpL1::inferReturnTypes
 // CHECK-NOT: }
-// CHECK: inferredReturnTypes[0] = operands[0].getType();
+// CHECK: ::mlir::Type odsInferredType0 = operands[0].getType();
+// CHECK: inferredReturnTypes[0] = odsInferredType0;
 
 def OpL2 : NS_Op<"op_with_all_types_constraint",
     [AllTypesMatch<["c", "b"]>, AllTypesMatch<["a", "d"]>]> {
@@ -133,5 +134,18 @@ def OpL2 : NS_Op<"op_with_all_types_constraint",
 
 // CHECK-LABEL: LogicalResult OpL2::inferReturnTypes
 // CHECK-NOT: }
-// CHECK: inferredReturnTypes[0] = operands[2].getType();
-// CHECK: inferredReturnTypes[1] = operands[0].getType();
+// CHECK: ::mlir::Type odsInferredType0 = operands[2].getType();
+// CHECK: ::mlir::Type odsInferredType1 = operands[0].getType();
+// CHECK: inferredReturnTypes[0] = odsInferredType0;
+// CHECK: inferredReturnTypes[1] = odsInferredType1;
+
+def OpL3 : NS_Op<"op_with_all_types_constraint",
+    [AllTypesMatch<["a", "b"]>]> {
+  let arguments = (ins I32Attr:$a);
+  let results = (outs AnyType:$b);
+}
+
+// CHECK-LABEL: LogicalResult OpL3::inferReturnTypes
+// CHECK-NOT: }
+// CHECK: ::mlir::Type odsInferredType0 = attributes.get("a").getType();
+// CHECK: inferredReturnTypes[0] = odsInferredType0;

diff  --git a/mlir/tools/mlir-tblgen/CodeGenHelpers.cpp b/mlir/tools/mlir-tblgen/CodeGenHelpers.cpp
index 7fb6d953b47f..80fb27b76cad 100644
--- a/mlir/tools/mlir-tblgen/CodeGenHelpers.cpp
+++ b/mlir/tools/mlir-tblgen/CodeGenHelpers.cpp
@@ -234,41 +234,6 @@ void StaticVerifierFunctionEmitter::emitPatternConstraints() {
 //===----------------------------------------------------------------------===//
 // Constraint Uniquing
 
-using RecordDenseMapInfo = llvm::DenseMapInfo<const llvm::Record *>;
-
-Constraint StaticVerifierFunctionEmitter::ConstraintUniquer::getEmptyKey() {
-  return Constraint(RecordDenseMapInfo::getEmptyKey(),
-                    Constraint::CK_Uncategorized);
-}
-
-Constraint StaticVerifierFunctionEmitter::ConstraintUniquer::getTombstoneKey() {
-  return Constraint(RecordDenseMapInfo::getTombstoneKey(),
-                    Constraint::CK_Uncategorized);
-}
-
-unsigned StaticVerifierFunctionEmitter::ConstraintUniquer::getHashValue(
-    Constraint constraint) {
-  if (constraint == getEmptyKey())
-    return RecordDenseMapInfo::getHashValue(RecordDenseMapInfo::getEmptyKey());
-  if (constraint == getTombstoneKey()) {
-    return RecordDenseMapInfo::getHashValue(
-        RecordDenseMapInfo::getTombstoneKey());
-  }
-  return llvm::hash_combine(constraint.getPredicate(), constraint.getSummary());
-}
-
-bool StaticVerifierFunctionEmitter::ConstraintUniquer::isEqual(Constraint lhs,
-                                                               Constraint rhs) {
-  if (lhs == rhs)
-    return true;
-  if (lhs == getEmptyKey() || lhs == getTombstoneKey())
-    return false;
-  if (rhs == getEmptyKey() || rhs == getTombstoneKey())
-    return false;
-  return lhs.getPredicate() == rhs.getPredicate() &&
-         lhs.getSummary() == rhs.getSummary();
-}
-
 /// An attribute constraint that references anything other than itself and the
 /// current op cannot be generically extracted into a function. Most
 /// prohibitive are operands and results, which require calls to

diff  --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 6f5f82aa7dd5..c80fe5a22641 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -2336,23 +2336,60 @@ void OpEmitter::genTypeInterfaceMethods() {
   fctx.withBuilder("odsBuilder");
   body << "  ::mlir::Builder odsBuilder(context);\n";
 
-  auto emitType = [&](const tblgen::Operator::ArgOrType &type) -> MethodBody & {
-    if (!type.isArg())
-      return body << tgfmt(*type.getType().getBuilderCall(), &fctx);
-    auto argIndex = type.getArg();
-    assert(!op.getArg(argIndex).is<NamedAttribute *>());
+  // Preprocess the result types and build all of the types used during
+  // inferrence. This limits the amount of duplicated work when a type is used
+  // to infer multiple others.
+  llvm::DenseMap<Constraint, int> constraintsTypes;
+  llvm::DenseMap<int, int> argumentsTypes;
+  int inferredTypeIdx = 0;
+  for (int i = 0, e = op.getNumResults(); i != e; ++i) {
+    auto type = op.getSameTypeAsResult(i).front();
+
+    // If the type isn't an argument, it refers to a buildable type.
+    if (!type.isArg()) {
+      auto it = constraintsTypes.try_emplace(type.getType(), inferredTypeIdx);
+      if (!it.second)
+        continue;
+
+      // If we haven't seen this constraint, generate a variable for it.
+      body << "  ::mlir::Type odsInferredType" << inferredTypeIdx++ << " = "
+           << tgfmt(*type.getType().getBuilderCall(), &fctx) << ";\n";
+      continue;
+    }
+
+    // Otherwise, this is an argument.
+    int argIndex = type.getArg();
+    auto it = argumentsTypes.try_emplace(argIndex, inferredTypeIdx);
+    if (!it.second)
+      continue;
+    body << "  ::mlir::Type odsInferredType" << inferredTypeIdx++ << " = ";
+
+    // If this is an operand, just index into operand list to access the type.
     auto arg = op.getArgToOperandOrAttribute(argIndex);
-    if (arg.kind() == Operator::OperandOrAttribute::Kind::Operand)
-      return body << "operands[" << arg.operandOrAttributeIndex()
-                  << "].getType()";
-    return body << "attributes[" << arg.operandOrAttributeIndex()
-                << "].getType()";
-  };
+    if (arg.kind() == Operator::OperandOrAttribute::Kind::Operand) {
+      body << "operands[" << arg.operandOrAttributeIndex() << "].getType()";
+
+      // If this is an attribute, index into the attribute dictionary.
+    } else {
+      auto *attr =
+          op.getArg(arg.operandOrAttributeIndex()).get<NamedAttribute *>();
+      body << "attributes.get(\"" << attr->name << "\").getType()";
+    }
+    body << ";\n";
+  }
 
+  // Perform a second pass that handles assigning the inferred types to the
+  // results.
   for (int i = 0, e = op.getNumResults(); i != e; ++i) {
-    body << "  inferredReturnTypes[" << i << "] = ";
     auto types = op.getSameTypeAsResult(i);
-    emitType(types[0]) << ";\n";
+
+    // Append the inferred type.
+    auto type = types.front();
+    body << "  inferredReturnTypes[" << i << "] = odsInferredType"
+         << (type.isArg() ? argumentsTypes[type.getArg()]
+                          : constraintsTypes[type.getType()])
+         << ";\n";
+
     if (types.size() == 1)
       continue;
     // TODO: We could verify equality here, but skipping that for verification.


        


More information about the Mlir-commits mailing list