[Mlir-commits] [mlir] [mlir] fix Undefined behavior in CastInfo::castFailed with From=<MLIR… (PR #87145)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Mar 30 00:19:21 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: long.chen (lipracer)

<details>
<summary>Changes</summary>

… interface>

Fixes https://github.com/llvm/llvm-project/issues/86647

---
Full diff: https://github.com/llvm/llvm-project/pull/87145.diff


3 Files Affected:

- (modified) mlir/include/mlir/IR/OpDefinition.h (+28) 
- (modified) mlir/tools/mlir-tblgen/OpInterfacesGen.cpp (+2-1) 
- (modified) mlir/unittests/Interfaces/InferTypeOpInterfaceTest.cpp (+28) 


``````````diff
diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index bd68c27445744e..6da00b4c549a3d 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -22,6 +22,7 @@
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/ODSSupport.h"
 #include "mlir/IR/Operation.h"
+#include "llvm/Support/Casting.h"
 #include "llvm/Support/PointerLikeTypeTraits.h"
 
 #include <optional>
@@ -2110,6 +2111,33 @@ struct DenseMapInfo<T,
   }
   static bool isEqual(T lhs, T rhs) { return lhs == rhs; }
 };
+
+template <typename To, typename From>
+struct CastInfo<
+    To, From,
+    std::enable_if_t<
+        std::is_base_of_v<mlir::OpInterface<To, typename To::InterfaceTraits>,
+                          To> &&
+            std::is_base_of_v<mlir::OpInterface<std::remove_const_t<From>,
+                                                typename std::remove_const_t<
+                                                    From>::InterfaceTraits>,
+                              std::remove_const_t<From>>,
+        void>> : NullableValueCastFailed<To>,
+                 DefaultDoCastIfPossible<To, From, CastInfo<To, From>> {
+
+  static bool isPossible(From &val) {
+    if constexpr (std::is_same_v<To, From>)
+      return true;
+    else
+      return isa<To>(
+          const_cast<std::remove_const_t<From> &>(val).getOperation());
+  }
+
+  static To doCast(From &val) {
+    return To(const_cast<std::remove_const_t<From> &>(val).getOperation());
+  }
+};
+
 } // namespace llvm
 
 #endif
diff --git a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
index 2a7406f42f34b5..c6409e9ec30ec9 100644
--- a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
@@ -544,7 +544,8 @@ void InterfaceGenerator::emitInterfaceDecl(const Interface &interface) {
   // Emit the main interface class declaration.
   os << llvm::formatv("class {0} : public ::mlir::{3}<{1}, detail::{2}> {\n"
                       "public:\n"
-                      "  using ::mlir::{3}<{1}, detail::{2}>::{3};\n",
+                      "  using ::mlir::{3}<{1}, detail::{2}>::{3};\n"
+                      "  using InterfaceTraits = detail::{2};\n",
                       interfaceName, interfaceName, interfaceTraitsName,
                       interfaceBaseType);
 
diff --git a/mlir/unittests/Interfaces/InferTypeOpInterfaceTest.cpp b/mlir/unittests/Interfaces/InferTypeOpInterfaceTest.cpp
index 2fc8a43f7c04c6..686bac8c8aaac0 100644
--- a/mlir/unittests/Interfaces/InferTypeOpInterfaceTest.cpp
+++ b/mlir/unittests/Interfaces/InferTypeOpInterfaceTest.cpp
@@ -17,6 +17,7 @@
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/Parser/Parser.h"
+#include "llvm/ADT/TypeSwitch.h"
 
 #include <gtest/gtest.h>
 
@@ -103,3 +104,30 @@ TEST_F(ValueShapeRangeTest, SettingShapes) {
   EXPECT_EQ(range.getShape(1).getDimSize(0), 1);
   EXPECT_FALSE(range.getShape(2));
 }
+
+TEST(OperationInterfaceTest, CastOpToInterface) {
+  DialectRegistry registry;
+  MLIRContext ctx;
+
+  const char *ir = R"MLIR(
+      func.func @map(%arg : tensor<1xi64>) {
+        %0 = arith.constant dense<[10]> : tensor<1xi64>
+        %1 = arith.addi %arg, %0 : tensor<1xi64>
+        return
+      }
+    )MLIR";
+
+  registry.insert<func::FuncDialect, arith::ArithDialect>();
+  ctx.appendDialectRegistry(registry);
+  OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
+  Operation &op = cast<func::FuncOp>(module->front()).getBody().front().front();
+
+  ::mlir::OpAsmOpInterface interface = llvm::cast<::mlir::OpAsmOpInterface>(op);
+
+  std::string funcName;
+  llvm::TypeSwitch<::mlir::OpAsmOpInterface, void>(interface)
+      .Case<::mlir::VectorUnrollOpInterface, arith::ConstantOp>(
+          [&](auto op) { funcName = __PRETTY_FUNCTION__; });
+
+  EXPECT_TRUE(funcName.find("ConstantOp") != std::string::npos);
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/87145


More information about the Mlir-commits mailing list