[Mlir-commits] [mlir] 1b60f0d - [mlir][ods] Generate inferReturnTypes for ops with TypesMatchWith
Jeff Niu
llvmlistbot at llvm.org
Thu Jan 12 13:26:17 PST 2023
Author: Jeff Niu
Date: 2023-01-12T13:26:12-08:00
New Revision: 1b60f0d73c34fec4648bb05f98db75008a50f4d8
URL: https://github.com/llvm/llvm-project/commit/1b60f0d73c34fec4648bb05f98db75008a50f4d8
DIFF: https://github.com/llvm/llvm-project/commit/1b60f0d73c34fec4648bb05f98db75008a50f4d8.diff
LOG: [mlir][ods] Generate inferReturnTypes for ops with TypesMatchWith
Ops that use TypesMatchWith to constrain result types for verification
and to infer result types during parser generation should also be able
to have the `inferReturnTypes` method auto generated. This patch
upgrades the logic for generating `inferReturnTypes` to handle the
TypesMatchWith trait by building a type inference graph where each edge
corresponds to "type of A can be inferred from type of B", supporting
transformers other than `"$_self"`.
Reviewed By: lattner, rriddle
Differential Revision: https://reviews.llvm.org/D141231
Added:
Modified:
mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
mlir/include/mlir/TableGen/Operator.h
mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
mlir/lib/TableGen/Operator.cpp
mlir/test/mlir-tblgen/op-result.td
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index b531b5ed9f288..a4e542658ec77 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1305,13 +1305,6 @@ def Arith_CmpIOp
SignlessIntegerLikeOfAnyRank:$lhs,
SignlessIntegerLikeOfAnyRank:$rhs);
- let builders = [
- OpBuilder<(ins "CmpIPredicate":$predicate, "Value":$lhs, "Value":$rhs), [{
- build($_builder, $_state, ::getI1SameShape(lhs.getType()),
- predicate, lhs, rhs);
- }]>
- ];
-
let extraClassDeclaration = [{
static arith::CmpIPredicate getPredicateByName(StringRef name);
}];
@@ -1356,13 +1349,6 @@ def Arith_CmpFOp : Arith_CompareOp<"cmpf"> {
FloatLike:$lhs,
FloatLike:$rhs);
- let builders = [
- OpBuilder<(ins "CmpFPredicate":$predicate, "Value":$lhs, "Value":$rhs), [{
- build($_builder, $_state, ::getI1SameShape(lhs.getType()),
- predicate, lhs, rhs);
- }]>
- ];
-
let extraClassDeclaration = [{
static arith::CmpFPredicate getPredicateByName(StringRef name);
}];
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index 74ed475dc74d6..279f5973bd807 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -294,12 +294,6 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
"the reference to load from", [MemRead]>:$memref);
let results = (outs AnyTensor:$result);
- let builders = [
- OpBuilder<(ins "Value":$memref), [{
- $_state.addOperands(memref);
- $_state.addTypes(memref::getTensorTypeFromMemRefType(memref.getType()));
- }]>];
-
let extraClassDeclaration = [{
/// The result of a to_tensor is always a tensor.
TensorType getType() {
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 63c4fea362048..5f96a4682107e 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -119,9 +119,6 @@ def LLVM_ICmpOp : LLVM_ArithmeticCmpOp<"icmp", [Pure]> {
LLVM_ScalarOrVectorOf<LLVM_AnyPointer>]>:$lhs,
AnyTypeOf<[LLVM_ScalarOrVectorOf<AnyInteger>,
LLVM_ScalarOrVectorOf<LLVM_AnyPointer>]>:$rhs);
- let builders = [
- OpBuilder<(ins "ICmpPredicate":$predicate, "Value":$lhs, "Value":$rhs)>
- ];
let hasCustomAssemblyFormat = 1;
string llvmInstName = "ICmp";
string llvmBuilder = [{
@@ -145,9 +142,6 @@ def LLVM_FCmpOp : LLVM_ArithmeticCmpOp<"fcmp", [
LLVM_ScalarOrVectorOf<LLVM_AnyFloat>:$rhs,
DefaultValuedAttr<LLVM_FastmathFlagsAttr,
"{}">:$fastmathFlags);
- let builders = [
- OpBuilder<(ins "FCmpPredicate":$predicate, "Value":$lhs, "Value":$rhs)>
- ];
let hasCustomAssemblyFormat = 1;
string llvmInstName = "FCmp";
string llvmBuilder = [{
@@ -583,11 +577,6 @@ def LLVM_ExtractElementOp : LLVM_Op<"extractelement", [Pure,
let arguments = (ins LLVM_AnyVector:$vector, AnyInteger:$position);
let results = (outs LLVM_Type:$res);
- let builders = [
- OpBuilder<(ins "Value":$vector, "Value":$position,
- CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>
- ];
-
let assemblyFormat = [{
$vector `[` $position `:` type($position) `]` attr-dict `:` type($vector)
}];
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index c88450e42f9ab..121deb7839942 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1158,14 +1158,6 @@ def LoadOp : MemRef_Op<"load",
Variadic<Index>:$indices);
let results = (outs AnyType:$result);
- let builders = [
- OpBuilder<(ins "Value":$memref, CArg<"ValueRange", "{}">:$indices), [{
- auto memrefType = memref.getType().cast<MemRefType>();
- $_state.addOperands(memref);
- $_state.addOperands(indices);
- $_state.types.push_back(memrefType.getElementType());
- }]>];
-
let extraClassDeclaration = [{
Value getMemRef() { return getOperand(0); }
void setMemRef(Value value) { setOperand(0, value); }
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
index 02959ae9940e3..d6e90724edc27 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
@@ -30,11 +30,6 @@ class SPIRV_LogicalBinaryOp<string mnemonic, Type operandsType,
"getUnaryOpResultType($_self)"
>])> {
let assemblyFormat = "$operand1 `,` $operand2 `:` type($operand1) attr-dict";
-
- let builders = [
- OpBuilder<(ins "Value":$lhs, "Value":$rhs),
- [{::buildLogicalBinaryOp($_builder, $_state, lhs, rhs);}]>
- ];
}
class SPIRV_LogicalUnaryOp<string mnemonic, Type operandType,
@@ -49,11 +44,6 @@ class SPIRV_LogicalUnaryOp<string mnemonic, Type operandType,
"getUnaryOpResultType($_self)"
>])> {
let assemblyFormat = "$operand `:` type($operand) attr-dict";
-
- let builders = [
- OpBuilder<(ins "Value":$value),
- [{::buildLogicalUnaryOp($_builder, $_state, value);}]>
- ];
}
// -----
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 3535146e3c05c..ccd53b7f0bf51 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -237,12 +237,6 @@ def Tensor_ExtractOp : Tensor_Op<"extract", [
let results = (outs AnyType:$result);
let assemblyFormat = "$tensor `[` $indices `]` attr-dict `:` type($tensor)";
- let builders = [
- OpBuilder<(ins "Value":$tensor, CArg<"ValueRange", "{}">:$indices), [{
- auto resType = tensor.getType().cast<ShapedType>().getElementType();
- build($_builder, $_state, resType, tensor, indices);
- }]>];
-
let hasCanonicalizer = 1;
let hasFolder = 1;
let hasVerifier = 1;
@@ -292,7 +286,7 @@ def Tensor_ExtractSliceOp : Tensor_OpWithOffsetSizesAndStrides<"extract_slice",
between
diff erent flavors of ops on that operate on tensors.
#### Verification vs Inference in the rank-reduced case
-
+
Note that there may be multiple ways to infer a resulting rank-reduced type.
e.g. 1x6x1 could potentially rank-reduce to either 1x6 or 6x1 2-D shapes.
@@ -724,13 +718,6 @@ def Tensor_InsertOp : Tensor_Op<"insert", [
$scalar `into` $dest `[` $indices `]` attr-dict `:` type($dest)
}];
- let builders = [
- OpBuilder<(ins "Value":$scalar, "Value":$dest,
- CArg<"ValueRange", "{}">:$indices), [{
- auto resType = dest.getType();
- build($_builder, $_state, resType, scalar, dest, indices);
- }]>];
-
let extraClassDeclaration = [{
std::pair<int64_t, int64_t> getDpsInitsPositionRange() {
return {1, 2}; // `dest` operand
@@ -795,7 +782,7 @@ def Tensor_InsertSliceOp : Tensor_OpWithOffsetSizesAndStrides<"insert_slice", [
behavior of tensor.extract_slice.
#### Verification in the rank-reduced case
-
+
The same verification discussion and mechanisms apply as for ExtractSliceOp.
Unlike ExtractSliceOp however, there is no need for a specific inference.
@@ -1399,7 +1386,7 @@ def Tensor_ParallelInsertSliceOp : Tensor_Op<"parallel_insert_slice", [
rank-reducing behavior of tensor.insert_slice and tensor.extract_slice.
#### Verification in the rank-reduced case
-
+
The same verification discussion and mechanisms apply as for ExtractSliceOp.
Unlike ExtractSliceOp however, there is no need for a specific inference.
}];
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index df5b7c597f6ab..04af8d3d80af2 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -576,8 +576,6 @@ def Vector_ExtractElementOp :
let builders = [
// 0-D builder.
OpBuilder<(ins "Value":$source)>,
- // 1-D + position builder.
- OpBuilder<(ins "Value":$source, "Value":$position)>,
];
let extraClassDeclaration = [{
VectorType getVectorType() {
diff --git a/mlir/include/mlir/TableGen/Operator.h b/mlir/include/mlir/TableGen/Operator.h
index f4a475d60700b..88dff220d6378 100644
--- a/mlir/include/mlir/TableGen/Operator.h
+++ b/mlir/include/mlir/TableGen/Operator.h
@@ -37,6 +37,39 @@ class StringInit;
namespace mlir {
namespace tblgen {
+/// This class represents an inferred result type. The result type can be
+/// inferred from an argument or result type. If it is inferred from another
+/// result type, that type must be buildable or inferred from yet another type.
+class InferredResultType {
+public:
+ InferredResultType(int index, std::string transformer)
+ : index(index), transformer(std::move(transformer)) {}
+
+ /// Returns true if result type is inferred from an argument type.
+ bool isArg() const { return isArgIndex(index); }
+ /// Return the mapped argument or result index.
+ int getIndex() const { return index; }
+ /// If the type is inferred from a result, return the result index.
+ int getResultIndex() const { return unmapResultIndex(index); }
+
+ // Mapping from result index to combined argument and result index.
+ // Arguments are indexed to match getArg index, while the result indexes are
+ // mapped to avoid overlap.
+ static int mapResultIndex(int i) { return -1 - i; }
+ static int unmapResultIndex(int i) { return -i - 1; }
+ static bool isResultIndex(int i) { return i < 0; }
+ static bool isArgIndex(int i) { return i >= 0; }
+
+ StringRef getTransformer() const { return transformer; }
+
+private:
+ /// The index of the source argument or result.
+ int index;
+
+ /// The transfer to apply to the type to obtain the inferred type.
+ std::string transformer;
+};
+
/// Wrapper class that contains a MLIR op's information (e.g., operands,
/// attributes) defined in TableGen and provides helper methods for
/// accessing them.
@@ -259,32 +292,9 @@ class Operator {
/// Return whether all the result types are known.
bool allResultTypesKnown() const { return allResultsHaveKnownTypes; };
- /// Pair representing either a index to an argument or a type constraint. Only
- /// one of these entries should have the non-default value.
- struct ArgOrType {
- explicit ArgOrType(int index) : index(index), constraint(std::nullopt) {}
- explicit ArgOrType(TypeConstraint constraint)
- : index(std::nullopt), constraint(constraint) {}
- bool isArg() const {
- assert(constraint.has_value() ^ index.has_value());
- return index.has_value();
- }
- bool isType() const {
- assert(constraint.has_value() ^ index.has_value());
- return constraint.has_value();
- }
-
- int getArg() const { return *index; }
- TypeConstraint getType() const { return *constraint; }
-
- private:
- std::optional<int> index;
- std::optional<TypeConstraint> constraint;
- };
-
- /// Return all arguments or type constraints with same type as result[index].
+ /// Return all arguments or type constraints with same type as result[index].
/// Requires: all result types are known.
- ArrayRef<ArgOrType> getSameTypeAsResult(int index) const;
+ const InferredResultType &getInferredResultType(int index) const;
/// Pair consisting kind of argument and index into operands or attributes.
struct OperandOrAttribute {
@@ -359,7 +369,7 @@ class Operator {
SmallVector<NamedRegion, 1> regions;
/// The argument with the same type as the result.
- SmallVector<SmallVector<ArgOrType, 2>, 4> resultTypeMapping;
+ SmallVector<InferredResultType> resultTypeMapping;
/// Map from argument to attribute or operand number.
SmallVector<OperandOrAttribute, 4> attrOrOperandMapping;
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 82ec921497540..2664996e03b7a 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -298,7 +298,7 @@ static BufferAllocs allocBuffers(OpBuilder &b, OpTy xferOp) {
auto maskBuffer = b.create<memref::AllocaOp>(loc, maskType);
b.setInsertionPoint(xferOp);
b.create<memref::StoreOp>(loc, xferOp.getMask(), maskBuffer);
- result.maskBuffer = b.create<memref::LoadOp>(loc, maskBuffer);
+ result.maskBuffer = b.create<memref::LoadOp>(loc, maskBuffer, ValueRange());
}
return result;
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 4a6ed4c5c7fe5..09b118d717949 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -101,16 +101,6 @@ static Type getI1SameShape(Type type) {
// Printing, parsing and builder for LLVM::CmpOp.
//===----------------------------------------------------------------------===//
-void ICmpOp::build(OpBuilder &builder, OperationState &result,
- ICmpPredicate predicate, Value lhs, Value rhs) {
- build(builder, result, getI1SameShape(lhs.getType()), predicate, lhs, rhs);
-}
-
-void FCmpOp::build(OpBuilder &builder, OperationState &result,
- FCmpPredicate predicate, Value lhs, Value rhs) {
- build(builder, result, getI1SameShape(lhs.getType()), predicate, lhs, rhs);
-}
-
void ICmpOp::print(OpAsmPrinter &p) {
p << " \"" << stringifyICmpPredicate(getPredicate()) << "\" " << getOperand(0)
<< ", " << getOperand(1);
@@ -1372,20 +1362,6 @@ ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
return success();
}
-//===----------------------------------------------------------------------===//
-// ExtractElementOp
-//===----------------------------------------------------------------------===//
-
-/// Expects vector to be an LLVM vector type and position to be an integer type.
-void LLVM::ExtractElementOp::build(OpBuilder &b, OperationState &result,
- Value vector, Value position,
- ArrayRef<NamedAttribute> attrs) {
- auto vectorType = vector.getType();
- auto llvmType = LLVM::getVectorElementType(vectorType);
- build(b, result, llvmType, vector, position);
- result.addAttributes(attrs);
-}
-
//===----------------------------------------------------------------------===//
// ExtractValueOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 09a907e252110..7b9f9e41f5e85 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -1005,28 +1005,6 @@ static LogicalResult verifyShiftOp(Operation *op) {
return success();
}
-static void buildLogicalBinaryOp(OpBuilder &builder, OperationState &state,
- Value lhs, Value rhs) {
- assert(lhs.getType() == rhs.getType());
-
- Type boolType = builder.getI1Type();
- if (auto vecType = lhs.getType().dyn_cast<VectorType>())
- boolType = VectorType::get(vecType.getShape(), boolType);
- state.addTypes(boolType);
-
- state.addOperands({lhs, rhs});
-}
-
-static void buildLogicalUnaryOp(OpBuilder &builder, OperationState &state,
- Value value) {
- Type boolType = builder.getI1Type();
- if (auto vecType = value.getType().dyn_cast<VectorType>())
- boolType = VectorType::get(vecType.getShape(), boolType);
- state.addTypes(boolType);
-
- state.addOperands(value);
-}
-
//===----------------------------------------------------------------------===//
// spirv.AccessChainOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index d5975876cfb2e..cc56ce6271e88 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1015,12 +1015,6 @@ void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result,
result.addTypes(source.getType().cast<VectorType>().getElementType());
}
-void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result,
- Value source, Value position) {
- result.addOperands({source, position});
- result.addTypes(source.getType().cast<VectorType>().getElementType());
-}
-
LogicalResult vector::ExtractElementOp::verify() {
VectorType vectorType = getVectorType();
if (vectorType.getRank() == 0) {
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
index 7567dc5944228..092c457d3bed9 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
@@ -432,7 +432,8 @@ static void createFullPartialVectorTransferWrite(RewriterBase &b,
Value load = b.create<memref::LoadOp>(
loc,
b.create<vector::TypeCastOp>(
- loc, MemRefType::get({}, xferOp.getVector().getType()), alloc));
+ loc, MemRefType::get({}, xferOp.getVector().getType()), alloc),
+ ValueRange());
mapping.map(xferOp.getVector(), load);
b.clone(*xferOp.getOperation(), mapping);
b.create<scf::YieldOp>(loc, ValueRange{});
diff --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp
index 150c385fd2d00..dd528406fe9a2 100644
--- a/mlir/lib/TableGen/Operator.cpp
+++ b/mlir/lib/TableGen/Operator.cpp
@@ -25,6 +25,7 @@
#include "llvm/Support/FormatVariadic.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Record.h"
+#include <list>
#define DEBUG_TYPE "mlir-tblgen-operator"
@@ -344,11 +345,6 @@ auto Operator::getOperands() const -> const_value_range {
auto Operator::getArg(int index) const -> Argument { return arguments[index]; }
-// Mapping from result index to combined argument and result index. Arguments
-// are indexed to match getArg index, while the result indexes are mapped to
-// avoid overlap.
-static int resultIndex(int i) { return -1 - i; }
-
bool Operator::isVariadic() const {
return any_of(llvm::concat<const NamedTypeConstraint>(operands, results),
[](const NamedTypeConstraint &op) { return op.isVariadic(); });
@@ -384,46 +380,47 @@ void Operator::populateTypeInferenceInfo(
if (operandI == arguments.end())
return;
- // Map each of the result types to the anchor operation.
+ // All result types are inferred from the operand type.
int operandIdx = operandI - arguments.begin();
- resultTypeMapping.resize(getNumResults());
for (int i = 0; i < getNumResults(); ++i)
- resultTypeMapping[i].emplace_back(operandIdx);
+ resultTypeMapping.emplace_back(operandIdx, "$_self");
allResultsHaveKnownTypes = true;
traits.push_back(Trait::create(inferTrait->getDefInit()));
return;
}
- // We create equivalence classes of argument/result types where arguments
- // and results are mapped into the same index space and indices corresponding
- // to the same type are in the same equivalence class.
- llvm::EquivalenceClasses<int> ecs;
- resultTypeMapping.resize(getNumResults());
- // Captures the argument whose type matches a given result type. Preference
- // towards capturing operands first before attributes.
- auto captureMapping = [&](int i) {
- bool found = false;
- ecs.insert(resultIndex(i));
- auto mi = ecs.findLeader(resultIndex(i));
- for (auto me = ecs.member_end(); mi != me; ++mi) {
- if (*mi < 0) {
- auto tc = getResultTypeConstraint(i);
- if (tc.getBuilderCall()) {
- resultTypeMapping[i].emplace_back(tc);
- found = true;
- }
- continue;
- }
+ /// This struct represents a node in this operation's result type inferenece
+ /// graph. Each node has a list of incoming type inference edges `sources`.
+ /// Each edge represents a "source" from which the result type can be
+ /// inferred, either an operand (leaf) or another result (node). When a node
+ /// is known to have a fully-inferred type, `inferred` is set to true.
+ struct ResultTypeInference {
+ /// The list of incoming type inference edges.
+ SmallVector<InferredResultType> sources;
+ /// This flag is set to true when the result type is known to be inferrable.
+ bool inferred = false;
+ };
- resultTypeMapping[i].emplace_back(*mi);
- found = true;
+ // This vector represents the type inference graph, with one node for each
+ // operation result. The nth element is the node for the nth result.
+ SmallVector<ResultTypeInference> inference(getNumResults(), {});
+
+ // For all results whose types are buildable, initialize their type inference
+ // nodes with an edge to themselves. Mark those nodes are fully-inferred.
+ for (auto &[idx, infer] : llvm::enumerate(inference)) {
+ if (getResult(idx).constraint.getBuilderCall()) {
+ infer.sources.emplace_back(InferredResultType::mapResultIndex(idx),
+ "$_self");
+ infer.inferred = true;
}
- return found;
- };
+ }
+ // Use `AllTypesMatch` and `TypesMatchWith` operation traits to build the
+ // result type inference graph.
for (const Trait &trait : traits) {
const llvm::Record &def = trait.getDef();
+
// If the infer type op interface was manually added, then treat it as
// intention that the op needs special handling.
// TODO: Reconsider whether to always generate, this is more conservative
@@ -435,24 +432,106 @@ void Operator::populateTypeInferenceInfo(
if (&traitDef->getDef() == inferTrait)
return;
+ // The `TypesMatchWith` trait represents a 1 -> 1 type inference edge with a
+ // type transformer.
+ if (def.isSubClassOf("TypesMatchWith")) {
+ int target = argumentsAndResultsIndex.lookup(def.getValueAsString("rhs"));
+ // Ignore operand type inference.
+ if (InferredResultType::isArgIndex(target))
+ continue;
+ int resultIndex = InferredResultType::unmapResultIndex(target);
+ ResultTypeInference &infer = inference[resultIndex];
+ // If the type of the result has already been inferred, do nothing.
+ if (infer.inferred)
+ continue;
+ int sourceIndex =
+ argumentsAndResultsIndex.lookup(def.getValueAsString("lhs"));
+ infer.sources.emplace_back(sourceIndex,
+ def.getValueAsString("transformer").str());
+ // Locally propagate inferredness.
+ infer.inferred =
+ InferredResultType::isArgIndex(sourceIndex) ||
+ inference[InferredResultType::unmapResultIndex(sourceIndex)].inferred;
+ continue;
+ }
+
if (!def.isSubClassOf("AllTypesMatch"))
continue;
auto values = def.getValueAsListOfStrings("values");
- auto root = argumentsAndResultsIndex.lookup(values.front());
- for (StringRef str : values)
- ecs.unionSets(argumentsAndResultsIndex.lookup(str), root);
+ // The `AllTypesMatch` trait represents an N <-> N fanin and fanout. That
+ // is, every result type has an edge from every other type. However, if any
+ // one of the values refers to an operand or a result with a fully-inferred
+ // type, we can infer all other types from that value. Try to find a
+ // fully-inferred type in the list.
+ std::optional<int> fullyInferredIndex;
+ SmallVector<int> resultIndices;
+ for (StringRef name : values) {
+ int index = argumentsAndResultsIndex.lookup(name);
+ if (InferredResultType::isResultIndex(index))
+ resultIndices.push_back(InferredResultType::unmapResultIndex(index));
+ if (InferredResultType::isArgIndex(index) ||
+ inference[InferredResultType::unmapResultIndex(index)].inferred)
+ fullyInferredIndex = index;
+ }
+ if (fullyInferredIndex) {
+ // Make the fully-inferred type the only source for all results that
+ // aren't already inferred -- a 1 -> N fanout.
+ for (int resultIndex : resultIndices) {
+ ResultTypeInference &infer = inference[resultIndex];
+ if (!infer.inferred) {
+ infer.sources.assign(1, {*fullyInferredIndex, "$_self"});
+ infer.inferred = true;
+ }
+ }
+ } else {
+ // Add an edge between every result and every other type; N <-> N.
+ for (int resultIndex : resultIndices) {
+ for (int otherResultIndex : resultIndices) {
+ if (resultIndex == otherResultIndex)
+ continue;
+ inference[resultIndex].sources.emplace_back(otherResultIndex,
+ "$_self");
+ }
+ }
+ }
}
- // Verifies that all output types have a corresponding known input type
- // and chooses matching operand or attribute (in that order) that
- // matches it.
- allResultsHaveKnownTypes =
- all_of(llvm::seq<int>(0, getNumResults()), captureMapping);
+ // Propagate inferredness until a fixed point.
+ std::list<ResultTypeInference *> worklist;
+ for (ResultTypeInference &infer : inference)
+ if (!infer.inferred)
+ worklist.push_back(&infer);
+ bool changed;
+ do {
+ changed = false;
+ // This is `llvm::make_early_inc_range` but keeps the iterator for erasing.
+ for (auto earlyIncIt = worklist.begin(), cur = earlyIncIt;
+ cur = earlyIncIt++, cur != worklist.end();) {
+ ResultTypeInference &infer = **cur;
+ for (auto &[idx, source] : llvm::enumerate(infer.sources)) {
+ assert(InferredResultType::isResultIndex(source.getIndex()));
+ if (inference[InferredResultType::unmapResultIndex(source.getIndex())]
+ .inferred) {
+ changed = true;
+ infer.inferred = true;
+ // Make this the only source for the result. This breaks any cycles.
+ infer.sources.assign(1, source);
+ worklist.erase(cur);
+ break;
+ }
+ }
+ }
+ } while (changed);
+
+ allResultsHaveKnownTypes = worklist.empty();
// If the types could be computed, then add type inference trait.
- if (allResultsHaveKnownTypes)
+ if (allResultsHaveKnownTypes) {
traits.push_back(Trait::create(inferTrait->getDefInit()));
+ for (const ResultTypeInference &infer : inference)
+ resultTypeMapping.push_back(infer.sources.front());
+ }
}
void Operator::populateOpStructure() {
@@ -562,7 +641,7 @@ void Operator::populateOpStructure() {
resultDef = resultDef->getValueAsDef("constraint");
results.push_back({name, TypeConstraint(resultDef)});
if (!name.empty())
- argumentsAndResultsIndex[name] = resultIndex(i);
+ argumentsAndResultsIndex[name] = InferredResultType::mapResultIndex(i);
// We currently only support VariadicOfVariadic operands.
if (results.back().constraint.isVariadicOfVariadic()) {
@@ -683,7 +762,7 @@ void Operator::populateOpStructure() {
LLVM_DEBUG(print(llvm::dbgs()));
}
-auto Operator::getSameTypeAsResult(int index) const -> ArrayRef<ArgOrType> {
+const InferredResultType &Operator::getInferredResultType(int index) const {
assert(allResultTypesKnown());
return resultTypeMapping[index];
}
diff --git a/mlir/test/mlir-tblgen/op-result.td b/mlir/test/mlir-tblgen/op-result.td
index 33bbfec821083..06178f8489c00 100644
--- a/mlir/test/mlir-tblgen/op-result.td
+++ b/mlir/test/mlir-tblgen/op-result.td
@@ -155,6 +155,22 @@ def OpL3 : NS_Op<"op_with_all_types_constraint",
// CHECK: ::mlir::Type odsInferredType0 = attributes.get("a").cast<::mlir::TypedAttr>().getType();
// CHECK: inferredReturnTypes[0] = odsInferredType0;
+def OpL4 : NS_Op<"two_inference_edges", [
+ TypesMatchWith<"", "a", "b", "infer0($_self)">,
+ TypesMatchWith<"", "b", "c", "infer1($_self)">,
+ TypesMatchWith<"", "input", "a", "fromInput($_self)">]> {
+ let arguments = (ins I32:$input);
+ let results = (outs AnyType:$a, AnyType:$b, AnyType:$c);
+}
+
+// CHECK-LABEL: LogicalResult OpL4::inferReturnTypes
+// CHECK: odsInferredType0 = fromInput(operands[0].getType())
+// CHECK: odsInferredType1 = infer0(odsInferredType0)
+// CHECK: odsInferredType2 = infer1(odsInferredType1)
+// CHECK: inferredReturnTypes[0] = odsInferredType0
+// CHECK: inferredReturnTypes[1] = odsInferredType1
+// CHECK: inferredReturnTypes[2] = odsInferredType2
+
def OpM : NS_Op<"mix_
diff _size_variadic_and_normal_results_op", [AttrSizedResultSegments]> {
let results = (outs Variadic<AnyTensor>:$output1, AnyTensor:$output2, Optional<AnyTensor>:$output3);
}
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 2483378f691bb..3e7166e4ccd8c 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -22,6 +22,7 @@
#include "mlir/TableGen/Operator.h"
#include "mlir/TableGen/SideEffects.h"
#include "mlir/TableGen/Trait.h"
+#include "llvm/ADT/BitVector.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/StringExtras.h"
@@ -2518,67 +2519,57 @@ void OpEmitter::genTypeInterfaceMethods() {
FmtContext fctx;
fctx.withBuilder("odsBuilder");
+ fctx.addSubst("_ctxt", "context");
body << " ::mlir::Builder odsBuilder(context);\n";
- // Preprocess the result types and build all of the types used during
- // inferrence. This limits the amount of duplicated work when a type is used
- // to infer multiple others.
- llvm::DenseMap<Constraint, int> constraintsTypes;
- llvm::DenseMap<int, int> argumentsTypes;
+ // Process the type inference graph in topological order, starting from types
+ // that are always fully-inferred: operands and results with constructible
+ // types. The type inference graph here will always be a DAG, so this gives
+ // us the correct order for generating the types. -1 is a placeholder to
+ // indicate the type for a result has not been generated.
+ SmallVector<int> constructedIndices(op.getNumResults(), -1);
int inferredTypeIdx = 0;
- for (int i = 0, e = op.getNumResults(); i != e; ++i) {
- auto type = op.getSameTypeAsResult(i).front();
-
- // If the type isn't an argument, it refers to a buildable type.
- if (!type.isArg()) {
- auto it = constraintsTypes.try_emplace(type.getType(), inferredTypeIdx);
- if (!it.second)
+ for (int numResults = op.getNumResults(); inferredTypeIdx != numResults;) {
+ for (int i = 0, e = op.getNumResults(); i != e; ++i) {
+ if (constructedIndices[i] >= 0)
continue;
-
- // If we haven't seen this constraint, generate a variable for it.
- body << " ::mlir::Type odsInferredType" << inferredTypeIdx++ << " = "
- << tgfmt(*type.getType().getBuilderCall(), &fctx) << ";\n";
- continue;
- }
-
- // Otherwise, this is an argument.
- int argIndex = type.getArg();
- auto it = argumentsTypes.try_emplace(argIndex, inferredTypeIdx);
- if (!it.second)
- continue;
- body << " ::mlir::Type odsInferredType" << inferredTypeIdx++ << " = ";
-
- // If this is an operand, just index into operand list to access the type.
- auto arg = op.getArgToOperandOrAttribute(argIndex);
- if (arg.kind() == Operator::OperandOrAttribute::Kind::Operand) {
- body << "operands[" << arg.operandOrAttributeIndex() << "].getType()";
-
- // If this is an attribute, index into the attribute dictionary.
- } else {
- auto *attr =
- op.getArg(arg.operandOrAttributeIndex()).get<NamedAttribute *>();
- body << "attributes.get(\"" << attr->name
- << "\").cast<::mlir::TypedAttr>().getType()";
+ const InferredResultType &infer = op.getInferredResultType(i);
+ std::string typeStr;
+ body << " ::mlir::Type odsInferredType" << inferredTypeIdx++ << " = ";
+ if (infer.isArg()) {
+ // If this is an operand, just index into operand list to access the
+ // type.
+ auto arg = op.getArgToOperandOrAttribute(infer.getIndex());
+ if (arg.kind() == Operator::OperandOrAttribute::Kind::Operand) {
+ typeStr = ("operands[" + Twine(arg.operandOrAttributeIndex()) +
+ "].getType()")
+ .str();
+
+ // If this is an attribute, index into the attribute dictionary.
+ } else {
+ auto *attr =
+ op.getArg(arg.operandOrAttributeIndex()).get<NamedAttribute *>();
+ typeStr = ("attributes.get(\"" + attr->name +
+ "\").cast<::mlir::TypedAttr>().getType()")
+ .str();
+ }
+ } else if (std::optional<StringRef> builder =
+ op.getResult(infer.getResultIndex())
+ .constraint.getBuilderCall()) {
+ typeStr = tgfmt(*builder, &fctx).str();
+ } else if (int index = constructedIndices[infer.getResultIndex()];
+ index >= 0) {
+ typeStr = ("odsInferredType" + Twine(index)).str();
+ } else {
+ continue;
+ }
+ body << tgfmt(infer.getTransformer(), &fctx.withSelf(typeStr)) << ";\n";
+ constructedIndices[i] = inferredTypeIdx - 1;
}
- body << ";\n";
}
-
- // Perform a second pass that handles assigning the inferred types to the
- // results.
- for (int i = 0, e = op.getNumResults(); i != e; ++i) {
- auto types = op.getSameTypeAsResult(i);
-
- // Append the inferred type.
- auto type = types.front();
- body << " inferredReturnTypes[" << i << "] = odsInferredType"
- << (type.isArg() ? argumentsTypes[type.getArg()]
- : constraintsTypes[type.getType()])
+ for (auto [i, index] : llvm::enumerate(constructedIndices))
+ body << " inferredReturnTypes[" << i << "] = odsInferredType" << index
<< ";\n";
-
- if (types.size() == 1)
- continue;
- // TODO: We could verify equality here, but skipping that for verification.
- }
body << " return ::mlir::success();";
}
More information about the Mlir-commits
mailing list