[Mlir-commits] [mlir] [mlir][arith] Check for valid IR in BitcastOp::fold. (PR #100743)
Ingo Müller
llvmlistbot at llvm.org
Sun Jul 28 23:37:13 PDT 2024
https://github.com/ingomueller-net updated https://github.com/llvm/llvm-project/pull/100743
>From a8b01cd5fa21fb85db91d02a57a9191a88a18823 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ingo=20M=C3=BCller?= <ingomueller at google.com>
Date: Fri, 26 Jul 2024 13:30:21 +0000
Subject: [PATCH 1/2] [mlir][arith] Check for valid IR in BitcastOp::fold.
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
This PR prevents `BitcastOp::fold` to create invalid `IntegerAttr` and
`FloatAttr` values, which result in failed assertions. This can happen
if the input IR is invalid. The PR adds tests for whether the
to-be-created attribute verifies and returns early from `fold`if it
doesn't.
Signed-off-by: Ingo Müller <ingomueller at google.com>
---
mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 21 +++++++--
.../Dialect/Arith/ArithOpsFoldersTest.cpp | 46 +++++++++++++++++++
mlir/unittests/Dialect/Arith/CMakeLists.txt | 7 +++
mlir/unittests/Dialect/CMakeLists.txt | 1 +
4 files changed, 72 insertions(+), 3 deletions(-)
create mode 100644 mlir/unittests/Dialect/Arith/ArithOpsFoldersTest.cpp
create mode 100644 mlir/unittests/Dialect/Arith/CMakeLists.txt
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index aa5eb95a3d22e..55e8fcde0c498 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1740,9 +1740,24 @@ OpFoldResult arith::BitcastOp::fold(FoldAdaptor adaptor) {
? llvm::cast<FloatAttr>(operand).getValue().bitcastToAPInt()
: llvm::cast<IntegerAttr>(operand).getValue();
- if (auto resFloatType = llvm::dyn_cast<FloatType>(resType))
- return FloatAttr::get(resType,
- APFloat(resFloatType.getFloatSemantics(), bits));
+ /// If bitwidth aren't the same, don't fold.
+ if (resType.getIntOrFloatBitWidth() != bits.getBitWidth())
+ return {};
+
+ MLIRContext *ctx = getContext();
+ auto emitErrorFn = [=] { return ::emitError(UnknownLoc::get(ctx)); };
+
+ if (auto resFloatType = llvm::dyn_cast<FloatType>(resType)) {
+ /// If bits don't represent a valid float, don't fold.
+ APFloat floatBits(resFloatType.getFloatSemantics(), bits);
+ if (failed(FloatAttr::verify(emitErrorFn, resType, floatBits)))
+ return {};
+ return FloatAttr::get(resType, floatBits);
+ }
+
+ /// If bits don't represent a valid integer, don't fold.
+ if (failed(IntegerAttr::verify(emitErrorFn, resType, bits)))
+ return {};
return IntegerAttr::get(resType, bits);
}
diff --git a/mlir/unittests/Dialect/Arith/ArithOpsFoldersTest.cpp b/mlir/unittests/Dialect/Arith/ArithOpsFoldersTest.cpp
new file mode 100644
index 0000000000000..06c255f01ad69
--- /dev/null
+++ b/mlir/unittests/Dialect/Arith/ArithOpsFoldersTest.cpp
@@ -0,0 +1,46 @@
+//===- ArithOpsFoldersTest.cpp - unit tests for arith op folders ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Verifier.h"
+#include "gtest/gtest.h"
+
+using namespace mlir;
+
+namespace {
+// Tests a regression that made `BitcastOp::fold` crash on invalid input IR, see
+// #100743,
+TEST(BitcastOpTest, FoldInteger) {
+ MLIRContext context;
+ context.loadDialect<arith::ArithDialect>();
+ auto loc = UnknownLoc::get(&context);
+ auto module = ModuleOp::create(loc);
+ OpBuilder builder(module.getBodyRegion());
+ Value i32Val = builder.create<arith::ConstantOp>(
+ loc, builder.getI32Type(), builder.getI32IntegerAttr(0));
+ // This would create an invalid op: `bitcast` can't cast different bitwidths.
+ builder.createOrFold<arith::BitcastOp>(loc, builder.getI64Type(), i32Val);
+ ASSERT_TRUE(failed(verify(module)));
+}
+
+// Tests a regression that made `BitcastOp::fold` crash on invalid input IR, see
+// #100743,
+TEST(BitcastOpTest, FoldFloat) {
+ MLIRContext context;
+ context.loadDialect<arith::ArithDialect>();
+ auto loc = UnknownLoc::get(&context);
+ auto module = ModuleOp::create(loc);
+ OpBuilder builder(module.getBodyRegion());
+ Value f32Val = builder.create<arith::ConstantOp>(loc, builder.getF32Type(),
+ builder.getF32FloatAttr(0));
+ // This would create an invalid op: `bitcast` can't cast different bitwidths.
+ builder.createOrFold<arith::BitcastOp>(loc, builder.getF64Type(), f32Val);
+ ASSERT_TRUE(failed(verify(module)));
+}
+} // namespace
diff --git a/mlir/unittests/Dialect/Arith/CMakeLists.txt b/mlir/unittests/Dialect/Arith/CMakeLists.txt
new file mode 100644
index 0000000000000..ac6b701529d3f
--- /dev/null
+++ b/mlir/unittests/Dialect/Arith/CMakeLists.txt
@@ -0,0 +1,7 @@
+add_mlir_unittest(MLIRArithOpsTests
+ ArithOpsFoldersTest.cpp
+)
+target_link_libraries(MLIRArithOpsTests
+ PRIVATE
+ MLIRArithDialect
+)
diff --git a/mlir/unittests/Dialect/CMakeLists.txt b/mlir/unittests/Dialect/CMakeLists.txt
index 90a75d5a46ad9..e88e8c61fcb13 100644
--- a/mlir/unittests/Dialect/CMakeLists.txt
+++ b/mlir/unittests/Dialect/CMakeLists.txt
@@ -6,6 +6,7 @@ target_link_libraries(MLIRDialectTests
MLIRIR
MLIRDialect)
+add_subdirectory(Arith)
add_subdirectory(ArmSME)
add_subdirectory(Index)
add_subdirectory(LLVMIR)
>From 293c6f02049df809f14b9f51fdd2232f107f91b9 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ingo=20M=C3=BCller?= <ingomueller at google.com>
Date: Mon, 29 Jul 2024 06:36:49 +0000
Subject: [PATCH 2/2] Remove unittest. Replace early exists with assert.
---
mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 23 ++--------
.../Dialect/Arith/ArithOpsFoldersTest.cpp | 46 -------------------
mlir/unittests/Dialect/Arith/CMakeLists.txt | 7 ---
mlir/unittests/Dialect/CMakeLists.txt | 1 -
4 files changed, 5 insertions(+), 72 deletions(-)
delete mode 100644 mlir/unittests/Dialect/Arith/ArithOpsFoldersTest.cpp
delete mode 100644 mlir/unittests/Dialect/Arith/CMakeLists.txt
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 55e8fcde0c498..641b7d7e2d13b 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1739,25 +1739,12 @@ OpFoldResult arith::BitcastOp::fold(FoldAdaptor adaptor) {
APInt bits = llvm::isa<FloatAttr>(operand)
? llvm::cast<FloatAttr>(operand).getValue().bitcastToAPInt()
: llvm::cast<IntegerAttr>(operand).getValue();
+ assert(resType.getIntOrFloatBitWidth() == bits.getBitWidth() &&
+ "trying to fold on broken IR: operands have incompatible types");
- /// If bitwidth aren't the same, don't fold.
- if (resType.getIntOrFloatBitWidth() != bits.getBitWidth())
- return {};
-
- MLIRContext *ctx = getContext();
- auto emitErrorFn = [=] { return ::emitError(UnknownLoc::get(ctx)); };
-
- if (auto resFloatType = llvm::dyn_cast<FloatType>(resType)) {
- /// If bits don't represent a valid float, don't fold.
- APFloat floatBits(resFloatType.getFloatSemantics(), bits);
- if (failed(FloatAttr::verify(emitErrorFn, resType, floatBits)))
- return {};
- return FloatAttr::get(resType, floatBits);
- }
-
- /// If bits don't represent a valid integer, don't fold.
- if (failed(IntegerAttr::verify(emitErrorFn, resType, bits)))
- return {};
+ if (auto resFloatType = llvm::dyn_cast<FloatType>(resType))
+ return FloatAttr::get(resType,
+ APFloat(resFloatType.getFloatSemantics(), bits));
return IntegerAttr::get(resType, bits);
}
diff --git a/mlir/unittests/Dialect/Arith/ArithOpsFoldersTest.cpp b/mlir/unittests/Dialect/Arith/ArithOpsFoldersTest.cpp
deleted file mode 100644
index 06c255f01ad69..0000000000000
--- a/mlir/unittests/Dialect/Arith/ArithOpsFoldersTest.cpp
+++ /dev/null
@@ -1,46 +0,0 @@
-//===- ArithOpsFoldersTest.cpp - unit tests for arith op folders ----------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/Verifier.h"
-#include "gtest/gtest.h"
-
-using namespace mlir;
-
-namespace {
-// Tests a regression that made `BitcastOp::fold` crash on invalid input IR, see
-// #100743,
-TEST(BitcastOpTest, FoldInteger) {
- MLIRContext context;
- context.loadDialect<arith::ArithDialect>();
- auto loc = UnknownLoc::get(&context);
- auto module = ModuleOp::create(loc);
- OpBuilder builder(module.getBodyRegion());
- Value i32Val = builder.create<arith::ConstantOp>(
- loc, builder.getI32Type(), builder.getI32IntegerAttr(0));
- // This would create an invalid op: `bitcast` can't cast different bitwidths.
- builder.createOrFold<arith::BitcastOp>(loc, builder.getI64Type(), i32Val);
- ASSERT_TRUE(failed(verify(module)));
-}
-
-// Tests a regression that made `BitcastOp::fold` crash on invalid input IR, see
-// #100743,
-TEST(BitcastOpTest, FoldFloat) {
- MLIRContext context;
- context.loadDialect<arith::ArithDialect>();
- auto loc = UnknownLoc::get(&context);
- auto module = ModuleOp::create(loc);
- OpBuilder builder(module.getBodyRegion());
- Value f32Val = builder.create<arith::ConstantOp>(loc, builder.getF32Type(),
- builder.getF32FloatAttr(0));
- // This would create an invalid op: `bitcast` can't cast different bitwidths.
- builder.createOrFold<arith::BitcastOp>(loc, builder.getF64Type(), f32Val);
- ASSERT_TRUE(failed(verify(module)));
-}
-} // namespace
diff --git a/mlir/unittests/Dialect/Arith/CMakeLists.txt b/mlir/unittests/Dialect/Arith/CMakeLists.txt
deleted file mode 100644
index ac6b701529d3f..0000000000000
--- a/mlir/unittests/Dialect/Arith/CMakeLists.txt
+++ /dev/null
@@ -1,7 +0,0 @@
-add_mlir_unittest(MLIRArithOpsTests
- ArithOpsFoldersTest.cpp
-)
-target_link_libraries(MLIRArithOpsTests
- PRIVATE
- MLIRArithDialect
-)
diff --git a/mlir/unittests/Dialect/CMakeLists.txt b/mlir/unittests/Dialect/CMakeLists.txt
index e88e8c61fcb13..90a75d5a46ad9 100644
--- a/mlir/unittests/Dialect/CMakeLists.txt
+++ b/mlir/unittests/Dialect/CMakeLists.txt
@@ -6,7 +6,6 @@ target_link_libraries(MLIRDialectTests
MLIRIR
MLIRDialect)
-add_subdirectory(Arith)
add_subdirectory(ArmSME)
add_subdirectory(Index)
add_subdirectory(LLVMIR)
More information about the Mlir-commits
mailing list