[Mlir-commits] [mlir] 9744d39 - [mlir][index] Implement folders for CastSOp and CastUOp (#66960)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Sep 21 11:53:47 PDT 2023
Author: Jeff Niu
Date: 2023-09-21T11:53:43-07:00
New Revision: 9744d396f8823eaeb15c856f2fa45d8dd4f1d60b
URL: https://github.com/llvm/llvm-project/commit/9744d396f8823eaeb15c856f2fa45d8dd4f1d60b
DIFF: https://github.com/llvm/llvm-project/commit/9744d396f8823eaeb15c856f2fa45d8dd4f1d60b.diff
LOG: [mlir][index] Implement folders for CastSOp and CastUOp (#66960)
Fixes https://github.com/llvm/llvm-project/issues/66402
Added:
mlir/unittests/Dialect/Index/CMakeLists.txt
mlir/unittests/Dialect/Index/IndexOpsFoldersTest.cpp
Modified:
mlir/include/mlir/Dialect/Index/IR/IndexOps.td
mlir/lib/Dialect/Index/IR/IndexOps.cpp
mlir/test/Dialect/Index/index-canonicalize.mlir
mlir/unittests/Dialect/CMakeLists.txt
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Index/IR/IndexOps.td b/mlir/include/mlir/Dialect/Index/IR/IndexOps.td
index 61cdf4ed0877a0f..c6079cb8a98c813 100644
--- a/mlir/include/mlir/Dialect/Index/IR/IndexOps.td
+++ b/mlir/include/mlir/Dialect/Index/IR/IndexOps.td
@@ -446,7 +446,7 @@ def Index_XOrOp : IndexBinaryOp<"xor", [Commutative, Pure]> {
// CastSOp
//===----------------------------------------------------------------------===//
-def Index_CastSOp : IndexOp<"casts", [Pure,
+def Index_CastSOp : IndexOp<"casts", [Pure,
DeclareOpInterfaceMethods<CastOpInterface>]> {
let summary = "index signed cast";
let description = [{
@@ -469,13 +469,14 @@ def Index_CastSOp : IndexOp<"casts", [Pure,
let arguments = (ins AnyTypeOf<[AnyInteger, Index]>:$input);
let results = (outs AnyTypeOf<[AnyInteger, Index]>:$output);
let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)";
+ let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
// CastUOp
//===----------------------------------------------------------------------===//
-def Index_CastUOp : IndexOp<"castu", [Pure,
+def Index_CastUOp : IndexOp<"castu", [Pure,
DeclareOpInterfaceMethods<CastOpInterface>]> {
let summary = "index unsigned cast";
let description = [{
@@ -498,6 +499,7 @@ def Index_CastUOp : IndexOp<"castu", [Pure,
let arguments = (ins AnyTypeOf<[AnyInteger, Index]>:$input);
let results = (outs AnyTypeOf<[AnyInteger, Index]>:$output);
let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)";
+ let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Index/IR/IndexOps.cpp b/mlir/lib/Dialect/Index/IR/IndexOps.cpp
index b6d802876c15ede..b506397742772a7 100644
--- a/mlir/lib/Dialect/Index/IR/IndexOps.cpp
+++ b/mlir/lib/Dialect/Index/IR/IndexOps.cpp
@@ -444,11 +444,63 @@ OpFoldResult XOrOp::fold(FoldAdaptor adaptor) {
// CastSOp
//===----------------------------------------------------------------------===//
+static OpFoldResult
+foldCastOp(Attribute input, Type type,
+ function_ref<APInt(const APInt &, unsigned)> extFn,
+ function_ref<APInt(const APInt &, unsigned)> extOrTruncFn) {
+ auto attr = dyn_cast_if_present<IntegerAttr>(input);
+ if (!attr)
+ return {};
+ const APInt &value = attr.getValue();
+
+ if (isa<IndexType>(type)) {
+ // When casting to an index type, perform the cast assuming a 64-bit target.
+ // The result can be truncated to 32 bits as needed and always be correct.
+ // This is because `cast32(cast64(value)) == cast32(value)`.
+ APInt result = extOrTruncFn(value, 64);
+ return IntegerAttr::get(type, result);
+ }
+
+ // When casting from an index type, we must ensure the results respect
+ // `cast_t(value) == cast_t(trunc32(value))`.
+ auto intType = cast<IntegerType>(type);
+ unsigned width = intType.getWidth();
+
+ // If the result type is at most 32 bits, then the cast can always be folded
+ // because it is always a truncation.
+ if (width <= 32) {
+ APInt result = value.trunc(width);
+ return IntegerAttr::get(type, result);
+ }
+
+ // If the result type is at least 64 bits, then the cast is always a
+ // extension. The results will
diff er if `trunc32(value) != value)`.
+ if (width >= 64) {
+ if (extFn(value.trunc(32), 64) != value)
+ return {};
+ APInt result = extFn(value, width);
+ return IntegerAttr::get(type, result);
+ }
+
+ // Otherwise, we just have to check the property directly.
+ APInt result = value.trunc(width);
+ if (result != extFn(value.trunc(32), width))
+ return {};
+ return IntegerAttr::get(type, result);
+}
+
bool CastSOp::areCastCompatible(TypeRange lhsTypes, TypeRange rhsTypes) {
return llvm::isa<IndexType>(lhsTypes.front()) !=
llvm::isa<IndexType>(rhsTypes.front());
}
+OpFoldResult CastSOp::fold(FoldAdaptor adaptor) {
+ return foldCastOp(
+ adaptor.getInput(), getType(),
+ [](const APInt &x, unsigned width) { return x.sext(width); },
+ [](const APInt &x, unsigned width) { return x.sextOrTrunc(width); });
+}
+
//===----------------------------------------------------------------------===//
// CastUOp
//===----------------------------------------------------------------------===//
@@ -458,6 +510,13 @@ bool CastUOp::areCastCompatible(TypeRange lhsTypes, TypeRange rhsTypes) {
llvm::isa<IndexType>(rhsTypes.front());
}
+OpFoldResult CastUOp::fold(FoldAdaptor adaptor) {
+ return foldCastOp(
+ adaptor.getInput(), getType(),
+ [](const APInt &x, unsigned width) { return x.zext(width); },
+ [](const APInt &x, unsigned width) { return x.zextOrTrunc(width); });
+}
+
//===----------------------------------------------------------------------===//
// CmpOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Index/index-canonicalize.mlir b/mlir/test/Dialect/Index/index-canonicalize.mlir
index 67308ffbe55ac6d..db03505350b77e3 100644
--- a/mlir/test/Dialect/Index/index-canonicalize.mlir
+++ b/mlir/test/Dialect/Index/index-canonicalize.mlir
@@ -556,3 +556,19 @@ func.func @sub_identity(%arg0: index) -> index {
// CHECK-NEXT: return %arg0
return %0 : index
}
+
+// CHECK-LABEL: @castu_to_index
+func.func @castu_to_index() -> index {
+ // CHECK: index.constant 8000000000000
+ %0 = arith.constant 8000000000000 : i48
+ %1 = index.castu %0 : i48 to index
+ return %1 : index
+}
+
+// CHECK-LABEL: @casts_to_index
+func.func @casts_to_index() -> index {
+ // CHECK: index.constant -1000
+ %0 = arith.constant -1000 : i48
+ %1 = index.casts %0 : i48 to index
+ return %1 : index
+}
diff --git a/mlir/unittests/Dialect/CMakeLists.txt b/mlir/unittests/Dialect/CMakeLists.txt
index 522aeca29146d17..2d2835c64b9844f 100644
--- a/mlir/unittests/Dialect/CMakeLists.txt
+++ b/mlir/unittests/Dialect/CMakeLists.txt
@@ -6,6 +6,7 @@ target_link_libraries(MLIRDialectTests
MLIRIR
MLIRDialect)
+add_subdirectory(Index)
add_subdirectory(LLVMIR)
add_subdirectory(MemRef)
add_subdirectory(SparseTensor)
diff --git a/mlir/unittests/Dialect/Index/CMakeLists.txt b/mlir/unittests/Dialect/Index/CMakeLists.txt
new file mode 100644
index 000000000000000..c4bac2371e52fb3
--- /dev/null
+++ b/mlir/unittests/Dialect/Index/CMakeLists.txt
@@ -0,0 +1,7 @@
+add_mlir_unittest(MLIRIndexOpsTests
+ IndexOpsFoldersTest.cpp
+)
+target_link_libraries(MLIRIndexOpsTests
+ PRIVATE
+ MLIRIndexDialect
+)
diff --git a/mlir/unittests/Dialect/Index/IndexOpsFoldersTest.cpp b/mlir/unittests/Dialect/Index/IndexOpsFoldersTest.cpp
new file mode 100644
index 000000000000000..948033ddb5934a6
--- /dev/null
+++ b/mlir/unittests/Dialect/Index/IndexOpsFoldersTest.cpp
@@ -0,0 +1,104 @@
+//===- IndexOpsFoldersTest.cpp - unit tests for index op folders ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Index/IR/IndexDialect.h"
+#include "mlir/Dialect/Index/IR/IndexOps.h"
+#include "mlir/IR/OwningOpRef.h"
+#include "gtest/gtest.h"
+
+using namespace mlir;
+
+namespace {
+/// Test fixture for testing operation folders.
+class IndexFolderTest : public testing::Test {
+public:
+ IndexFolderTest() { ctx.getOrLoadDialect<index::IndexDialect>(); }
+
+ /// Instantiate an operation, invoke its folder, and return the attribute
+ /// result.
+ template <typename OpT>
+ void foldOp(IntegerAttr &value, Type type, ArrayRef<Attribute> operands);
+
+protected:
+ /// The MLIR context to use.
+ MLIRContext ctx;
+ /// A builder to use.
+ OpBuilder b{&ctx};
+};
+} // namespace
+
+template <typename OpT>
+void IndexFolderTest::foldOp(IntegerAttr &value, Type type,
+ ArrayRef<Attribute> operands) {
+ // This function returns null so that `ASSERT_*` works within it.
+ OperationState state(UnknownLoc::get(&ctx), OpT::getOperationName());
+ state.addTypes(type);
+ OwningOpRef<OpT> op = cast<OpT>(b.create(state));
+ SmallVector<OpFoldResult> results;
+ LogicalResult result = op->getOperation()->fold(operands, results);
+ // Propagate the failure to the test.
+ if (failed(result)) {
+ value = nullptr;
+ return;
+ }
+ ASSERT_EQ(results.size(), 1u);
+ value = dyn_cast_or_null<IntegerAttr>(dyn_cast<Attribute>(results.front()));
+ ASSERT_TRUE(value);
+}
+
+TEST_F(IndexFolderTest, TestCastUOpFolder) {
+ IntegerAttr value;
+ auto fold = [&](Type type, Attribute input) {
+ foldOp<index::CastUOp>(value, type, input);
+ };
+
+ // Target width less than or equal to 32 bits.
+ fold(b.getIntegerType(16), b.getIndexAttr(8000000000));
+ ASSERT_TRUE(value);
+ EXPECT_EQ(value.getInt(), 20480u);
+
+ // Target width greater than or equal to 64 bits.
+ fold(b.getIntegerType(64), b.getIndexAttr(2000));
+ ASSERT_TRUE(value);
+ EXPECT_EQ(value.getInt(), 2000u);
+
+ // Fails to fold, because truncating to 32 bits and then extending creates a
+ //
diff erent value.
+ fold(b.getIntegerType(64), b.getIndexAttr(8000000000));
+ EXPECT_FALSE(value);
+
+ // Target width between 32 and 64 bits.
+ fold(b.getIntegerType(40), b.getIndexAttr(0x10000000010000));
+ // Fold succeeds because the upper bits are truncated in the cast.
+ ASSERT_TRUE(value);
+ EXPECT_EQ(value.getInt(), 65536);
+
+ // Fails to fold because the upper bits are not truncated.
+ fold(b.getIntegerType(60), b.getIndexAttr(0x10000000010000));
+ EXPECT_FALSE(value);
+}
+
+TEST_F(IndexFolderTest, TestCastSOpFolder) {
+ IntegerAttr value;
+ auto fold = [&](Type type, Attribute input) {
+ foldOp<index::CastSOp>(value, type, input);
+ };
+
+ // Just test the extension cases to ensure signs are being respected.
+
+ // Target width greater than or equal to 64 bits.
+ fold(b.getIntegerType(64), b.getIndexAttr(-2000));
+ ASSERT_TRUE(value);
+ EXPECT_EQ(value.getInt(), -2000);
+
+ // Target width between 32 and 64 bits.
+ fold(b.getIntegerType(40), b.getIndexAttr(-0x10000000010000));
+ // Fold succeeds because the upper bits are truncated in the cast.
+ ASSERT_TRUE(value);
+ EXPECT_EQ(value.getInt(), -65536);
+}
More information about the Mlir-commits
mailing list