[Mlir-commits] [mlir] 688c51a - [MLIR] Add mlir::TypedValue

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Aug 28 20:26:50 PDT 2022


Author: Tyker
Date: 2022-08-28T20:26:37-07:00
New Revision: 688c51a5acc53b456014e53663051476d825e896

URL: https://github.com/llvm/llvm-project/commit/688c51a5acc53b456014e53663051476d825e896
DIFF: https://github.com/llvm/llvm-project/commit/688c51a5acc53b456014e53663051476d825e896.diff

LOG: [MLIR] Add mlir::TypedValue

mlir::TypedValue is a wrapper class for mlir::Values with a known type
getType will return the known type and all assignements will be checked

Also the tablegen Operation generator was adapted to use mlir::TypedValue
when appropriate

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h
    mlir/include/mlir/IR/Value.h
    mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
    mlir/lib/Dialect/SCF/Utils/Utils.cpp
    mlir/lib/Dialect/Shape/IR/Shape.cpp
    mlir/test/lib/Dialect/Test/TestDialect.h
    mlir/test/mlir-tblgen/op-attribute.td
    mlir/test/mlir-tblgen/op-decl-and-defs.td
    mlir/test/mlir-tblgen/op-operand.td
    mlir/test/mlir-tblgen/op-result.td
    mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h b/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h
index 664797bb8b84e..bc51845b96064 100644
--- a/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h
+++ b/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h
@@ -9,6 +9,7 @@
 #ifndef MLIR_DIALECT_BUFFERIZATION_TRANSFORMOPS_BUFFERIZATIONTRANSFORMOPS_H
 #define MLIR_DIALECT_BUFFERIZATION_TRANSFORMOPS_BUFFERIZATIONTRANSFORMOPS_H
 
+#include "mlir/Dialect/PDL/IR/PDLTypes.h"
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
 #include "mlir/IR/OpImplementation.h"
 

diff  --git a/mlir/include/mlir/IR/Value.h b/mlir/include/mlir/IR/Value.h
index 2f0524c4da7df..c9ff4f0145938 100644
--- a/mlir/include/mlir/IR/Value.h
+++ b/mlir/include/mlir/IR/Value.h
@@ -419,6 +419,27 @@ inline unsigned OpResultImpl::getResultNumber() const {
   return cast<InlineOpResult>(this)->getResultNumber();
 }
 
+/// TypedValue is a Value with a statically know type.
+/// TypedValue can be null/empty
+template <typename Ty>
+struct TypedValue : Value {
+  /// Return the known Type
+  Ty getType() { return Value::getType().template cast<Ty>(); }
+  void setType(mlir::Type ty) {
+    assert(ty.template isa<Ty>());
+    Value::setType(ty);
+  }
+
+  TypedValue(Value val) : Value(val) {
+    assert(!val || val.getType().template isa<Ty>());
+  }
+  TypedValue &operator=(const Value &other) {
+    assert(!other || other.getType().template isa<Ty>());
+    Value::operator=(other);
+    return *this;
+  }
+};
+
 } // namespace detail
 
 /// This is a value defined by a result of an operation.
@@ -459,6 +480,12 @@ inline ::llvm::hash_code hash_value(Value arg) {
   return ::llvm::hash_value(arg.getImpl());
 }
 
+template <typename Ty, typename Value = mlir::Value>
+/// If Ty is mlir::Type this will select `Value` instead of having a wrapper
+/// around it. This helps resolve ambiguous conversion issues.
+using TypedValue = std::conditional_t<std::is_same_v<Ty, mlir::Type>,
+                                      mlir::Value, detail::TypedValue<Ty>>;
+
 } // namespace mlir
 
 namespace llvm {

diff  --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 8756603acf3e6..e4448ff1f3100 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -784,7 +784,7 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
       launchOp.getKernelName().getValue(), loc, rewriter);
   auto function = moduleGetFunctionCallBuilder.create(
       loc, rewriter, {module.getResult(), kernelName});
-  auto zero = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type, 0);
+  Value zero = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type, 0);
   Value stream =
       adaptor.asyncDependencies().empty()
           ? streamCreateCallBuilder.create(loc, rewriter, {}).getResult()

diff  --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 7569476cb7cd1..501d95b3c791d 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -428,7 +428,7 @@ LogicalResult mlir::loopUnrollByFactor(
   // 'upperBoundUnrolled' and 'stepUnrolled' for static and dynamic cases.
   OpBuilder boundsBuilder(forOp);
   auto loc = forOp.getLoc();
-  auto step = forOp.getStep();
+  Value step = forOp.getStep();
   Value upperBoundUnrolled;
   Value stepUnrolled;
   bool generateEpilogueLoop = true;

diff  --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 2b8cae83fd414..2227064f6791c 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -778,7 +778,7 @@ struct CanonicalizeCastExtentTensorOperandsPattern
                                 PatternRewriter &rewriter) const override {
     // Canonicalize operands.
     bool anyChange = false;
-    auto canonicalizeOperand = [&](Value operand) {
+    auto canonicalizeOperand = [&](Value operand) -> Value {
       if (auto castOp = operand.getDefiningOp<tensor::CastOp>()) {
         // Only eliminate the cast if it holds no shape information.
         bool isInformationLoosingCast =

diff  --git a/mlir/test/lib/Dialect/Test/TestDialect.h b/mlir/test/lib/Dialect/Test/TestDialect.h
index 0e583d38aa1a5..ceb9dc6f4c933 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.h
+++ b/mlir/test/lib/Dialect/Test/TestDialect.h
@@ -14,6 +14,7 @@
 #ifndef MLIR_TESTDIALECT_H
 #define MLIR_TESTDIALECT_H
 
+#include "TestTypes.h"
 #include "TestAttributes.h"
 #include "TestInterfaces.h"
 #include "mlir/Dialect/DLTI/DLTI.h"

diff  --git a/mlir/test/mlir-tblgen/op-attribute.td b/mlir/test/mlir-tblgen/op-attribute.td
index 3c85d71460dc5..3e632d7502179 100644
--- a/mlir/test/mlir-tblgen/op-attribute.td
+++ b/mlir/test/mlir-tblgen/op-attribute.td
@@ -471,8 +471,8 @@ def OpWithDefaultAndSuccessor : NS_Op<"default_with_succ", []> {
 // DECL: static void build({{.*}}, bool dv_bool_attr, ::mlir::BlockRange succ)
 
 // DEF-LABEL: MixOperandsAndAttrs definitions
-// DEF-DAG: ::mlir::Value MixOperandsAndAttrs::operand()
-// DEF-DAG: ::mlir::Value MixOperandsAndAttrs::otherArg()
+// DEF-DAG: ::mlir::TypedValue<::mlir::FloatType> MixOperandsAndAttrs::operand()
+// DEF-DAG: ::mlir::TypedValue<::mlir::FloatType> MixOperandsAndAttrs::otherArg()
 // DEF-DAG: void MixOperandsAndAttrs::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::FloatAttr attr, ::mlir::Value operand, ::mlir::FloatAttr otherAttr, ::mlir::Value otherArg)
 // DEF-DAG: ::llvm::APFloat MixOperandsAndAttrs::attr()
 // DEF-DAG: ::llvm::APFloat MixOperandsAndAttrs::otherAttr()

diff  --git a/mlir/test/mlir-tblgen/op-decl-and-defs.td b/mlir/test/mlir-tblgen/op-decl-and-defs.td
index cab95c3294673..76d76c974a733 100644
--- a/mlir/test/mlir-tblgen/op-decl-and-defs.td
+++ b/mlir/test/mlir-tblgen/op-decl-and-defs.td
@@ -78,12 +78,12 @@ def NS_AOp : NS_Op<"a_op", [IsolatedFromAbove, IsolatedFromAbove]> {
 // CHECK:     return ::llvm::StringLiteral("test.a_op");
 // CHECK:   }
 // CHECK:   ::mlir::Operation::operand_range getODSOperands(unsigned index);
-// CHECK:   ::mlir::Value getA();
+// CHECK:   ::mlir::TypedValue<::mlir::IntegerType> getA();
 // CHECK:   ::mlir::Operation::operand_range getB();
 // CHECK:   ::mlir::MutableOperandRange getAMutable();
 // CHECK:   ::mlir::MutableOperandRange getBMutable();
 // CHECK:   ::mlir::Operation::result_range getODSResults(unsigned index);
-// CHECK:   ::mlir::Value getR();
+// CHECK:   ::mlir::TypedValue<::mlir::IntegerType> getR();
 // CHECK:   ::mlir::Region &getSomeRegion();
 // CHECK:   ::mlir::MutableArrayRef<::mlir::Region> getSomeRegions();
 // CHECK:   ::mlir::IntegerAttr getAttr1Attr()
@@ -169,7 +169,7 @@ def NS_EOp : NS_Op<"op_with_optionals", []> {
 // CHECK-LABEL: NS::EOp declarations
 // CHECK:   ::mlir::Value getA();
 // CHECK:   ::mlir::MutableOperandRange getAMutable();
-// CHECK:   ::mlir::Value getB();
+// CHECK:   ::mlir::TypedValue<::mlir::FloatType> getB();
 // CHECK:   static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, /*optional*/::mlir::Type b, /*optional*/::mlir::Value a)
 
 // Check that all types match constraint results in generating builder.

diff  --git a/mlir/test/mlir-tblgen/op-operand.td b/mlir/test/mlir-tblgen/op-operand.td
index 7e2c2c55e0a4b..c4f904dc0a14a 100644
--- a/mlir/test/mlir-tblgen/op-operand.td
+++ b/mlir/test/mlir-tblgen/op-operand.td
@@ -54,7 +54,7 @@ def OpD : NS_Op<"mix_variadic_and_normal_inputs_op", [SameVariadicOperandSize]>
 // CHECK-LABEL: ::mlir::Operation::operand_range OpD::input1
 // CHECK-NEXT: return getODSOperands(0);
 
-// CHECK-LABEL: ::mlir::Value OpD::input2
+// CHECK-LABEL: ::mlir::TypedValue<::mlir::TensorType> OpD::input2
 // CHECK-NEXT: return *getODSOperands(1).begin();
 
 // CHECK-LABEL: OpD::build

diff  --git a/mlir/test/mlir-tblgen/op-result.td b/mlir/test/mlir-tblgen/op-result.td
index cf7fb6c30870d..5c356e0137fd7 100644
--- a/mlir/test/mlir-tblgen/op-result.td
+++ b/mlir/test/mlir-tblgen/op-result.td
@@ -102,7 +102,7 @@ def OpI : NS_Op<"mix_variadic_and_normal_results_op", [SameVariadicResultSize]>
 // CHECK-LABEL: ::mlir::Operation::result_range OpI::output1
 // CHECK-NEXT:    return getODSResults(0);
 
-// CHECK-LABEL: ::mlir::Value OpI::output2
+// CHECK-LABEL: ::mlir::TypedValue<::mlir::TensorType> OpI::output2
 // CHECK-NEXT:    return *getODSResults(1).begin();
 
 // CHECK-LABEL: OpI::build

diff  --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 9d5c391224c0d..70c22e17f27ee 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -1152,6 +1152,25 @@ generateValueRangeStartAndEnd(Class &opClass, StringRef methodName,
   }
 }
 
+static std::string generateTypeForGetter(bool isAdaptor,
+                                         const NamedTypeConstraint &value) {
+  std::string str = "::mlir::Value";
+  /// If the CPPClassName is not a fully qualified type. Uses of types
+  /// across Dialect fail because they are not in the correct namespace. So we
+  /// dont generate TypedValue unless the type is fully qualified.
+  /// getCPPClassName doesn't return the fully qualified path for
+  /// `mlir::pdl::OperationType` see
+  /// https://github.com/llvm/llvm-project/issues/57279.
+  /// Adaptor will have values that are not from the type of their operation and
+  /// this is expected, so we dont generate TypedValue for Adaptor
+  if (!isAdaptor && value.constraint.getCPPClassName() != "::mlir::Type" &&
+      StringRef(value.constraint.getCPPClassName()).startswith("::"))
+    str = llvm::formatv("::mlir::TypedValue<{0}>",
+                        value.constraint.getCPPClassName())
+              .str();
+  return str;
+}
+
 // Generates the named operand getter methods for the given Operator `op` and
 // puts them in `opClass`.  Uses `rangeType` as the return type of getters that
 // return a range of operands (individual operands are `Value ` and each
@@ -1216,7 +1235,7 @@ static void generateNamedOperandGetters(const Operator &op, Class &opClass,
       continue;
     for (StringRef name : op.getGetterNames(operand.name)) {
       if (operand.isOptional()) {
-        m = opClass.addMethod("::mlir::Value", name);
+        m = opClass.addMethod(generateTypeForGetter(isAdaptor, operand), name);
         ERROR_IF_PRUNED(m, name, op);
         m->body() << "  auto operands = getODSOperands(" << i << ");\n"
                   << "  return operands.empty() ? ::mlir::Value() : "
@@ -1242,7 +1261,7 @@ static void generateNamedOperandGetters(const Operator &op, Class &opClass,
         ERROR_IF_PRUNED(m, name, op);
         m->body() << "  return getODSOperands(" << i << ");";
       } else {
-        m = opClass.addMethod("::mlir::Value", name);
+        m = opClass.addMethod(generateTypeForGetter(isAdaptor, operand), name);
         ERROR_IF_PRUNED(m, name, op);
         m->body() << "  return *getODSOperands(" << i << ").begin();";
       }
@@ -1365,7 +1384,8 @@ void OpEmitter::genNamedResultGetters() {
       continue;
     for (StringRef name : op.getGetterNames(result.name)) {
       if (result.isOptional()) {
-        m = opClass.addMethod("::mlir::Value", name);
+        m = opClass.addMethod(
+            generateTypeForGetter(/*isAdaptor=*/false, result), name);
         ERROR_IF_PRUNED(m, name, op);
         m->body()
             << "  auto results = getODSResults(" << i << ");\n"
@@ -1375,7 +1395,8 @@ void OpEmitter::genNamedResultGetters() {
         ERROR_IF_PRUNED(m, name, op);
         m->body() << "  return getODSResults(" << i << ");";
       } else {
-        m = opClass.addMethod("::mlir::Value", name);
+        m = opClass.addMethod(
+            generateTypeForGetter(/*isAdaptor=*/false, result), name);
         ERROR_IF_PRUNED(m, name, op);
         m->body() << "  return *getODSResults(" << i << ").begin();";
       }


        


More information about the Mlir-commits mailing list