[Mlir-commits] [mlir] 0ce25b1 - [mlir] Require explicit casts when using TypedValue
Rahul Kayaith
llvmlistbot at llvm.org
Wed Feb 1 18:55:01 PST 2023
Author: Rahul Kayaith
Date: 2023-02-01T21:54:53-05:00
New Revision: 0ce25b12357b24d06cf08cc02719c144d567d5db
URL: https://github.com/llvm/llvm-project/commit/0ce25b12357b24d06cf08cc02719c144d567d5db
DIFF: https://github.com/llvm/llvm-project/commit/0ce25b12357b24d06cf08cc02719c144d567d5db.diff
LOG: [mlir] Require explicit casts when using TypedValue
Currently `TypedValue` can be constructed directly from `Value`, hiding
errors that could be caught at compile time. For example the following
will compile, but crash/assert at runtime:
```
void foo(TypedValue<IntegerType>);
void bar(TypedValue<FloatType> v) {
foo(v);
}
```
This change removes the constructors and replaces them with explicit
llvm casts.
Depends on D142852
Reviewed By: rriddle
Differential Revision: https://reviews.llvm.org/D142855
Added:
Modified:
mlir/docs/Tutorials/Toy/Ch-4.md
mlir/examples/toy/Ch4/mlir/Dialect.cpp
mlir/examples/toy/Ch5/mlir/Dialect.cpp
mlir/examples/toy/Ch6/mlir/Dialect.cpp
mlir/examples/toy/Ch7/mlir/Dialect.cpp
mlir/include/mlir/IR/OpDefinition.h
mlir/include/mlir/IR/Value.h
mlir/include/mlir/TableGen/Class.h
mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h
mlir/test/mlir-tblgen/op-operand.td
mlir/test/mlir-tblgen/op-result.td
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
mlir/unittests/IR/IRMapping.cpp
Removed:
################################################################################
diff --git a/mlir/docs/Tutorials/Toy/Ch-4.md b/mlir/docs/Tutorials/Toy/Ch-4.md
index 95cdca0c3fa32..df82141c77ba5 100644
--- a/mlir/docs/Tutorials/Toy/Ch-4.md
+++ b/mlir/docs/Tutorials/Toy/Ch-4.md
@@ -375,7 +375,7 @@ inferred as the shape of the inputs.
```c++
/// Infer the output shape of the MulOp, this is required by the shape inference
/// interface.
-void MulOp::inferShapes() { getResult().setType(getOperand(0).getType()); }
+void MulOp::inferShapes() { getResult().setType(getLhs().getType()); }
```
At this point, each of the necessary Toy operations provide a mechanism by which
diff --git a/mlir/examples/toy/Ch4/mlir/Dialect.cpp b/mlir/examples/toy/Ch4/mlir/Dialect.cpp
index c87e107f6f415..17a42d69c8f4c 100644
--- a/mlir/examples/toy/Ch4/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch4/mlir/Dialect.cpp
@@ -237,7 +237,7 @@ void AddOp::print(mlir::OpAsmPrinter &p) { printBinaryOp(p, *this); }
/// Infer the output shape of the AddOp, this is required by the shape inference
/// interface.
-void AddOp::inferShapes() { getResult().setType(getOperand(0).getType()); }
+void AddOp::inferShapes() { getResult().setType(getLhs().getType()); }
//===----------------------------------------------------------------------===//
// CastOp
@@ -245,7 +245,7 @@ void AddOp::inferShapes() { getResult().setType(getOperand(0).getType()); }
/// Infer the output shape of the CastOp, this is required by the shape
/// inference interface.
-void CastOp::inferShapes() { getResult().setType(getOperand().getType()); }
+void CastOp::inferShapes() { getResult().setType(getInput().getType()); }
/// Returns true if the given set of input and result types are compatible with
/// this cast operation. This is required by the `CastOpInterface` to verify
@@ -349,7 +349,7 @@ void MulOp::print(mlir::OpAsmPrinter &p) { printBinaryOp(p, *this); }
/// Infer the output shape of the MulOp, this is required by the shape inference
/// interface.
-void MulOp::inferShapes() { getResult().setType(getOperand(0).getType()); }
+void MulOp::inferShapes() { getResult().setType(getLhs().getType()); }
//===----------------------------------------------------------------------===//
// ReturnOp
diff --git a/mlir/examples/toy/Ch5/mlir/Dialect.cpp b/mlir/examples/toy/Ch5/mlir/Dialect.cpp
index 04ae3149281f9..77ceb636e17f2 100644
--- a/mlir/examples/toy/Ch5/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch5/mlir/Dialect.cpp
@@ -237,7 +237,7 @@ void AddOp::print(mlir::OpAsmPrinter &p) { printBinaryOp(p, *this); }
/// Infer the output shape of the AddOp, this is required by the shape inference
/// interface.
-void AddOp::inferShapes() { getResult().setType(getOperand(0).getType()); }
+void AddOp::inferShapes() { getResult().setType(getLhs().getType()); }
//===----------------------------------------------------------------------===//
// CastOp
@@ -245,7 +245,7 @@ void AddOp::inferShapes() { getResult().setType(getOperand(0).getType()); }
/// Infer the output shape of the CastOp, this is required by the shape
/// inference interface.
-void CastOp::inferShapes() { getResult().setType(getOperand().getType()); }
+void CastOp::inferShapes() { getResult().setType(getInput().getType()); }
/// Returns true if the given set of input and result types are compatible with
/// this cast operation. This is required by the `CastOpInterface` to verify
@@ -349,7 +349,7 @@ void MulOp::print(mlir::OpAsmPrinter &p) { printBinaryOp(p, *this); }
/// Infer the output shape of the MulOp, this is required by the shape inference
/// interface.
-void MulOp::inferShapes() { getResult().setType(getOperand(0).getType()); }
+void MulOp::inferShapes() { getResult().setType(getLhs().getType()); }
//===----------------------------------------------------------------------===//
// ReturnOp
diff --git a/mlir/examples/toy/Ch6/mlir/Dialect.cpp b/mlir/examples/toy/Ch6/mlir/Dialect.cpp
index 04ae3149281f9..77ceb636e17f2 100644
--- a/mlir/examples/toy/Ch6/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch6/mlir/Dialect.cpp
@@ -237,7 +237,7 @@ void AddOp::print(mlir::OpAsmPrinter &p) { printBinaryOp(p, *this); }
/// Infer the output shape of the AddOp, this is required by the shape inference
/// interface.
-void AddOp::inferShapes() { getResult().setType(getOperand(0).getType()); }
+void AddOp::inferShapes() { getResult().setType(getLhs().getType()); }
//===----------------------------------------------------------------------===//
// CastOp
@@ -245,7 +245,7 @@ void AddOp::inferShapes() { getResult().setType(getOperand(0).getType()); }
/// Infer the output shape of the CastOp, this is required by the shape
/// inference interface.
-void CastOp::inferShapes() { getResult().setType(getOperand().getType()); }
+void CastOp::inferShapes() { getResult().setType(getInput().getType()); }
/// Returns true if the given set of input and result types are compatible with
/// this cast operation. This is required by the `CastOpInterface` to verify
@@ -349,7 +349,7 @@ void MulOp::print(mlir::OpAsmPrinter &p) { printBinaryOp(p, *this); }
/// Infer the output shape of the MulOp, this is required by the shape inference
/// interface.
-void MulOp::inferShapes() { getResult().setType(getOperand(0).getType()); }
+void MulOp::inferShapes() { getResult().setType(getLhs().getType()); }
//===----------------------------------------------------------------------===//
// ReturnOp
diff --git a/mlir/examples/toy/Ch7/mlir/Dialect.cpp b/mlir/examples/toy/Ch7/mlir/Dialect.cpp
index 3382cbcd3074a..188b94fc2dfeb 100644
--- a/mlir/examples/toy/Ch7/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch7/mlir/Dialect.cpp
@@ -243,7 +243,9 @@ mlir::LogicalResult StructConstantOp::verify() {
/// Infer the output shape of the ConstantOp, this is required by the shape
/// inference interface.
-void ConstantOp::inferShapes() { getResult().setType(getValue().getType()); }
+void ConstantOp::inferShapes() {
+ getResult().setType(cast<TensorType>(getValue().getType()));
+}
//===----------------------------------------------------------------------===//
// AddOp
@@ -264,7 +266,7 @@ void AddOp::print(mlir::OpAsmPrinter &p) { printBinaryOp(p, *this); }
/// Infer the output shape of the AddOp, this is required by the shape inference
/// interface.
-void AddOp::inferShapes() { getResult().setType(getOperand(0).getType()); }
+void AddOp::inferShapes() { getResult().setType(getLhs().getType()); }
//===----------------------------------------------------------------------===//
// CastOp
@@ -272,7 +274,7 @@ void AddOp::inferShapes() { getResult().setType(getOperand(0).getType()); }
/// Infer the output shape of the CastOp, this is required by the shape
/// inference interface.
-void CastOp::inferShapes() { getResult().setType(getOperand().getType()); }
+void CastOp::inferShapes() { getResult().setType(getInput().getType()); }
/// Returns true if the given set of input and result types are compatible with
/// this cast operation. This is required by the `CastOpInterface` to verify
@@ -376,7 +378,7 @@ void MulOp::print(mlir::OpAsmPrinter &p) { printBinaryOp(p, *this); }
/// Infer the output shape of the MulOp, this is required by the shape inference
/// interface.
-void MulOp::inferShapes() { getResult().setType(getOperand(0).getType()); }
+void MulOp::inferShapes() { getResult().setType(getLhs().getType()); }
//===----------------------------------------------------------------------===//
// ReturnOp
diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index 60c961f91cfac..4c0a849ba908c 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -632,7 +632,7 @@ class OneTypedResult {
: public TraitBase<ConcreteType, OneTypedResult<ResultType>::Impl> {
public:
TypedValue<ResultType> getResult() {
- return this->getOperation()->getResult(0);
+ return cast<TypedValue<ResultType>>(this->getOperation()->getResult(0));
}
/// If the operation returns a single value, then the Op can be implicitly
diff --git a/mlir/include/mlir/IR/Value.h b/mlir/include/mlir/IR/Value.h
index c84ae97c99b9b..e95bfcf0252d0 100644
--- a/mlir/include/mlir/IR/Value.h
+++ b/mlir/include/mlir/IR/Value.h
@@ -427,21 +427,13 @@ inline unsigned OpResultImpl::getResultNumber() const {
/// TypedValue can be null/empty
template <typename Ty>
struct TypedValue : Value {
+ using Value::Value;
+
+ static bool classof(Value value) { return llvm::isa<Ty>(value.getType()); }
+
/// Return the known Type
Ty getType() { return Value::getType().template cast<Ty>(); }
- void setType(mlir::Type ty) {
- assert(ty.template isa<Ty>());
- Value::setType(ty);
- }
-
- TypedValue(Value val) : Value(val) {
- assert(!val || val.getType().template isa<Ty>());
- }
- TypedValue &operator=(const Value &other) {
- assert(!other || other.getType().template isa<Ty>());
- Value::operator=(other);
- return *this;
- }
+ void setType(Ty ty) { Value::setType(ty); }
};
} // namespace detail
diff --git a/mlir/include/mlir/TableGen/Class.h b/mlir/include/mlir/TableGen/Class.h
index 954ef5bec2d04..9c4efcd41cd6b 100644
--- a/mlir/include/mlir/TableGen/Class.h
+++ b/mlir/include/mlir/TableGen/Class.h
@@ -152,6 +152,9 @@ class MethodSignature {
/// Get the name of the method.
StringRef getName() const { return methodName; }
+ /// Get the return type of the method
+ StringRef getReturnType() const { return returnType; }
+
/// Get the number of parameters.
unsigned getNumParameters() const { return parameters.getNumParameters(); }
@@ -344,6 +347,9 @@ class Method : public ClassDeclarationBase<ClassDeclaration::Method> {
/// Returns the name of this method.
StringRef getName() const { return methodSignature.getName(); }
+ /// Returns the return type of this method
+ StringRef getReturnType() const { return methodSignature.getReturnType(); }
+
/// Returns if this method makes the `other` method redundant.
bool makesRedundant(const Method &other) const {
return methodSignature.makesRedundant(other.methodSignature);
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index aa41c962826da..46f5292a9ad1a 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1884,7 +1884,7 @@ class ReverseConverter : public OpRewritePattern<tosa::ReverseOp> {
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
llvm::SmallVector<Value> indices;
for (unsigned int i = 0; i < inputTy.getRank(); i++) {
- auto index =
+ Value index =
rewriter.create<linalg::IndexOp>(nestedLoc, i).getResult();
if (i == axis) {
auto one = rewriter.create<arith::ConstantIndexOp>(nestedLoc, 1);
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 931a85b636f9d..d8070b34a761d 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -1033,7 +1033,7 @@ struct UnrollTransferWriteConversion
auto vec = getDataVector(xferOp);
auto xferVecType = xferOp.getVectorType();
int64_t dimSize = xferVecType.getShape()[0];
- auto source = xferOp.getSource(); // memref or tensor to be written to.
+ Value source = xferOp.getSource(); // memref or tensor to be written to.
auto sourceType = isTensorOp(xferOp) ? xferOp.getShapedType() : Type();
// Generate fully unrolled loop of transfer ops.
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 1e2a7b017b7ff..9cd2331c24d19 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -1056,7 +1056,7 @@ struct NewRewriter : public OpRewritePattern<NewOp> {
// %t = sparse_tensor.ConvertOp %tmp
RankedTensorType cooTp =
getUnorderedCOOFromTypeWithOrdering(dstTp, encDst.getDimOrdering());
- auto cooBuffer =
+ Value cooBuffer =
rewriter.create<AllocTensorOp>(loc, cooTp, dynSizesArray).getResult();
Value c0 = constantIndex(rewriter, loc, 0);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h
index d5554ea8fbd26..98c75510b46e7 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h
@@ -173,7 +173,8 @@ class StorageLayout {
class SparseTensorSpecifier {
public:
- explicit SparseTensorSpecifier(Value specifier) : specifier(specifier) {}
+ explicit SparseTensorSpecifier(Value specifier)
+ : specifier(cast<TypedValue<StorageSpecifierType>>(specifier)) {}
// Undef value for dimension sizes, all zero value for memory sizes.
static Value getInitValue(OpBuilder &builder, Location loc,
diff --git a/mlir/test/mlir-tblgen/op-operand.td b/mlir/test/mlir-tblgen/op-operand.td
index 7a76f98ead581..68a9def83c2e0 100644
--- a/mlir/test/mlir-tblgen/op-operand.td
+++ b/mlir/test/mlir-tblgen/op-operand.td
@@ -43,7 +43,7 @@ def OpD : NS_Op<"mix_variadic_and_normal_inputs_op", [SameVariadicOperandSize]>
// CHECK-NEXT: return getODSOperands(0);
// CHECK-LABEL: ::mlir::TypedValue<::mlir::TensorType> OpD::getInput2
-// CHECK-NEXT: return *getODSOperands(1).begin();
+// CHECK-NEXT: return ::llvm::cast<::mlir::TypedValue<::mlir::TensorType>>(*getODSOperands(1).begin());
// CHECK-LABEL: OpD::build
// CHECK-NEXT: odsState.addOperands(input1);
diff --git a/mlir/test/mlir-tblgen/op-result.td b/mlir/test/mlir-tblgen/op-result.td
index 06178f8489c00..d49bffa9cb441 100644
--- a/mlir/test/mlir-tblgen/op-result.td
+++ b/mlir/test/mlir-tblgen/op-result.td
@@ -100,7 +100,7 @@ def OpI : NS_Op<"mix_variadic_and_normal_results_op", [SameVariadicResultSize]>
// CHECK-NEXT: return getODSResults(0);
// CHECK-LABEL: ::mlir::TypedValue<::mlir::TensorType> OpI::getOutput2
-// CHECK-NEXT: return *getODSResults(1).begin();
+// CHECK-NEXT: return ::llvm::cast<::mlir::TypedValue<::mlir::TensorType>>(*getODSResults(1).begin());
// CHECK-LABEL: OpI::build
// CHECK-NEXT: odsState.addTypes(output1);
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 7ed29f91d3d64..bc3e2599e72d7 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -1337,10 +1337,12 @@ generateNamedOperandGetters(const Operator &op, Class &opClass,
: generateTypeForGetter(operand),
name);
ERROR_IF_PRUNED(m, name, op);
- m->body().indent() << formatv(
- "auto operands = getODSOperands({0});\n"
- "return operands.empty() ? {1}{{} : *operands.begin();",
- i, rangeElementType);
+ m->body().indent() << formatv("auto operands = getODSOperands({0});\n"
+ "return operands.empty() ? {1}{{} : ",
+ i, m->getReturnType());
+ if (!isGenericAdaptorBase)
+ m->body() << llvm::formatv("::llvm::cast<{0}>", m->getReturnType());
+ m->body() << "(*operands.begin());";
} else if (operand.isVariadicOfVariadic()) {
std::string segmentAttr = op.getGetterName(
operand.constraint.getVariadicOfVariadicSegmentSizeAttr());
@@ -1366,7 +1368,10 @@ generateNamedOperandGetters(const Operator &op, Class &opClass,
: generateTypeForGetter(operand),
name);
ERROR_IF_PRUNED(m, name, op);
- m->body() << " return *getODSOperands(" << i << ").begin();";
+ m->body().indent() << "return ";
+ if (!isGenericAdaptorBase)
+ m->body() << llvm::formatv("::llvm::cast<{0}>", m->getReturnType());
+ m->body() << llvm::formatv("(*getODSOperands({0}).begin());", i);
}
}
}
@@ -1489,9 +1494,11 @@ void OpEmitter::genNamedResultGetters() {
if (result.isOptional()) {
m = opClass.addMethod(generateTypeForGetter(result), name);
ERROR_IF_PRUNED(m, name, op);
- m->body()
- << " auto results = getODSResults(" << i << ");\n"
- << " return results.empty() ? ::mlir::Value() : *results.begin();";
+ m->body() << " auto results = getODSResults(" << i << ");\n"
+ << llvm::formatv(" return results.empty()"
+ " ? {0}()"
+ " : ::llvm::cast<{0}>(*results.begin());",
+ m->getReturnType());
} else if (result.isVariadic()) {
m = opClass.addMethod("::mlir::Operation::result_range", name);
ERROR_IF_PRUNED(m, name, op);
@@ -1499,7 +1506,9 @@ void OpEmitter::genNamedResultGetters() {
} else {
m = opClass.addMethod(generateTypeForGetter(result), name);
ERROR_IF_PRUNED(m, name, op);
- m->body() << " return *getODSResults(" << i << ").begin();";
+ m->body() << llvm::formatv(
+ " return ::llvm::cast<{0}>(*getODSResults({1}).begin());",
+ m->getReturnType(), i);
}
}
}
diff --git a/mlir/unittests/IR/IRMapping.cpp b/mlir/unittests/IR/IRMapping.cpp
index bf00e5eea2260..83627975006ee 100644
--- a/mlir/unittests/IR/IRMapping.cpp
+++ b/mlir/unittests/IR/IRMapping.cpp
@@ -32,7 +32,7 @@ TEST(IRMapping, TypedValue) {
IRMapping mapping;
mapping.map(i64Val, f64Val);
- TypedValue<IntegerType> typedI64Val = i64Val;
+ auto typedI64Val = cast<TypedValue<IntegerType>>(i64Val);
EXPECT_EQ(mapping.lookup(typedI64Val), f64Val);
}
More information about the Mlir-commits
mailing list