[Mlir-commits] [mlir] 0ce25b1 - [mlir] Require explicit casts when using TypedValue

Rahul Kayaith llvmlistbot at llvm.org
Wed Feb 1 18:55:01 PST 2023


Author: Rahul Kayaith
Date: 2023-02-01T21:54:53-05:00
New Revision: 0ce25b12357b24d06cf08cc02719c144d567d5db

URL: https://github.com/llvm/llvm-project/commit/0ce25b12357b24d06cf08cc02719c144d567d5db
DIFF: https://github.com/llvm/llvm-project/commit/0ce25b12357b24d06cf08cc02719c144d567d5db.diff

LOG: [mlir] Require explicit casts when using TypedValue

Currently `TypedValue` can be constructed directly from `Value`, hiding
errors that could be caught at compile time. For example the following
will compile, but crash/assert at runtime:
```
void foo(TypedValue<IntegerType>);
void bar(TypedValue<FloatType> v) {
  foo(v);
}
```

This change removes the constructors and replaces them with explicit
llvm casts.

Depends on D142852

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D142855

Added: 
    

Modified: 
    mlir/docs/Tutorials/Toy/Ch-4.md
    mlir/examples/toy/Ch4/mlir/Dialect.cpp
    mlir/examples/toy/Ch5/mlir/Dialect.cpp
    mlir/examples/toy/Ch6/mlir/Dialect.cpp
    mlir/examples/toy/Ch7/mlir/Dialect.cpp
    mlir/include/mlir/IR/OpDefinition.h
    mlir/include/mlir/IR/Value.h
    mlir/include/mlir/TableGen/Class.h
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
    mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h
    mlir/test/mlir-tblgen/op-operand.td
    mlir/test/mlir-tblgen/op-result.td
    mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
    mlir/unittests/IR/IRMapping.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/Tutorials/Toy/Ch-4.md b/mlir/docs/Tutorials/Toy/Ch-4.md
index 95cdca0c3fa32..df82141c77ba5 100644
--- a/mlir/docs/Tutorials/Toy/Ch-4.md
+++ b/mlir/docs/Tutorials/Toy/Ch-4.md
@@ -375,7 +375,7 @@ inferred as the shape of the inputs.
 ```c++
 /// Infer the output shape of the MulOp, this is required by the shape inference
 /// interface.
-void MulOp::inferShapes() { getResult().setType(getOperand(0).getType()); }
+void MulOp::inferShapes() { getResult().setType(getLhs().getType()); }
 ```
 
 At this point, each of the necessary Toy operations provide a mechanism by which

diff  --git a/mlir/examples/toy/Ch4/mlir/Dialect.cpp b/mlir/examples/toy/Ch4/mlir/Dialect.cpp
index c87e107f6f415..17a42d69c8f4c 100644
--- a/mlir/examples/toy/Ch4/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch4/mlir/Dialect.cpp
@@ -237,7 +237,7 @@ void AddOp::print(mlir::OpAsmPrinter &p) { printBinaryOp(p, *this); }
 
 /// Infer the output shape of the AddOp, this is required by the shape inference
 /// interface.
-void AddOp::inferShapes() { getResult().setType(getOperand(0).getType()); }
+void AddOp::inferShapes() { getResult().setType(getLhs().getType()); }
 
 //===----------------------------------------------------------------------===//
 // CastOp
@@ -245,7 +245,7 @@ void AddOp::inferShapes() { getResult().setType(getOperand(0).getType()); }
 
 /// Infer the output shape of the CastOp, this is required by the shape
 /// inference interface.
-void CastOp::inferShapes() { getResult().setType(getOperand().getType()); }
+void CastOp::inferShapes() { getResult().setType(getInput().getType()); }
 
 /// Returns true if the given set of input and result types are compatible with
 /// this cast operation. This is required by the `CastOpInterface` to verify
@@ -349,7 +349,7 @@ void MulOp::print(mlir::OpAsmPrinter &p) { printBinaryOp(p, *this); }
 
 /// Infer the output shape of the MulOp, this is required by the shape inference
 /// interface.
-void MulOp::inferShapes() { getResult().setType(getOperand(0).getType()); }
+void MulOp::inferShapes() { getResult().setType(getLhs().getType()); }
 
 //===----------------------------------------------------------------------===//
 // ReturnOp

diff  --git a/mlir/examples/toy/Ch5/mlir/Dialect.cpp b/mlir/examples/toy/Ch5/mlir/Dialect.cpp
index 04ae3149281f9..77ceb636e17f2 100644
--- a/mlir/examples/toy/Ch5/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch5/mlir/Dialect.cpp
@@ -237,7 +237,7 @@ void AddOp::print(mlir::OpAsmPrinter &p) { printBinaryOp(p, *this); }
 
 /// Infer the output shape of the AddOp, this is required by the shape inference
 /// interface.
-void AddOp::inferShapes() { getResult().setType(getOperand(0).getType()); }
+void AddOp::inferShapes() { getResult().setType(getLhs().getType()); }
 
 //===----------------------------------------------------------------------===//
 // CastOp
@@ -245,7 +245,7 @@ void AddOp::inferShapes() { getResult().setType(getOperand(0).getType()); }
 
 /// Infer the output shape of the CastOp, this is required by the shape
 /// inference interface.
-void CastOp::inferShapes() { getResult().setType(getOperand().getType()); }
+void CastOp::inferShapes() { getResult().setType(getInput().getType()); }
 
 /// Returns true if the given set of input and result types are compatible with
 /// this cast operation. This is required by the `CastOpInterface` to verify
@@ -349,7 +349,7 @@ void MulOp::print(mlir::OpAsmPrinter &p) { printBinaryOp(p, *this); }
 
 /// Infer the output shape of the MulOp, this is required by the shape inference
 /// interface.
-void MulOp::inferShapes() { getResult().setType(getOperand(0).getType()); }
+void MulOp::inferShapes() { getResult().setType(getLhs().getType()); }
 
 //===----------------------------------------------------------------------===//
 // ReturnOp

diff  --git a/mlir/examples/toy/Ch6/mlir/Dialect.cpp b/mlir/examples/toy/Ch6/mlir/Dialect.cpp
index 04ae3149281f9..77ceb636e17f2 100644
--- a/mlir/examples/toy/Ch6/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch6/mlir/Dialect.cpp
@@ -237,7 +237,7 @@ void AddOp::print(mlir::OpAsmPrinter &p) { printBinaryOp(p, *this); }
 
 /// Infer the output shape of the AddOp, this is required by the shape inference
 /// interface.
-void AddOp::inferShapes() { getResult().setType(getOperand(0).getType()); }
+void AddOp::inferShapes() { getResult().setType(getLhs().getType()); }
 
 //===----------------------------------------------------------------------===//
 // CastOp
@@ -245,7 +245,7 @@ void AddOp::inferShapes() { getResult().setType(getOperand(0).getType()); }
 
 /// Infer the output shape of the CastOp, this is required by the shape
 /// inference interface.
-void CastOp::inferShapes() { getResult().setType(getOperand().getType()); }
+void CastOp::inferShapes() { getResult().setType(getInput().getType()); }
 
 /// Returns true if the given set of input and result types are compatible with
 /// this cast operation. This is required by the `CastOpInterface` to verify
@@ -349,7 +349,7 @@ void MulOp::print(mlir::OpAsmPrinter &p) { printBinaryOp(p, *this); }
 
 /// Infer the output shape of the MulOp, this is required by the shape inference
 /// interface.
-void MulOp::inferShapes() { getResult().setType(getOperand(0).getType()); }
+void MulOp::inferShapes() { getResult().setType(getLhs().getType()); }
 
 //===----------------------------------------------------------------------===//
 // ReturnOp

diff  --git a/mlir/examples/toy/Ch7/mlir/Dialect.cpp b/mlir/examples/toy/Ch7/mlir/Dialect.cpp
index 3382cbcd3074a..188b94fc2dfeb 100644
--- a/mlir/examples/toy/Ch7/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch7/mlir/Dialect.cpp
@@ -243,7 +243,9 @@ mlir::LogicalResult StructConstantOp::verify() {
 
 /// Infer the output shape of the ConstantOp, this is required by the shape
 /// inference interface.
-void ConstantOp::inferShapes() { getResult().setType(getValue().getType()); }
+void ConstantOp::inferShapes() {
+  getResult().setType(cast<TensorType>(getValue().getType()));
+}
 
 //===----------------------------------------------------------------------===//
 // AddOp
@@ -264,7 +266,7 @@ void AddOp::print(mlir::OpAsmPrinter &p) { printBinaryOp(p, *this); }
 
 /// Infer the output shape of the AddOp, this is required by the shape inference
 /// interface.
-void AddOp::inferShapes() { getResult().setType(getOperand(0).getType()); }
+void AddOp::inferShapes() { getResult().setType(getLhs().getType()); }
 
 //===----------------------------------------------------------------------===//
 // CastOp
@@ -272,7 +274,7 @@ void AddOp::inferShapes() { getResult().setType(getOperand(0).getType()); }
 
 /// Infer the output shape of the CastOp, this is required by the shape
 /// inference interface.
-void CastOp::inferShapes() { getResult().setType(getOperand().getType()); }
+void CastOp::inferShapes() { getResult().setType(getInput().getType()); }
 
 /// Returns true if the given set of input and result types are compatible with
 /// this cast operation. This is required by the `CastOpInterface` to verify
@@ -376,7 +378,7 @@ void MulOp::print(mlir::OpAsmPrinter &p) { printBinaryOp(p, *this); }
 
 /// Infer the output shape of the MulOp, this is required by the shape inference
 /// interface.
-void MulOp::inferShapes() { getResult().setType(getOperand(0).getType()); }
+void MulOp::inferShapes() { getResult().setType(getLhs().getType()); }
 
 //===----------------------------------------------------------------------===//
 // ReturnOp

diff  --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index 60c961f91cfac..4c0a849ba908c 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -632,7 +632,7 @@ class OneTypedResult {
       : public TraitBase<ConcreteType, OneTypedResult<ResultType>::Impl> {
   public:
     TypedValue<ResultType> getResult() {
-      return this->getOperation()->getResult(0);
+      return cast<TypedValue<ResultType>>(this->getOperation()->getResult(0));
     }
 
     /// If the operation returns a single value, then the Op can be implicitly

diff  --git a/mlir/include/mlir/IR/Value.h b/mlir/include/mlir/IR/Value.h
index c84ae97c99b9b..e95bfcf0252d0 100644
--- a/mlir/include/mlir/IR/Value.h
+++ b/mlir/include/mlir/IR/Value.h
@@ -427,21 +427,13 @@ inline unsigned OpResultImpl::getResultNumber() const {
 /// TypedValue can be null/empty
 template <typename Ty>
 struct TypedValue : Value {
+  using Value::Value;
+
+  static bool classof(Value value) { return llvm::isa<Ty>(value.getType()); }
+
   /// Return the known Type
   Ty getType() { return Value::getType().template cast<Ty>(); }
-  void setType(mlir::Type ty) {
-    assert(ty.template isa<Ty>());
-    Value::setType(ty);
-  }
-
-  TypedValue(Value val) : Value(val) {
-    assert(!val || val.getType().template isa<Ty>());
-  }
-  TypedValue &operator=(const Value &other) {
-    assert(!other || other.getType().template isa<Ty>());
-    Value::operator=(other);
-    return *this;
-  }
+  void setType(Ty ty) { Value::setType(ty); }
 };
 
 } // namespace detail

diff  --git a/mlir/include/mlir/TableGen/Class.h b/mlir/include/mlir/TableGen/Class.h
index 954ef5bec2d04..9c4efcd41cd6b 100644
--- a/mlir/include/mlir/TableGen/Class.h
+++ b/mlir/include/mlir/TableGen/Class.h
@@ -152,6 +152,9 @@ class MethodSignature {
   /// Get the name of the method.
   StringRef getName() const { return methodName; }
 
+  /// Get the return type of the method
+  StringRef getReturnType() const { return returnType; }
+
   /// Get the number of parameters.
   unsigned getNumParameters() const { return parameters.getNumParameters(); }
 
@@ -344,6 +347,9 @@ class Method : public ClassDeclarationBase<ClassDeclaration::Method> {
   /// Returns the name of this method.
   StringRef getName() const { return methodSignature.getName(); }
 
+  /// Returns the return type of this method
+  StringRef getReturnType() const { return methodSignature.getReturnType(); }
+
   /// Returns if this method makes the `other` method redundant.
   bool makesRedundant(const Method &other) const {
     return methodSignature.makesRedundant(other.methodSignature);

diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index aa41c962826da..46f5292a9ad1a 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1884,7 +1884,7 @@ class ReverseConverter : public OpRewritePattern<tosa::ReverseOp> {
         [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
           llvm::SmallVector<Value> indices;
           for (unsigned int i = 0; i < inputTy.getRank(); i++) {
-            auto index =
+            Value index =
                 rewriter.create<linalg::IndexOp>(nestedLoc, i).getResult();
             if (i == axis) {
               auto one = rewriter.create<arith::ConstantIndexOp>(nestedLoc, 1);

diff  --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 931a85b636f9d..d8070b34a761d 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -1033,7 +1033,7 @@ struct UnrollTransferWriteConversion
     auto vec = getDataVector(xferOp);
     auto xferVecType = xferOp.getVectorType();
     int64_t dimSize = xferVecType.getShape()[0];
-    auto source = xferOp.getSource(); // memref or tensor to be written to.
+    Value source = xferOp.getSource(); // memref or tensor to be written to.
     auto sourceType = isTensorOp(xferOp) ? xferOp.getShapedType() : Type();
 
     // Generate fully unrolled loop of transfer ops.

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 1e2a7b017b7ff..9cd2331c24d19 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -1056,7 +1056,7 @@ struct NewRewriter : public OpRewritePattern<NewOp> {
     //   %t = sparse_tensor.ConvertOp %tmp
     RankedTensorType cooTp =
         getUnorderedCOOFromTypeWithOrdering(dstTp, encDst.getDimOrdering());
-    auto cooBuffer =
+    Value cooBuffer =
         rewriter.create<AllocTensorOp>(loc, cooTp, dynSizesArray).getResult();
 
     Value c0 = constantIndex(rewriter, loc, 0);

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h
index d5554ea8fbd26..98c75510b46e7 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h
@@ -173,7 +173,8 @@ class StorageLayout {
 
 class SparseTensorSpecifier {
 public:
-  explicit SparseTensorSpecifier(Value specifier) : specifier(specifier) {}
+  explicit SparseTensorSpecifier(Value specifier)
+      : specifier(cast<TypedValue<StorageSpecifierType>>(specifier)) {}
 
   // Undef value for dimension sizes, all zero value for memory sizes.
   static Value getInitValue(OpBuilder &builder, Location loc,

diff  --git a/mlir/test/mlir-tblgen/op-operand.td b/mlir/test/mlir-tblgen/op-operand.td
index 7a76f98ead581..68a9def83c2e0 100644
--- a/mlir/test/mlir-tblgen/op-operand.td
+++ b/mlir/test/mlir-tblgen/op-operand.td
@@ -43,7 +43,7 @@ def OpD : NS_Op<"mix_variadic_and_normal_inputs_op", [SameVariadicOperandSize]>
 // CHECK-NEXT: return getODSOperands(0);
 
 // CHECK-LABEL: ::mlir::TypedValue<::mlir::TensorType> OpD::getInput2
-// CHECK-NEXT: return *getODSOperands(1).begin();
+// CHECK-NEXT: return ::llvm::cast<::mlir::TypedValue<::mlir::TensorType>>(*getODSOperands(1).begin());
 
 // CHECK-LABEL: OpD::build
 // CHECK-NEXT: odsState.addOperands(input1);

diff  --git a/mlir/test/mlir-tblgen/op-result.td b/mlir/test/mlir-tblgen/op-result.td
index 06178f8489c00..d49bffa9cb441 100644
--- a/mlir/test/mlir-tblgen/op-result.td
+++ b/mlir/test/mlir-tblgen/op-result.td
@@ -100,7 +100,7 @@ def OpI : NS_Op<"mix_variadic_and_normal_results_op", [SameVariadicResultSize]>
 // CHECK-NEXT:    return getODSResults(0);
 
 // CHECK-LABEL: ::mlir::TypedValue<::mlir::TensorType> OpI::getOutput2
-// CHECK-NEXT:    return *getODSResults(1).begin();
+// CHECK-NEXT:    return ::llvm::cast<::mlir::TypedValue<::mlir::TensorType>>(*getODSResults(1).begin());
 
 // CHECK-LABEL: OpI::build
 // CHECK-NEXT:    odsState.addTypes(output1);

diff  --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 7ed29f91d3d64..bc3e2599e72d7 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -1337,10 +1337,12 @@ generateNamedOperandGetters(const Operator &op, Class &opClass,
                                 : generateTypeForGetter(operand),
                             name);
       ERROR_IF_PRUNED(m, name, op);
-      m->body().indent() << formatv(
-          "auto operands = getODSOperands({0});\n"
-          "return operands.empty() ? {1}{{} : *operands.begin();",
-          i, rangeElementType);
+      m->body().indent() << formatv("auto operands = getODSOperands({0});\n"
+                                    "return operands.empty() ? {1}{{} : ",
+                                    i, m->getReturnType());
+      if (!isGenericAdaptorBase)
+        m->body() << llvm::formatv("::llvm::cast<{0}>", m->getReturnType());
+      m->body() << "(*operands.begin());";
     } else if (operand.isVariadicOfVariadic()) {
       std::string segmentAttr = op.getGetterName(
           operand.constraint.getVariadicOfVariadicSegmentSizeAttr());
@@ -1366,7 +1368,10 @@ generateNamedOperandGetters(const Operator &op, Class &opClass,
                                 : generateTypeForGetter(operand),
                             name);
       ERROR_IF_PRUNED(m, name, op);
-      m->body() << "  return *getODSOperands(" << i << ").begin();";
+      m->body().indent() << "return ";
+      if (!isGenericAdaptorBase)
+        m->body() << llvm::formatv("::llvm::cast<{0}>", m->getReturnType());
+      m->body() << llvm::formatv("(*getODSOperands({0}).begin());", i);
     }
   }
 }
@@ -1489,9 +1494,11 @@ void OpEmitter::genNamedResultGetters() {
     if (result.isOptional()) {
       m = opClass.addMethod(generateTypeForGetter(result), name);
       ERROR_IF_PRUNED(m, name, op);
-      m->body()
-          << "  auto results = getODSResults(" << i << ");\n"
-          << "  return results.empty() ? ::mlir::Value() : *results.begin();";
+      m->body() << "  auto results = getODSResults(" << i << ");\n"
+                << llvm::formatv("  return results.empty()"
+                                 " ? {0}()"
+                                 " : ::llvm::cast<{0}>(*results.begin());",
+                                 m->getReturnType());
     } else if (result.isVariadic()) {
       m = opClass.addMethod("::mlir::Operation::result_range", name);
       ERROR_IF_PRUNED(m, name, op);
@@ -1499,7 +1506,9 @@ void OpEmitter::genNamedResultGetters() {
     } else {
       m = opClass.addMethod(generateTypeForGetter(result), name);
       ERROR_IF_PRUNED(m, name, op);
-      m->body() << "  return *getODSResults(" << i << ").begin();";
+      m->body() << llvm::formatv(
+          "  return ::llvm::cast<{0}>(*getODSResults({1}).begin());",
+          m->getReturnType(), i);
     }
   }
 }

diff  --git a/mlir/unittests/IR/IRMapping.cpp b/mlir/unittests/IR/IRMapping.cpp
index bf00e5eea2260..83627975006ee 100644
--- a/mlir/unittests/IR/IRMapping.cpp
+++ b/mlir/unittests/IR/IRMapping.cpp
@@ -32,7 +32,7 @@ TEST(IRMapping, TypedValue) {
 
   IRMapping mapping;
   mapping.map(i64Val, f64Val);
-  TypedValue<IntegerType> typedI64Val = i64Val;
+  auto typedI64Val = cast<TypedValue<IntegerType>>(i64Val);
   EXPECT_EQ(mapping.lookup(typedI64Val), f64Val);
 }
 


        


More information about the Mlir-commits mailing list