[Mlir-commits] [mlir] [mlir] fix Undefined behavior in CastInfo::castFailed with From=<MLIRinterface> (PR #87145)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Nov 4 04:39:18 PST 2024
https://github.com/lipracer updated https://github.com/llvm/llvm-project/pull/87145
>From fd32cf987f381bb68ffde962be1e07ca2c6d5512 Mon Sep 17 00:00:00 2001
From: lipracer <lipracer at gmail.com>
Date: Fri, 29 Mar 2024 23:25:07 +0800
Subject: [PATCH 1/3] [mlir] fix Undefined behavior in CastInfo::castFailed
with From=<MLIR interface>
Fixes https://github.com/llvm/llvm-project/issues/86647
---
mlir/include/mlir/IR/OpDefinition.h | 29 ++++++++++++++++
mlir/tools/mlir-tblgen/OpInterfacesGen.cpp | 3 +-
mlir/unittests/IR/InterfaceTest.cpp | 40 ++++++++++++++++++++++
3 files changed, 71 insertions(+), 1 deletion(-)
diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index 59f094d6690991..5610daadfbecb5 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -22,6 +22,7 @@
#include "mlir/IR/Dialect.h"
#include "mlir/IR/ODSSupport.h"
#include "mlir/IR/Operation.h"
+#include "llvm/Support/Casting.h"
#include "llvm/Support/PointerLikeTypeTraits.h"
#include <optional>
@@ -2142,6 +2143,34 @@ struct DenseMapInfo<T,
}
static bool isEqual(T lhs, T rhs) { return lhs == rhs; }
};
+
+template <typename To, typename From>
+struct CastInfo<
+ To, From,
+ std::enable_if_t<
+ std::is_base_of_v<mlir::OpInterface<To, typename To::InterfaceTraits>,
+ To> &&
+ std::is_base_of_v<mlir::OpInterface<std::remove_const_t<From>,
+ typename std::remove_const_t<
+ From>::InterfaceTraits>,
+ std::remove_const_t<From>>,
+ void>> : NullableValueCastFailed<To>,
+ DefaultDoCastIfPossible<To, From, CastInfo<To, From>> {
+
+ static bool isPossible(From &val) {
+ if constexpr (std::is_same_v<To, From>)
+ return true;
+ else
+ return mlir::OpInterface<To, typename To::InterfaceTraits>::
+ InterfaceBase::classof(
+ const_cast<std::remove_const_t<From> &>(val).getOperation());
+ }
+
+ static To doCast(From &val) {
+ return To(const_cast<std::remove_const_t<From> &>(val).getOperation());
+ }
+};
+
} // namespace llvm
#endif
diff --git a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
index 1f1b1d9a340391..c8233d19da4b05 100644
--- a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
@@ -545,7 +545,8 @@ void InterfaceGenerator::emitInterfaceDecl(const Interface &interface) {
// Emit the main interface class declaration.
os << llvm::formatv("class {0} : public ::mlir::{3}<{1}, detail::{2}> {\n"
"public:\n"
- " using ::mlir::{3}<{1}, detail::{2}>::{3};\n",
+ " using ::mlir::{3}<{1}, detail::{2}>::{3};\n"
+ " using InterfaceTraits = detail::{2};\n",
interfaceName, interfaceName, interfaceTraitsName,
interfaceBaseType);
diff --git a/mlir/unittests/IR/InterfaceTest.cpp b/mlir/unittests/IR/InterfaceTest.cpp
index 42196b003e7dad..741365b3efb5fc 100644
--- a/mlir/unittests/IR/InterfaceTest.cpp
+++ b/mlir/unittests/IR/InterfaceTest.cpp
@@ -17,6 +17,9 @@
#include "../../test/lib/Dialect/Test/TestDialect.h"
#include "../../test/lib/Dialect/Test/TestOps.h"
#include "../../test/lib/Dialect/Test/TestTypes.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Parser/Parser.h"
+#include "llvm/ADT/TypeSwitch.h"
using namespace mlir;
using namespace test;
@@ -84,3 +87,40 @@ TEST(InterfaceTest, TestImplicitConversion) {
typeA = typeB;
EXPECT_EQ(typeA, typeB);
}
+
+TEST(OperationInterfaceTest, CastOpToInterface) {
+ DialectRegistry registry;
+ MLIRContext ctx;
+
+ const char *ir = R"MLIR(
+ func.func @map(%arg : tensor<1xi64>) {
+ %0 = arith.constant dense<[10]> : tensor<1xi64>
+ %1 = arith.addi %arg, %0 : tensor<1xi64>
+ return
+ }
+ )MLIR";
+
+ registry.insert<func::FuncDialect, arith::ArithDialect>();
+ ctx.appendDialectRegistry(registry);
+ OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
+ Operation &op = cast<func::FuncOp>(module->front()).getBody().front().front();
+
+ OpAsmOpInterface interface = llvm::cast<OpAsmOpInterface>(op);
+
+ bool constantOp =
+ llvm::TypeSwitch<OpAsmOpInterface, bool>(interface)
+ .Case<VectorUnrollOpInterface, arith::ConstantOp>([&](auto op) {
+ return std::is_same_v<decltype(op), arith::ConstantOp>;
+ });
+
+ EXPECT_TRUE(constantOp);
+
+ EXPECT_FALSE(llvm::isa<VectorUnrollOpInterface>(interface));
+ EXPECT_FALSE(llvm::dyn_cast<VectorUnrollOpInterface>(interface));
+
+ EXPECT_TRUE(llvm::isa<InferTypeOpInterface>(interface));
+ EXPECT_TRUE(llvm::dyn_cast<InferTypeOpInterface>(interface));
+
+ EXPECT_TRUE(llvm::isa<OpAsmOpInterface>(interface));
+ EXPECT_TRUE(llvm::dyn_cast<OpAsmOpInterface>(interface));
+}
>From cc59eaca3b16fcccbf6b806037a70a03f02b6767 Mon Sep 17 00:00:00 2001
From: lipracer <lipracer at gmail.com>
Date: Thu, 11 Jul 2024 11:06:45 -0400
Subject: [PATCH 2/3] add CastInfo to support cast Interface to Op
---
mlir/include/mlir/IR/OpDefinition.h | 53 +++++++++++++++++++++++++++--
mlir/include/mlir/TableGen/Class.h | 2 ++
mlir/tools/mlir-tblgen/OpClass.cpp | 9 +++++
mlir/unittests/IR/InterfaceTest.cpp | 19 +++++++----
4 files changed, 75 insertions(+), 8 deletions(-)
diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index 5610daadfbecb5..763d0f6ff2cb8d 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -2144,6 +2144,9 @@ struct DenseMapInfo<T,
static bool isEqual(T lhs, T rhs) { return lhs == rhs; }
};
+/// Add support for llvm style casts.
+/// We provide a cast between To and From if To and From is mlir::OpInterface or
+/// derives from it.
template <typename To, typename From>
struct CastInfo<
To, From,
@@ -2157,7 +2160,7 @@ struct CastInfo<
void>> : NullableValueCastFailed<To>,
DefaultDoCastIfPossible<To, From, CastInfo<To, From>> {
- static bool isPossible(From &val) {
+ static inline bool isPossible(From &val) {
if constexpr (std::is_same_v<To, From>)
return true;
else
@@ -2166,7 +2169,53 @@ struct CastInfo<
const_cast<std::remove_const_t<From> &>(val).getOperation());
}
- static To doCast(From &val) {
+ static inline To doCast(From &val) {
+ return To(const_cast<std::remove_const_t<From> &>(val).getOperation());
+ }
+};
+
+template <typename OpT, typename = void>
+struct is_concrete_op_type : public std::false_type {};
+
+template <typename OpT, template <typename T> typename... Traits>
+constexpr auto concrete_op_base_type_impl(std::tuple<Traits<OpT>...>) {
+ return mlir::Op<OpT, Traits...>(nullptr);
+}
+
+template <typename OpT>
+using concrete_op_base_type =
+ decltype(concrete_op_base_type_impl<OpT>(typename OpT::traits()));
+
+template <typename OpT>
+struct is_concrete_op_type<
+ OpT, std::enable_if_t<std::is_base_of_v<concrete_op_base_type<OpT>, OpT>>>
+ : public std::true_type {};
+
+/// Add support for llvm style casts.
+/// We provide a cast between To and From if To is mlir::Op<ConcreteType,
+/// Trait0, Trait1, ...> or derives from it and From is mlir::OpInterface or
+/// derives from it.
+template <typename To, typename From>
+struct CastInfo<
+ To, From,
+ std::enable_if_t<
+ is_concrete_op_type<To>() &&
+ std::is_base_of_v<mlir::OpInterface<std::remove_const_t<From>,
+ typename std::remove_const_t<
+ From>::InterfaceTraits>,
+ std::remove_const_t<From>>>>
+ : NullableValueCastFailed<To>,
+ DefaultDoCastIfPossible<To, From, CastInfo<To, From>> {
+
+ static inline bool isPossible(From &val) {
+ if constexpr (std::is_same_v<To, From>)
+ return true;
+ else
+ return isa<To>(
+ const_cast<std::remove_const_t<From> &>(val).getOperation());
+ }
+
+ static inline To doCast(From &val) {
return To(const_cast<std::remove_const_t<From> &>(val).getOperation());
}
};
diff --git a/mlir/include/mlir/TableGen/Class.h b/mlir/include/mlir/TableGen/Class.h
index f750a34a3b2ba4..5cb9aa4e6d21ba 100644
--- a/mlir/include/mlir/TableGen/Class.h
+++ b/mlir/include/mlir/TableGen/Class.h
@@ -521,6 +521,8 @@ class ParentClass {
/// Write the parent class declaration.
void writeTo(raw_indented_ostream &os) const;
+ friend class OpClass;
+
private:
/// The fully resolved C++ name of the parent class.
std::string name;
diff --git a/mlir/tools/mlir-tblgen/OpClass.cpp b/mlir/tools/mlir-tblgen/OpClass.cpp
index 60fa1833ce625e..5426302dfed3e3 100644
--- a/mlir/tools/mlir-tblgen/OpClass.cpp
+++ b/mlir/tools/mlir-tblgen/OpClass.cpp
@@ -36,7 +36,16 @@ OpClass::OpClass(StringRef name, std::string extraClassDeclaration,
}
void OpClass::finalize() {
+ std::string traitList;
+ llvm::raw_string_ostream os(traitList);
+ iterator_range parentTemplateParams(std::begin(parent.templateParams) + 1,
+ std::end(parent.templateParams));
+ llvm::interleaveComma(parentTemplateParams, os, [&](auto &trait) {
+ os << trait << "<" << getClassName().str() << ">";
+ });
+ declare<UsingDeclaration>("traits", "std::tuple<" + traitList + ">");
Class::finalize();
+
declare<VisibilityDeclaration>(Visibility::Public);
declare<ExtraClassDeclaration>(extraClassDeclaration, extraClassDefinition);
}
diff --git a/mlir/unittests/IR/InterfaceTest.cpp b/mlir/unittests/IR/InterfaceTest.cpp
index 741365b3efb5fc..6c983385679b18 100644
--- a/mlir/unittests/IR/InterfaceTest.cpp
+++ b/mlir/unittests/IR/InterfaceTest.cpp
@@ -88,7 +88,7 @@ TEST(InterfaceTest, TestImplicitConversion) {
EXPECT_EQ(typeA, typeB);
}
-TEST(OperationInterfaceTest, CastOpToInterface) {
+TEST(OperationInterfaceTest, CastInterfaceToOpOrInterface) {
DialectRegistry registry;
MLIRContext ctx;
@@ -105,13 +105,20 @@ TEST(OperationInterfaceTest, CastOpToInterface) {
OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
Operation &op = cast<func::FuncOp>(module->front()).getBody().front().front();
+ static_assert(std::is_base_of_v<llvm::concrete_op_base_type<arith::AddIOp>,
+ arith::AddIOp>,
+ "");
+ static_assert(llvm::is_concrete_op_type<arith::AddIOp>(), "");
+ static_assert(!llvm::is_concrete_op_type<OpAsmOpInterface>(), "");
+
OpAsmOpInterface interface = llvm::cast<OpAsmOpInterface>(op);
- bool constantOp =
- llvm::TypeSwitch<OpAsmOpInterface, bool>(interface)
- .Case<VectorUnrollOpInterface, arith::ConstantOp>([&](auto op) {
- return std::is_same_v<decltype(op), arith::ConstantOp>;
- });
+ bool constantOp = llvm::TypeSwitch<OpAsmOpInterface, bool>(interface)
+ .Case<arith::AddIOp, arith::ConstantOp>([&](auto op) {
+ bool is_same =
+ std::is_same_v<decltype(op), arith::ConstantOp>;
+ return is_same;
+ });
EXPECT_TRUE(constantOp);
>From 1b21dd261c61b15fb4d12483d50a1a67b1fda659 Mon Sep 17 00:00:00 2001
From: lipracer <lipracer at gmail.com>
Date: Mon, 4 Nov 2024 07:37:40 -0500
Subject: [PATCH 3/3] refine
---
mlir/include/mlir/IR/OpDefinition.h | 57 ++---------------------------
mlir/unittests/IR/InterfaceTest.cpp | 4 +-
2 files changed, 5 insertions(+), 56 deletions(-)
diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index 763d0f6ff2cb8d..7e067a48629e01 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -2145,68 +2145,17 @@ struct DenseMapInfo<T,
};
/// Add support for llvm style casts.
-/// We provide a cast between To and From if To and From is mlir::OpInterface or
+/// We provide a cast between To and From if To and From is mlir::OpState or
/// derives from it.
template <typename To, typename From>
struct CastInfo<
To, From,
std::enable_if_t<
- std::is_base_of_v<mlir::OpInterface<To, typename To::InterfaceTraits>,
- To> &&
- std::is_base_of_v<mlir::OpInterface<std::remove_const_t<From>,
- typename std::remove_const_t<
- From>::InterfaceTraits>,
- std::remove_const_t<From>>,
+ std::is_base_of_v<mlir::OpState, To> &&
+ std::is_base_of_v<mlir::OpState, std::remove_const_t<From>>,
void>> : NullableValueCastFailed<To>,
DefaultDoCastIfPossible<To, From, CastInfo<To, From>> {
- static inline bool isPossible(From &val) {
- if constexpr (std::is_same_v<To, From>)
- return true;
- else
- return mlir::OpInterface<To, typename To::InterfaceTraits>::
- InterfaceBase::classof(
- const_cast<std::remove_const_t<From> &>(val).getOperation());
- }
-
- static inline To doCast(From &val) {
- return To(const_cast<std::remove_const_t<From> &>(val).getOperation());
- }
-};
-
-template <typename OpT, typename = void>
-struct is_concrete_op_type : public std::false_type {};
-
-template <typename OpT, template <typename T> typename... Traits>
-constexpr auto concrete_op_base_type_impl(std::tuple<Traits<OpT>...>) {
- return mlir::Op<OpT, Traits...>(nullptr);
-}
-
-template <typename OpT>
-using concrete_op_base_type =
- decltype(concrete_op_base_type_impl<OpT>(typename OpT::traits()));
-
-template <typename OpT>
-struct is_concrete_op_type<
- OpT, std::enable_if_t<std::is_base_of_v<concrete_op_base_type<OpT>, OpT>>>
- : public std::true_type {};
-
-/// Add support for llvm style casts.
-/// We provide a cast between To and From if To is mlir::Op<ConcreteType,
-/// Trait0, Trait1, ...> or derives from it and From is mlir::OpInterface or
-/// derives from it.
-template <typename To, typename From>
-struct CastInfo<
- To, From,
- std::enable_if_t<
- is_concrete_op_type<To>() &&
- std::is_base_of_v<mlir::OpInterface<std::remove_const_t<From>,
- typename std::remove_const_t<
- From>::InterfaceTraits>,
- std::remove_const_t<From>>>>
- : NullableValueCastFailed<To>,
- DefaultDoCastIfPossible<To, From, CastInfo<To, From>> {
-
static inline bool isPossible(From &val) {
if constexpr (std::is_same_v<To, From>)
return true;
diff --git a/mlir/unittests/IR/InterfaceTest.cpp b/mlir/unittests/IR/InterfaceTest.cpp
index 6c983385679b18..a1e444eeb0e156 100644
--- a/mlir/unittests/IR/InterfaceTest.cpp
+++ b/mlir/unittests/IR/InterfaceTest.cpp
@@ -104,13 +104,13 @@ TEST(OperationInterfaceTest, CastInterfaceToOpOrInterface) {
ctx.appendDialectRegistry(registry);
OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
Operation &op = cast<func::FuncOp>(module->front()).getBody().front().front();
-
+ /*
static_assert(std::is_base_of_v<llvm::concrete_op_base_type<arith::AddIOp>,
arith::AddIOp>,
"");
static_assert(llvm::is_concrete_op_type<arith::AddIOp>(), "");
static_assert(!llvm::is_concrete_op_type<OpAsmOpInterface>(), "");
-
+ */
OpAsmOpInterface interface = llvm::cast<OpAsmOpInterface>(op);
bool constantOp = llvm::TypeSwitch<OpAsmOpInterface, bool>(interface)
More information about the Mlir-commits
mailing list