[Mlir-commits] [mlir] 688c51a - [MLIR] Add mlir::TypedValue
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Aug 28 20:26:50 PDT 2022
Author: Tyker
Date: 2022-08-28T20:26:37-07:00
New Revision: 688c51a5acc53b456014e53663051476d825e896
URL: https://github.com/llvm/llvm-project/commit/688c51a5acc53b456014e53663051476d825e896
DIFF: https://github.com/llvm/llvm-project/commit/688c51a5acc53b456014e53663051476d825e896.diff
LOG: [MLIR] Add mlir::TypedValue
mlir::TypedValue is a wrapper class for mlir::Values with a known type
getType will return the known type and all assignements will be checked
Also the tablegen Operation generator was adapted to use mlir::TypedValue
when appropriate
Added:
Modified:
mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h
mlir/include/mlir/IR/Value.h
mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
mlir/lib/Dialect/SCF/Utils/Utils.cpp
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/test/lib/Dialect/Test/TestDialect.h
mlir/test/mlir-tblgen/op-attribute.td
mlir/test/mlir-tblgen/op-decl-and-defs.td
mlir/test/mlir-tblgen/op-operand.td
mlir/test/mlir-tblgen/op-result.td
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h b/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h
index 664797bb8b84e..bc51845b96064 100644
--- a/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h
+++ b/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h
@@ -9,6 +9,7 @@
#ifndef MLIR_DIALECT_BUFFERIZATION_TRANSFORMOPS_BUFFERIZATIONTRANSFORMOPS_H
#define MLIR_DIALECT_BUFFERIZATION_TRANSFORMOPS_BUFFERIZATIONTRANSFORMOPS_H
+#include "mlir/Dialect/PDL/IR/PDLTypes.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/IR/OpImplementation.h"
diff --git a/mlir/include/mlir/IR/Value.h b/mlir/include/mlir/IR/Value.h
index 2f0524c4da7df..c9ff4f0145938 100644
--- a/mlir/include/mlir/IR/Value.h
+++ b/mlir/include/mlir/IR/Value.h
@@ -419,6 +419,27 @@ inline unsigned OpResultImpl::getResultNumber() const {
return cast<InlineOpResult>(this)->getResultNumber();
}
+/// TypedValue is a Value with a statically know type.
+/// TypedValue can be null/empty
+template <typename Ty>
+struct TypedValue : Value {
+ /// Return the known Type
+ Ty getType() { return Value::getType().template cast<Ty>(); }
+ void setType(mlir::Type ty) {
+ assert(ty.template isa<Ty>());
+ Value::setType(ty);
+ }
+
+ TypedValue(Value val) : Value(val) {
+ assert(!val || val.getType().template isa<Ty>());
+ }
+ TypedValue &operator=(const Value &other) {
+ assert(!other || other.getType().template isa<Ty>());
+ Value::operator=(other);
+ return *this;
+ }
+};
+
} // namespace detail
/// This is a value defined by a result of an operation.
@@ -459,6 +480,12 @@ inline ::llvm::hash_code hash_value(Value arg) {
return ::llvm::hash_value(arg.getImpl());
}
+template <typename Ty, typename Value = mlir::Value>
+/// If Ty is mlir::Type this will select `Value` instead of having a wrapper
+/// around it. This helps resolve ambiguous conversion issues.
+using TypedValue = std::conditional_t<std::is_same_v<Ty, mlir::Type>,
+ mlir::Value, detail::TypedValue<Ty>>;
+
} // namespace mlir
namespace llvm {
diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 8756603acf3e6..e4448ff1f3100 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -784,7 +784,7 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
launchOp.getKernelName().getValue(), loc, rewriter);
auto function = moduleGetFunctionCallBuilder.create(
loc, rewriter, {module.getResult(), kernelName});
- auto zero = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type, 0);
+ Value zero = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type, 0);
Value stream =
adaptor.asyncDependencies().empty()
? streamCreateCallBuilder.create(loc, rewriter, {}).getResult()
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 7569476cb7cd1..501d95b3c791d 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -428,7 +428,7 @@ LogicalResult mlir::loopUnrollByFactor(
// 'upperBoundUnrolled' and 'stepUnrolled' for static and dynamic cases.
OpBuilder boundsBuilder(forOp);
auto loc = forOp.getLoc();
- auto step = forOp.getStep();
+ Value step = forOp.getStep();
Value upperBoundUnrolled;
Value stepUnrolled;
bool generateEpilogueLoop = true;
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 2b8cae83fd414..2227064f6791c 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -778,7 +778,7 @@ struct CanonicalizeCastExtentTensorOperandsPattern
PatternRewriter &rewriter) const override {
// Canonicalize operands.
bool anyChange = false;
- auto canonicalizeOperand = [&](Value operand) {
+ auto canonicalizeOperand = [&](Value operand) -> Value {
if (auto castOp = operand.getDefiningOp<tensor::CastOp>()) {
// Only eliminate the cast if it holds no shape information.
bool isInformationLoosingCast =
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.h b/mlir/test/lib/Dialect/Test/TestDialect.h
index 0e583d38aa1a5..ceb9dc6f4c933 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.h
+++ b/mlir/test/lib/Dialect/Test/TestDialect.h
@@ -14,6 +14,7 @@
#ifndef MLIR_TESTDIALECT_H
#define MLIR_TESTDIALECT_H
+#include "TestTypes.h"
#include "TestAttributes.h"
#include "TestInterfaces.h"
#include "mlir/Dialect/DLTI/DLTI.h"
diff --git a/mlir/test/mlir-tblgen/op-attribute.td b/mlir/test/mlir-tblgen/op-attribute.td
index 3c85d71460dc5..3e632d7502179 100644
--- a/mlir/test/mlir-tblgen/op-attribute.td
+++ b/mlir/test/mlir-tblgen/op-attribute.td
@@ -471,8 +471,8 @@ def OpWithDefaultAndSuccessor : NS_Op<"default_with_succ", []> {
// DECL: static void build({{.*}}, bool dv_bool_attr, ::mlir::BlockRange succ)
// DEF-LABEL: MixOperandsAndAttrs definitions
-// DEF-DAG: ::mlir::Value MixOperandsAndAttrs::operand()
-// DEF-DAG: ::mlir::Value MixOperandsAndAttrs::otherArg()
+// DEF-DAG: ::mlir::TypedValue<::mlir::FloatType> MixOperandsAndAttrs::operand()
+// DEF-DAG: ::mlir::TypedValue<::mlir::FloatType> MixOperandsAndAttrs::otherArg()
// DEF-DAG: void MixOperandsAndAttrs::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::FloatAttr attr, ::mlir::Value operand, ::mlir::FloatAttr otherAttr, ::mlir::Value otherArg)
// DEF-DAG: ::llvm::APFloat MixOperandsAndAttrs::attr()
// DEF-DAG: ::llvm::APFloat MixOperandsAndAttrs::otherAttr()
diff --git a/mlir/test/mlir-tblgen/op-decl-and-defs.td b/mlir/test/mlir-tblgen/op-decl-and-defs.td
index cab95c3294673..76d76c974a733 100644
--- a/mlir/test/mlir-tblgen/op-decl-and-defs.td
+++ b/mlir/test/mlir-tblgen/op-decl-and-defs.td
@@ -78,12 +78,12 @@ def NS_AOp : NS_Op<"a_op", [IsolatedFromAbove, IsolatedFromAbove]> {
// CHECK: return ::llvm::StringLiteral("test.a_op");
// CHECK: }
// CHECK: ::mlir::Operation::operand_range getODSOperands(unsigned index);
-// CHECK: ::mlir::Value getA();
+// CHECK: ::mlir::TypedValue<::mlir::IntegerType> getA();
// CHECK: ::mlir::Operation::operand_range getB();
// CHECK: ::mlir::MutableOperandRange getAMutable();
// CHECK: ::mlir::MutableOperandRange getBMutable();
// CHECK: ::mlir::Operation::result_range getODSResults(unsigned index);
-// CHECK: ::mlir::Value getR();
+// CHECK: ::mlir::TypedValue<::mlir::IntegerType> getR();
// CHECK: ::mlir::Region &getSomeRegion();
// CHECK: ::mlir::MutableArrayRef<::mlir::Region> getSomeRegions();
// CHECK: ::mlir::IntegerAttr getAttr1Attr()
@@ -169,7 +169,7 @@ def NS_EOp : NS_Op<"op_with_optionals", []> {
// CHECK-LABEL: NS::EOp declarations
// CHECK: ::mlir::Value getA();
// CHECK: ::mlir::MutableOperandRange getAMutable();
-// CHECK: ::mlir::Value getB();
+// CHECK: ::mlir::TypedValue<::mlir::FloatType> getB();
// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, /*optional*/::mlir::Type b, /*optional*/::mlir::Value a)
// Check that all types match constraint results in generating builder.
diff --git a/mlir/test/mlir-tblgen/op-operand.td b/mlir/test/mlir-tblgen/op-operand.td
index 7e2c2c55e0a4b..c4f904dc0a14a 100644
--- a/mlir/test/mlir-tblgen/op-operand.td
+++ b/mlir/test/mlir-tblgen/op-operand.td
@@ -54,7 +54,7 @@ def OpD : NS_Op<"mix_variadic_and_normal_inputs_op", [SameVariadicOperandSize]>
// CHECK-LABEL: ::mlir::Operation::operand_range OpD::input1
// CHECK-NEXT: return getODSOperands(0);
-// CHECK-LABEL: ::mlir::Value OpD::input2
+// CHECK-LABEL: ::mlir::TypedValue<::mlir::TensorType> OpD::input2
// CHECK-NEXT: return *getODSOperands(1).begin();
// CHECK-LABEL: OpD::build
diff --git a/mlir/test/mlir-tblgen/op-result.td b/mlir/test/mlir-tblgen/op-result.td
index cf7fb6c30870d..5c356e0137fd7 100644
--- a/mlir/test/mlir-tblgen/op-result.td
+++ b/mlir/test/mlir-tblgen/op-result.td
@@ -102,7 +102,7 @@ def OpI : NS_Op<"mix_variadic_and_normal_results_op", [SameVariadicResultSize]>
// CHECK-LABEL: ::mlir::Operation::result_range OpI::output1
// CHECK-NEXT: return getODSResults(0);
-// CHECK-LABEL: ::mlir::Value OpI::output2
+// CHECK-LABEL: ::mlir::TypedValue<::mlir::TensorType> OpI::output2
// CHECK-NEXT: return *getODSResults(1).begin();
// CHECK-LABEL: OpI::build
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 9d5c391224c0d..70c22e17f27ee 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -1152,6 +1152,25 @@ generateValueRangeStartAndEnd(Class &opClass, StringRef methodName,
}
}
+static std::string generateTypeForGetter(bool isAdaptor,
+ const NamedTypeConstraint &value) {
+ std::string str = "::mlir::Value";
+ /// If the CPPClassName is not a fully qualified type. Uses of types
+ /// across Dialect fail because they are not in the correct namespace. So we
+ /// dont generate TypedValue unless the type is fully qualified.
+ /// getCPPClassName doesn't return the fully qualified path for
+ /// `mlir::pdl::OperationType` see
+ /// https://github.com/llvm/llvm-project/issues/57279.
+ /// Adaptor will have values that are not from the type of their operation and
+ /// this is expected, so we dont generate TypedValue for Adaptor
+ if (!isAdaptor && value.constraint.getCPPClassName() != "::mlir::Type" &&
+ StringRef(value.constraint.getCPPClassName()).startswith("::"))
+ str = llvm::formatv("::mlir::TypedValue<{0}>",
+ value.constraint.getCPPClassName())
+ .str();
+ return str;
+}
+
// Generates the named operand getter methods for the given Operator `op` and
// puts them in `opClass`. Uses `rangeType` as the return type of getters that
// return a range of operands (individual operands are `Value ` and each
@@ -1216,7 +1235,7 @@ static void generateNamedOperandGetters(const Operator &op, Class &opClass,
continue;
for (StringRef name : op.getGetterNames(operand.name)) {
if (operand.isOptional()) {
- m = opClass.addMethod("::mlir::Value", name);
+ m = opClass.addMethod(generateTypeForGetter(isAdaptor, operand), name);
ERROR_IF_PRUNED(m, name, op);
m->body() << " auto operands = getODSOperands(" << i << ");\n"
<< " return operands.empty() ? ::mlir::Value() : "
@@ -1242,7 +1261,7 @@ static void generateNamedOperandGetters(const Operator &op, Class &opClass,
ERROR_IF_PRUNED(m, name, op);
m->body() << " return getODSOperands(" << i << ");";
} else {
- m = opClass.addMethod("::mlir::Value", name);
+ m = opClass.addMethod(generateTypeForGetter(isAdaptor, operand), name);
ERROR_IF_PRUNED(m, name, op);
m->body() << " return *getODSOperands(" << i << ").begin();";
}
@@ -1365,7 +1384,8 @@ void OpEmitter::genNamedResultGetters() {
continue;
for (StringRef name : op.getGetterNames(result.name)) {
if (result.isOptional()) {
- m = opClass.addMethod("::mlir::Value", name);
+ m = opClass.addMethod(
+ generateTypeForGetter(/*isAdaptor=*/false, result), name);
ERROR_IF_PRUNED(m, name, op);
m->body()
<< " auto results = getODSResults(" << i << ");\n"
@@ -1375,7 +1395,8 @@ void OpEmitter::genNamedResultGetters() {
ERROR_IF_PRUNED(m, name, op);
m->body() << " return getODSResults(" << i << ");";
} else {
- m = opClass.addMethod("::mlir::Value", name);
+ m = opClass.addMethod(
+ generateTypeForGetter(/*isAdaptor=*/false, result), name);
ERROR_IF_PRUNED(m, name, op);
m->body() << " return *getODSResults(" << i << ").begin();";
}
More information about the Mlir-commits
mailing list