[Mlir-commits] [mlir] [mlir] fix Undefined behavior in CastInfo::castFailed with From=<MLIRinterface> (PR #87145)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jul 11 23:31:26 PDT 2024


https://github.com/lipracer updated https://github.com/llvm/llvm-project/pull/87145

>From a58e3d2aa05038241c62f36debdc44450b24b961 Mon Sep 17 00:00:00 2001
From: lipracer <lipracer at gmail.com>
Date: Fri, 29 Mar 2024 23:25:07 +0800
Subject: [PATCH 1/2] [mlir] fix Undefined behavior in CastInfo::castFailed
 with From=<MLIR interface>

Fixes https://github.com/llvm/llvm-project/issues/86647
---
 mlir/include/mlir/IR/OpDefinition.h        | 29 ++++++++++++++++
 mlir/tools/mlir-tblgen/OpInterfacesGen.cpp |  3 +-
 mlir/unittests/IR/InterfaceTest.cpp        | 40 ++++++++++++++++++++++
 3 files changed, 71 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index 59f094d669099..5610daadfbecb 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -22,6 +22,7 @@
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/ODSSupport.h"
 #include "mlir/IR/Operation.h"
+#include "llvm/Support/Casting.h"
 #include "llvm/Support/PointerLikeTypeTraits.h"
 
 #include <optional>
@@ -2142,6 +2143,34 @@ struct DenseMapInfo<T,
   }
   static bool isEqual(T lhs, T rhs) { return lhs == rhs; }
 };
+
+template <typename To, typename From>
+struct CastInfo<
+    To, From,
+    std::enable_if_t<
+        std::is_base_of_v<mlir::OpInterface<To, typename To::InterfaceTraits>,
+                          To> &&
+            std::is_base_of_v<mlir::OpInterface<std::remove_const_t<From>,
+                                                typename std::remove_const_t<
+                                                    From>::InterfaceTraits>,
+                              std::remove_const_t<From>>,
+        void>> : NullableValueCastFailed<To>,
+                 DefaultDoCastIfPossible<To, From, CastInfo<To, From>> {
+
+  static bool isPossible(From &val) {
+    if constexpr (std::is_same_v<To, From>)
+      return true;
+    else
+      return mlir::OpInterface<To, typename To::InterfaceTraits>::
+          InterfaceBase::classof(
+              const_cast<std::remove_const_t<From> &>(val).getOperation());
+  }
+
+  static To doCast(From &val) {
+    return To(const_cast<std::remove_const_t<From> &>(val).getOperation());
+  }
+};
+
 } // namespace llvm
 
 #endif
diff --git a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
index 4b06b92fbc8a8..a1cae23c1df90 100644
--- a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
@@ -544,7 +544,8 @@ void InterfaceGenerator::emitInterfaceDecl(const Interface &interface) {
   // Emit the main interface class declaration.
   os << llvm::formatv("class {0} : public ::mlir::{3}<{1}, detail::{2}> {\n"
                       "public:\n"
-                      "  using ::mlir::{3}<{1}, detail::{2}>::{3};\n",
+                      "  using ::mlir::{3}<{1}, detail::{2}>::{3};\n"
+                      "  using InterfaceTraits = detail::{2};\n",
                       interfaceName, interfaceName, interfaceTraitsName,
                       interfaceBaseType);
 
diff --git a/mlir/unittests/IR/InterfaceTest.cpp b/mlir/unittests/IR/InterfaceTest.cpp
index 42196b003e7da..741365b3efb5f 100644
--- a/mlir/unittests/IR/InterfaceTest.cpp
+++ b/mlir/unittests/IR/InterfaceTest.cpp
@@ -17,6 +17,9 @@
 #include "../../test/lib/Dialect/Test/TestDialect.h"
 #include "../../test/lib/Dialect/Test/TestOps.h"
 #include "../../test/lib/Dialect/Test/TestTypes.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Parser/Parser.h"
+#include "llvm/ADT/TypeSwitch.h"
 
 using namespace mlir;
 using namespace test;
@@ -84,3 +87,40 @@ TEST(InterfaceTest, TestImplicitConversion) {
   typeA = typeB;
   EXPECT_EQ(typeA, typeB);
 }
+
+TEST(OperationInterfaceTest, CastOpToInterface) {
+  DialectRegistry registry;
+  MLIRContext ctx;
+
+  const char *ir = R"MLIR(
+      func.func @map(%arg : tensor<1xi64>) {
+        %0 = arith.constant dense<[10]> : tensor<1xi64>
+        %1 = arith.addi %arg, %0 : tensor<1xi64>
+        return
+      }
+    )MLIR";
+
+  registry.insert<func::FuncDialect, arith::ArithDialect>();
+  ctx.appendDialectRegistry(registry);
+  OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
+  Operation &op = cast<func::FuncOp>(module->front()).getBody().front().front();
+
+  OpAsmOpInterface interface = llvm::cast<OpAsmOpInterface>(op);
+
+  bool constantOp =
+      llvm::TypeSwitch<OpAsmOpInterface, bool>(interface)
+          .Case<VectorUnrollOpInterface, arith::ConstantOp>([&](auto op) {
+            return std::is_same_v<decltype(op), arith::ConstantOp>;
+          });
+
+  EXPECT_TRUE(constantOp);
+
+  EXPECT_FALSE(llvm::isa<VectorUnrollOpInterface>(interface));
+  EXPECT_FALSE(llvm::dyn_cast<VectorUnrollOpInterface>(interface));
+
+  EXPECT_TRUE(llvm::isa<InferTypeOpInterface>(interface));
+  EXPECT_TRUE(llvm::dyn_cast<InferTypeOpInterface>(interface));
+
+  EXPECT_TRUE(llvm::isa<OpAsmOpInterface>(interface));
+  EXPECT_TRUE(llvm::dyn_cast<OpAsmOpInterface>(interface));
+}

>From e00ec5795b9eb9edd9803dbf666a40bed50cd540 Mon Sep 17 00:00:00 2001
From: lipracer <lipracer at gmail.com>
Date: Thu, 11 Jul 2024 11:06:45 -0400
Subject: [PATCH 2/2] add CastInfo to support cast Interface to Op

---
 mlir/include/mlir/IR/OpDefinition.h | 46 +++++++++++++++++++++++++++--
 mlir/include/mlir/TableGen/Class.h  |  2 ++
 mlir/tools/mlir-tblgen/OpClass.cpp  |  9 ++++++
 mlir/unittests/IR/InterfaceTest.cpp | 20 +++++++++----
 4 files changed, 69 insertions(+), 8 deletions(-)

diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index 5610daadfbecb..52aac19289cf4 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -2157,7 +2157,7 @@ struct CastInfo<
         void>> : NullableValueCastFailed<To>,
                  DefaultDoCastIfPossible<To, From, CastInfo<To, From>> {
 
-  static bool isPossible(From &val) {
+  static inline bool isPossible(From &val) {
     if constexpr (std::is_same_v<To, From>)
       return true;
     else
@@ -2166,7 +2166,49 @@ struct CastInfo<
               const_cast<std::remove_const_t<From> &>(val).getOperation());
   }
 
-  static To doCast(From &val) {
+  static inline To doCast(From &val) {
+    return To(const_cast<std::remove_const_t<From> &>(val).getOperation());
+  }
+};
+
+template <typename OpT, typename = void>
+struct is_concrete_op_type : public std::false_type {};
+
+template <typename OpT, template <typename T> typename... Traits>
+constexpr auto concrete_op_base_type_impl(std::tuple<Traits<OpT>...>) {
+  return mlir::Op<OpT, Traits...>(nullptr);
+}
+
+template <typename OpT>
+using concrete_op_base_type =
+    decltype(concrete_op_base_type_impl<OpT>(typename OpT::traits()));
+
+template <typename OpT>
+struct is_concrete_op_type<
+    OpT, std::enable_if_t<std::is_base_of_v<concrete_op_base_type<OpT>, OpT>>>
+    : public std::true_type {};
+
+template <typename To, typename From>
+struct CastInfo<
+    To, From,
+    std::enable_if_t<
+        is_concrete_op_type<To>() &&
+        std::is_base_of_v<mlir::OpInterface<std::remove_const_t<From>,
+                                            typename std::remove_const_t<
+                                                From>::InterfaceTraits>,
+                          std::remove_const_t<From>>>>
+    : NullableValueCastFailed<To>,
+      DefaultDoCastIfPossible<To, From, CastInfo<To, From>> {
+
+  static inline bool isPossible(From &val) {
+    if constexpr (std::is_same_v<To, From>)
+      return true;
+    else
+      return isa<To>(
+          const_cast<std::remove_const_t<From> &>(val).getOperation());
+  }
+
+  static inline To doCast(From &val) {
     return To(const_cast<std::remove_const_t<From> &>(val).getOperation());
   }
 };
diff --git a/mlir/include/mlir/TableGen/Class.h b/mlir/include/mlir/TableGen/Class.h
index 92fec6a3b11d9..7616f56aa2e3d 100644
--- a/mlir/include/mlir/TableGen/Class.h
+++ b/mlir/include/mlir/TableGen/Class.h
@@ -520,6 +520,8 @@ class ParentClass {
   /// Write the parent class declaration.
   void writeTo(raw_indented_ostream &os) const;
 
+  friend class OpClass;
+
 private:
   /// The fully resolved C++ name of the parent class.
   std::string name;
diff --git a/mlir/tools/mlir-tblgen/OpClass.cpp b/mlir/tools/mlir-tblgen/OpClass.cpp
index 60fa1833ce625..5426302dfed3e 100644
--- a/mlir/tools/mlir-tblgen/OpClass.cpp
+++ b/mlir/tools/mlir-tblgen/OpClass.cpp
@@ -36,7 +36,16 @@ OpClass::OpClass(StringRef name, std::string extraClassDeclaration,
 }
 
 void OpClass::finalize() {
+  std::string traitList;
+  llvm::raw_string_ostream os(traitList);
+  iterator_range parentTemplateParams(std::begin(parent.templateParams) + 1,
+                                      std::end(parent.templateParams));
+  llvm::interleaveComma(parentTemplateParams, os, [&](auto &trait) {
+    os << trait << "<" << getClassName().str() << ">";
+  });
+  declare<UsingDeclaration>("traits", "std::tuple<" + traitList + ">");
   Class::finalize();
+
   declare<VisibilityDeclaration>(Visibility::Public);
   declare<ExtraClassDeclaration>(extraClassDeclaration, extraClassDefinition);
 }
diff --git a/mlir/unittests/IR/InterfaceTest.cpp b/mlir/unittests/IR/InterfaceTest.cpp
index 741365b3efb5f..c9ae6938e8b40 100644
--- a/mlir/unittests/IR/InterfaceTest.cpp
+++ b/mlir/unittests/IR/InterfaceTest.cpp
@@ -18,6 +18,7 @@
 #include "../../test/lib/Dialect/Test/TestOps.h"
 #include "../../test/lib/Dialect/Test/TestTypes.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Parser/Parser.h"
 #include "llvm/ADT/TypeSwitch.h"
 
@@ -88,7 +89,7 @@ TEST(InterfaceTest, TestImplicitConversion) {
   EXPECT_EQ(typeA, typeB);
 }
 
-TEST(OperationInterfaceTest, CastOpToInterface) {
+TEST(OperationInterfaceTest, CastInterfaceToOpOrInterface) {
   DialectRegistry registry;
   MLIRContext ctx;
 
@@ -105,13 +106,20 @@ TEST(OperationInterfaceTest, CastOpToInterface) {
   OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
   Operation &op = cast<func::FuncOp>(module->front()).getBody().front().front();
 
+  static_assert(std::is_base_of_v<llvm::concrete_op_base_type<arith::AddIOp>,
+                                  arith::AddIOp>,
+                "");
+  static_assert(llvm::is_concrete_op_type<arith::AddIOp>(), "");
+  static_assert(!llvm::is_concrete_op_type<OpAsmOpInterface>(), "");
+
   OpAsmOpInterface interface = llvm::cast<OpAsmOpInterface>(op);
 
-  bool constantOp =
-      llvm::TypeSwitch<OpAsmOpInterface, bool>(interface)
-          .Case<VectorUnrollOpInterface, arith::ConstantOp>([&](auto op) {
-            return std::is_same_v<decltype(op), arith::ConstantOp>;
-          });
+  bool constantOp = llvm::TypeSwitch<OpAsmOpInterface, bool>(interface)
+                        .Case<arith::AddIOp, arith::ConstantOp>([&](auto op) {
+                          bool is_same =
+                              std::is_same_v<decltype(op), arith::ConstantOp>;
+                          return is_same;
+                        });
 
   EXPECT_TRUE(constantOp);
 



More information about the Mlir-commits mailing list