[Mlir-commits] [mlir] [mlir][arith] Check for valid IR in BitcastOp::fold. (PR #100743)

Ingo Müller llvmlistbot at llvm.org
Sun Jul 28 23:37:13 PDT 2024


https://github.com/ingomueller-net updated https://github.com/llvm/llvm-project/pull/100743

>From a8b01cd5fa21fb85db91d02a57a9191a88a18823 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ingo=20M=C3=BCller?= <ingomueller at google.com>
Date: Fri, 26 Jul 2024 13:30:21 +0000
Subject: [PATCH 1/2] [mlir][arith] Check for valid IR in BitcastOp::fold.
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

This PR prevents `BitcastOp::fold` to create invalid `IntegerAttr` and
`FloatAttr` values, which result in failed assertions. This can happen
if the input IR is invalid. The PR adds tests for whether the
to-be-created attribute verifies and returns early from `fold`if it
doesn't.

Signed-off-by: Ingo Müller <ingomueller at google.com>
---
 mlir/lib/Dialect/Arith/IR/ArithOps.cpp        | 21 +++++++--
 .../Dialect/Arith/ArithOpsFoldersTest.cpp     | 46 +++++++++++++++++++
 mlir/unittests/Dialect/Arith/CMakeLists.txt   |  7 +++
 mlir/unittests/Dialect/CMakeLists.txt         |  1 +
 4 files changed, 72 insertions(+), 3 deletions(-)
 create mode 100644 mlir/unittests/Dialect/Arith/ArithOpsFoldersTest.cpp
 create mode 100644 mlir/unittests/Dialect/Arith/CMakeLists.txt

diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index aa5eb95a3d22e..55e8fcde0c498 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1740,9 +1740,24 @@ OpFoldResult arith::BitcastOp::fold(FoldAdaptor adaptor) {
                    ? llvm::cast<FloatAttr>(operand).getValue().bitcastToAPInt()
                    : llvm::cast<IntegerAttr>(operand).getValue();
 
-  if (auto resFloatType = llvm::dyn_cast<FloatType>(resType))
-    return FloatAttr::get(resType,
-                          APFloat(resFloatType.getFloatSemantics(), bits));
+  /// If bitwidth aren't the same, don't fold.
+  if (resType.getIntOrFloatBitWidth() != bits.getBitWidth())
+    return {};
+
+  MLIRContext *ctx = getContext();
+  auto emitErrorFn = [=] { return ::emitError(UnknownLoc::get(ctx)); };
+
+  if (auto resFloatType = llvm::dyn_cast<FloatType>(resType)) {
+    /// If bits don't represent a valid float, don't fold.
+    APFloat floatBits(resFloatType.getFloatSemantics(), bits);
+    if (failed(FloatAttr::verify(emitErrorFn, resType, floatBits)))
+      return {};
+    return FloatAttr::get(resType, floatBits);
+  }
+
+  /// If bits don't represent a valid integer, don't fold.
+  if (failed(IntegerAttr::verify(emitErrorFn, resType, bits)))
+    return {};
   return IntegerAttr::get(resType, bits);
 }
 
diff --git a/mlir/unittests/Dialect/Arith/ArithOpsFoldersTest.cpp b/mlir/unittests/Dialect/Arith/ArithOpsFoldersTest.cpp
new file mode 100644
index 0000000000000..06c255f01ad69
--- /dev/null
+++ b/mlir/unittests/Dialect/Arith/ArithOpsFoldersTest.cpp
@@ -0,0 +1,46 @@
+//===- ArithOpsFoldersTest.cpp - unit tests for arith op folders ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Verifier.h"
+#include "gtest/gtest.h"
+
+using namespace mlir;
+
+namespace {
+// Tests a regression that made `BitcastOp::fold` crash on invalid input IR, see
+// #100743,
+TEST(BitcastOpTest, FoldInteger) {
+  MLIRContext context;
+  context.loadDialect<arith::ArithDialect>();
+  auto loc = UnknownLoc::get(&context);
+  auto module = ModuleOp::create(loc);
+  OpBuilder builder(module.getBodyRegion());
+  Value i32Val = builder.create<arith::ConstantOp>(
+      loc, builder.getI32Type(), builder.getI32IntegerAttr(0));
+  // This would create an invalid op: `bitcast` can't cast different bitwidths.
+  builder.createOrFold<arith::BitcastOp>(loc, builder.getI64Type(), i32Val);
+  ASSERT_TRUE(failed(verify(module)));
+}
+
+// Tests a regression that made `BitcastOp::fold` crash on invalid input IR, see
+// #100743,
+TEST(BitcastOpTest, FoldFloat) {
+  MLIRContext context;
+  context.loadDialect<arith::ArithDialect>();
+  auto loc = UnknownLoc::get(&context);
+  auto module = ModuleOp::create(loc);
+  OpBuilder builder(module.getBodyRegion());
+  Value f32Val = builder.create<arith::ConstantOp>(loc, builder.getF32Type(),
+                                                   builder.getF32FloatAttr(0));
+  // This would create an invalid op: `bitcast` can't cast different bitwidths.
+  builder.createOrFold<arith::BitcastOp>(loc, builder.getF64Type(), f32Val);
+  ASSERT_TRUE(failed(verify(module)));
+}
+} // namespace
diff --git a/mlir/unittests/Dialect/Arith/CMakeLists.txt b/mlir/unittests/Dialect/Arith/CMakeLists.txt
new file mode 100644
index 0000000000000..ac6b701529d3f
--- /dev/null
+++ b/mlir/unittests/Dialect/Arith/CMakeLists.txt
@@ -0,0 +1,7 @@
+add_mlir_unittest(MLIRArithOpsTests
+  ArithOpsFoldersTest.cpp
+)
+target_link_libraries(MLIRArithOpsTests
+  PRIVATE
+  MLIRArithDialect
+)
diff --git a/mlir/unittests/Dialect/CMakeLists.txt b/mlir/unittests/Dialect/CMakeLists.txt
index 90a75d5a46ad9..e88e8c61fcb13 100644
--- a/mlir/unittests/Dialect/CMakeLists.txt
+++ b/mlir/unittests/Dialect/CMakeLists.txt
@@ -6,6 +6,7 @@ target_link_libraries(MLIRDialectTests
   MLIRIR
   MLIRDialect)
 
+add_subdirectory(Arith)
 add_subdirectory(ArmSME)
 add_subdirectory(Index)
 add_subdirectory(LLVMIR)

>From 293c6f02049df809f14b9f51fdd2232f107f91b9 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ingo=20M=C3=BCller?= <ingomueller at google.com>
Date: Mon, 29 Jul 2024 06:36:49 +0000
Subject: [PATCH 2/2] Remove unittest. Replace early exists with assert.

---
 mlir/lib/Dialect/Arith/IR/ArithOps.cpp        | 23 ++--------
 .../Dialect/Arith/ArithOpsFoldersTest.cpp     | 46 -------------------
 mlir/unittests/Dialect/Arith/CMakeLists.txt   |  7 ---
 mlir/unittests/Dialect/CMakeLists.txt         |  1 -
 4 files changed, 5 insertions(+), 72 deletions(-)
 delete mode 100644 mlir/unittests/Dialect/Arith/ArithOpsFoldersTest.cpp
 delete mode 100644 mlir/unittests/Dialect/Arith/CMakeLists.txt

diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 55e8fcde0c498..641b7d7e2d13b 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1739,25 +1739,12 @@ OpFoldResult arith::BitcastOp::fold(FoldAdaptor adaptor) {
   APInt bits = llvm::isa<FloatAttr>(operand)
                    ? llvm::cast<FloatAttr>(operand).getValue().bitcastToAPInt()
                    : llvm::cast<IntegerAttr>(operand).getValue();
+  assert(resType.getIntOrFloatBitWidth() == bits.getBitWidth() &&
+         "trying to fold on broken IR: operands have incompatible types");
 
-  /// If bitwidth aren't the same, don't fold.
-  if (resType.getIntOrFloatBitWidth() != bits.getBitWidth())
-    return {};
-
-  MLIRContext *ctx = getContext();
-  auto emitErrorFn = [=] { return ::emitError(UnknownLoc::get(ctx)); };
-
-  if (auto resFloatType = llvm::dyn_cast<FloatType>(resType)) {
-    /// If bits don't represent a valid float, don't fold.
-    APFloat floatBits(resFloatType.getFloatSemantics(), bits);
-    if (failed(FloatAttr::verify(emitErrorFn, resType, floatBits)))
-      return {};
-    return FloatAttr::get(resType, floatBits);
-  }
-
-  /// If bits don't represent a valid integer, don't fold.
-  if (failed(IntegerAttr::verify(emitErrorFn, resType, bits)))
-    return {};
+  if (auto resFloatType = llvm::dyn_cast<FloatType>(resType))
+    return FloatAttr::get(resType,
+                          APFloat(resFloatType.getFloatSemantics(), bits));
   return IntegerAttr::get(resType, bits);
 }
 
diff --git a/mlir/unittests/Dialect/Arith/ArithOpsFoldersTest.cpp b/mlir/unittests/Dialect/Arith/ArithOpsFoldersTest.cpp
deleted file mode 100644
index 06c255f01ad69..0000000000000
--- a/mlir/unittests/Dialect/Arith/ArithOpsFoldersTest.cpp
+++ /dev/null
@@ -1,46 +0,0 @@
-//===- ArithOpsFoldersTest.cpp - unit tests for arith op folders ----------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/Verifier.h"
-#include "gtest/gtest.h"
-
-using namespace mlir;
-
-namespace {
-// Tests a regression that made `BitcastOp::fold` crash on invalid input IR, see
-// #100743,
-TEST(BitcastOpTest, FoldInteger) {
-  MLIRContext context;
-  context.loadDialect<arith::ArithDialect>();
-  auto loc = UnknownLoc::get(&context);
-  auto module = ModuleOp::create(loc);
-  OpBuilder builder(module.getBodyRegion());
-  Value i32Val = builder.create<arith::ConstantOp>(
-      loc, builder.getI32Type(), builder.getI32IntegerAttr(0));
-  // This would create an invalid op: `bitcast` can't cast different bitwidths.
-  builder.createOrFold<arith::BitcastOp>(loc, builder.getI64Type(), i32Val);
-  ASSERT_TRUE(failed(verify(module)));
-}
-
-// Tests a regression that made `BitcastOp::fold` crash on invalid input IR, see
-// #100743,
-TEST(BitcastOpTest, FoldFloat) {
-  MLIRContext context;
-  context.loadDialect<arith::ArithDialect>();
-  auto loc = UnknownLoc::get(&context);
-  auto module = ModuleOp::create(loc);
-  OpBuilder builder(module.getBodyRegion());
-  Value f32Val = builder.create<arith::ConstantOp>(loc, builder.getF32Type(),
-                                                   builder.getF32FloatAttr(0));
-  // This would create an invalid op: `bitcast` can't cast different bitwidths.
-  builder.createOrFold<arith::BitcastOp>(loc, builder.getF64Type(), f32Val);
-  ASSERT_TRUE(failed(verify(module)));
-}
-} // namespace
diff --git a/mlir/unittests/Dialect/Arith/CMakeLists.txt b/mlir/unittests/Dialect/Arith/CMakeLists.txt
deleted file mode 100644
index ac6b701529d3f..0000000000000
--- a/mlir/unittests/Dialect/Arith/CMakeLists.txt
+++ /dev/null
@@ -1,7 +0,0 @@
-add_mlir_unittest(MLIRArithOpsTests
-  ArithOpsFoldersTest.cpp
-)
-target_link_libraries(MLIRArithOpsTests
-  PRIVATE
-  MLIRArithDialect
-)
diff --git a/mlir/unittests/Dialect/CMakeLists.txt b/mlir/unittests/Dialect/CMakeLists.txt
index e88e8c61fcb13..90a75d5a46ad9 100644
--- a/mlir/unittests/Dialect/CMakeLists.txt
+++ b/mlir/unittests/Dialect/CMakeLists.txt
@@ -6,7 +6,6 @@ target_link_libraries(MLIRDialectTests
   MLIRIR
   MLIRDialect)
 
-add_subdirectory(Arith)
 add_subdirectory(ArmSME)
 add_subdirectory(Index)
 add_subdirectory(LLVMIR)



More information about the Mlir-commits mailing list