[Mlir-commits] [mlir] 1b8465a - [mlir] Add CastInfo for mlir classes subclassing from PointerUnion

Tres Popp llvmlistbot at llvm.org
Thu May 25 22:47:52 PDT 2023


Author: Tres Popp
Date: 2023-05-26T07:47:03+02:00
New Revision: 1b8465aac4368c64d3e78ebd94fb8ca048b9e801

URL: https://github.com/llvm/llvm-project/commit/1b8465aac4368c64d3e78ebd94fb8ca048b9e801
DIFF: https://github.com/llvm/llvm-project/commit/1b8465aac4368c64d3e78ebd94fb8ca048b9e801.diff

LOG: [mlir] Add CastInfo for mlir classes subclassing from PointerUnion

This is required to use the function variants of cast/isa/dyn_cast/etc
on them.

Context:
- https://mlir.llvm.org/deprecation/ at "Use the free function variants for dyn_cast/cast/isa/…"
- Original discussion at https://discourse.llvm.org/t/preferred-casting-style-going-forward/68443

Added: 
    

Modified: 
    mlir/include/mlir/Analysis/DataFlowFramework.h
    mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
    mlir/include/mlir/IR/OpDefinition.h
    mlir/include/mlir/IR/Unit.h
    mlir/include/mlir/Interfaces/CallInterfaces.h

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Analysis/DataFlowFramework.h b/mlir/include/mlir/Analysis/DataFlowFramework.h
index 68f3db7d507d4..9649f918faa2f 100644
--- a/mlir/include/mlir/Analysis/DataFlowFramework.h
+++ b/mlir/include/mlir/Analysis/DataFlowFramework.h
@@ -470,6 +470,16 @@ namespace llvm {
 template <>
 struct DenseMapInfo<mlir::ProgramPoint>
     : public DenseMapInfo<mlir::ProgramPoint::ParentTy> {};
+
+// Allow llvm::cast style functions.
+template <typename To>
+struct CastInfo<To, mlir::ProgramPoint>
+    : public CastInfo<To, mlir::ProgramPoint::PointerUnion> {};
+
+template <typename To>
+struct CastInfo<To, const mlir::ProgramPoint>
+    : public CastInfo<To, const mlir::ProgramPoint::PointerUnion> {};
+
 } // end namespace llvm
 
 #endif // MLIR_ANALYSIS_DATAFLOWFRAMEWORK_H

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
index 3725cdd838b9e..5f88e10cd5ae9 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
@@ -236,4 +236,17 @@ SmallVector<IntT> convertArrayToIndices(ArrayAttr attrs) {
 } // namespace LLVM
 } // namespace mlir
 
+namespace llvm {
+
+// Allow llvm::cast style functions.
+template <typename To>
+struct CastInfo<To, mlir::LLVM::GEPArg>
+    : public CastInfo<To, mlir::LLVM::GEPArg::PointerUnion> {};
+
+template <typename To>
+struct CastInfo<To, const mlir::LLVM::GEPArg>
+    : public CastInfo<To, const mlir::LLVM::GEPArg::PointerUnion> {};
+
+} // namespace llvm
+
 #endif // MLIR_DIALECT_LLVMIR_LLVMDIALECT_H_

diff  --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index 71864def4543e..f3734dc648275 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -269,15 +269,35 @@ class OpFoldResult : public PointerUnion<Attribute, Value> {
   void dump() const { llvm::errs() << *this << "\n"; }
 };
 
+// Temporarily exit the MLIR namespace to add casting support as later code in
+// this uses it. The CastInfo must come after the OpFoldResult definition and
+// before any cast function calls depending on CastInfo.
+
+} // namespace mlir
+
+namespace llvm {
+
+// Allow llvm::cast style functions.
+template <typename To>
+struct CastInfo<To, mlir::OpFoldResult>
+    : public CastInfo<To, mlir::OpFoldResult::PointerUnion> {};
+
+template <typename To>
+struct CastInfo<To, const mlir::OpFoldResult>
+    : public CastInfo<To, const mlir::OpFoldResult::PointerUnion> {};
+
+} // namespace llvm
+
+namespace mlir {
+
 /// Allow printing to a stream.
 inline raw_ostream &operator<<(raw_ostream &os, OpFoldResult ofr) {
-  if (Value value = ofr.dyn_cast<Value>())
+  if (Value value = llvm::dyn_cast_if_present<Value>(ofr))
     value.print(os);
   else
-    ofr.dyn_cast<Attribute>().print(os);
+    llvm::dyn_cast_if_present<Attribute>(ofr).print(os);
   return os;
 }
-
 /// Allow printing to a stream.
 inline raw_ostream &operator<<(raw_ostream &os, OpState op) {
   op.print(os, OpPrintingFlags().useLocalScope());
@@ -1554,7 +1574,7 @@ foldTrait(Operation *op, ArrayRef<Attribute> operands,
     return failure();
 
   if (OpFoldResult result = Trait::foldTrait(op, operands)) {
-    if (result.template dyn_cast<Value>() != op->getResult(0))
+    if (llvm::dyn_cast_if_present<Value>(result) != op->getResult(0))
       results.push_back(result);
     return success();
   }
@@ -1903,7 +1923,8 @@ class Op : public OpState, public Traits<ConcreteType>... {
 
     // If the fold failed or was in-place, try to fold the traits of the
     // operation.
-    if (!result || result.template dyn_cast<Value>() == op->getResult(0)) {
+    if (!result ||
+        llvm::dyn_cast_if_present<Value>(result) == op->getResult(0)) {
       if (succeeded(op_definition_impl::foldTraits<Traits<ConcreteType>...>(
               op, operands, results)))
         return success();
@@ -2119,7 +2140,6 @@ struct DenseMapInfo<T,
   }
   static bool isEqual(T lhs, T rhs) { return lhs == rhs; }
 };
-
 } // namespace llvm
 
 #endif

diff  --git a/mlir/include/mlir/IR/Unit.h b/mlir/include/mlir/IR/Unit.h
index 033dab5974516..63117a7664a7d 100644
--- a/mlir/include/mlir/IR/Unit.h
+++ b/mlir/include/mlir/IR/Unit.h
@@ -39,4 +39,17 @@ raw_ostream &operator<<(raw_ostream &os, const IRUnit &unit);
 
 } // end namespace mlir
 
+namespace llvm {
+
+// Allow llvm::cast style functions.
+template <typename To>
+struct CastInfo<To, mlir::IRUnit>
+    : public CastInfo<To, mlir::IRUnit::PointerUnion> {};
+
+template <typename To>
+struct CastInfo<To, const mlir::IRUnit>
+    : public CastInfo<To, const mlir::IRUnit::PointerUnion> {};
+
+} // namespace llvm
+
 #endif // MLIR_IR_UNIT_H

diff  --git a/mlir/include/mlir/Interfaces/CallInterfaces.h b/mlir/include/mlir/Interfaces/CallInterfaces.h
index 26a245eba3e58..7dbcddb01b241 100644
--- a/mlir/include/mlir/Interfaces/CallInterfaces.h
+++ b/mlir/include/mlir/Interfaces/CallInterfaces.h
@@ -30,10 +30,15 @@ struct CallInterfaceCallable : public PointerUnion<SymbolRefAttr, Value> {
 
 namespace llvm {
 
+// Allow llvm::cast style functions.
 template <typename To>
 struct CastInfo<To, mlir::CallInterfaceCallable>
     : public CastInfo<To, mlir::CallInterfaceCallable::PointerUnion> {};
 
+template <typename To>
+struct CastInfo<To, const mlir::CallInterfaceCallable>
+    : public CastInfo<To, const mlir::CallInterfaceCallable::PointerUnion> {};
+
 } // namespace llvm
 
 #endif // MLIR_INTERFACES_CALLINTERFACES_H


        


More information about the Mlir-commits mailing list