[Mlir-commits] [mlir] [mlir][index] Implement folders for CastSOp and CastUOp (PR #66960)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Sep 20 16:38:34 PDT 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-index
<details>
<summary>Changes</summary>
Fixes https://github.com/llvm/llvm-project/issues/66402
---
Full diff: https://github.com/llvm/llvm-project/pull/66960.diff
3 Files Affected:
- (modified) mlir/include/mlir/Dialect/Index/IR/IndexOps.td (+4-2)
- (modified) mlir/lib/Dialect/Index/IR/IndexOps.cpp (+59)
- (modified) mlir/test/Dialect/Index/index-canonicalize.mlir (+8)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Index/IR/IndexOps.td b/mlir/include/mlir/Dialect/Index/IR/IndexOps.td
index 61cdf4ed0877a0f..c6079cb8a98c813 100644
--- a/mlir/include/mlir/Dialect/Index/IR/IndexOps.td
+++ b/mlir/include/mlir/Dialect/Index/IR/IndexOps.td
@@ -446,7 +446,7 @@ def Index_XOrOp : IndexBinaryOp<"xor", [Commutative, Pure]> {
// CastSOp
//===----------------------------------------------------------------------===//
-def Index_CastSOp : IndexOp<"casts", [Pure,
+def Index_CastSOp : IndexOp<"casts", [Pure,
DeclareOpInterfaceMethods<CastOpInterface>]> {
let summary = "index signed cast";
let description = [{
@@ -469,13 +469,14 @@ def Index_CastSOp : IndexOp<"casts", [Pure,
let arguments = (ins AnyTypeOf<[AnyInteger, Index]>:$input);
let results = (outs AnyTypeOf<[AnyInteger, Index]>:$output);
let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)";
+ let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
// CastUOp
//===----------------------------------------------------------------------===//
-def Index_CastUOp : IndexOp<"castu", [Pure,
+def Index_CastUOp : IndexOp<"castu", [Pure,
DeclareOpInterfaceMethods<CastOpInterface>]> {
let summary = "index unsigned cast";
let description = [{
@@ -498,6 +499,7 @@ def Index_CastUOp : IndexOp<"castu", [Pure,
let arguments = (ins AnyTypeOf<[AnyInteger, Index]>:$input);
let results = (outs AnyTypeOf<[AnyInteger, Index]>:$output);
let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)";
+ let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Index/IR/IndexOps.cpp b/mlir/lib/Dialect/Index/IR/IndexOps.cpp
index b6d802876c15ede..b506397742772a7 100644
--- a/mlir/lib/Dialect/Index/IR/IndexOps.cpp
+++ b/mlir/lib/Dialect/Index/IR/IndexOps.cpp
@@ -444,11 +444,63 @@ OpFoldResult XOrOp::fold(FoldAdaptor adaptor) {
// CastSOp
//===----------------------------------------------------------------------===//
+static OpFoldResult
+foldCastOp(Attribute input, Type type,
+ function_ref<APInt(const APInt &, unsigned)> extFn,
+ function_ref<APInt(const APInt &, unsigned)> extOrTruncFn) {
+ auto attr = dyn_cast_if_present<IntegerAttr>(input);
+ if (!attr)
+ return {};
+ const APInt &value = attr.getValue();
+
+ if (isa<IndexType>(type)) {
+ // When casting to an index type, perform the cast assuming a 64-bit target.
+ // The result can be truncated to 32 bits as needed and always be correct.
+ // This is because `cast32(cast64(value)) == cast32(value)`.
+ APInt result = extOrTruncFn(value, 64);
+ return IntegerAttr::get(type, result);
+ }
+
+ // When casting from an index type, we must ensure the results respect
+ // `cast_t(value) == cast_t(trunc32(value))`.
+ auto intType = cast<IntegerType>(type);
+ unsigned width = intType.getWidth();
+
+ // If the result type is at most 32 bits, then the cast can always be folded
+ // because it is always a truncation.
+ if (width <= 32) {
+ APInt result = value.trunc(width);
+ return IntegerAttr::get(type, result);
+ }
+
+ // If the result type is at least 64 bits, then the cast is always a
+ // extension. The results will differ if `trunc32(value) != value)`.
+ if (width >= 64) {
+ if (extFn(value.trunc(32), 64) != value)
+ return {};
+ APInt result = extFn(value, width);
+ return IntegerAttr::get(type, result);
+ }
+
+ // Otherwise, we just have to check the property directly.
+ APInt result = value.trunc(width);
+ if (result != extFn(value.trunc(32), width))
+ return {};
+ return IntegerAttr::get(type, result);
+}
+
bool CastSOp::areCastCompatible(TypeRange lhsTypes, TypeRange rhsTypes) {
return llvm::isa<IndexType>(lhsTypes.front()) !=
llvm::isa<IndexType>(rhsTypes.front());
}
+OpFoldResult CastSOp::fold(FoldAdaptor adaptor) {
+ return foldCastOp(
+ adaptor.getInput(), getType(),
+ [](const APInt &x, unsigned width) { return x.sext(width); },
+ [](const APInt &x, unsigned width) { return x.sextOrTrunc(width); });
+}
+
//===----------------------------------------------------------------------===//
// CastUOp
//===----------------------------------------------------------------------===//
@@ -458,6 +510,13 @@ bool CastUOp::areCastCompatible(TypeRange lhsTypes, TypeRange rhsTypes) {
llvm::isa<IndexType>(rhsTypes.front());
}
+OpFoldResult CastUOp::fold(FoldAdaptor adaptor) {
+ return foldCastOp(
+ adaptor.getInput(), getType(),
+ [](const APInt &x, unsigned width) { return x.zext(width); },
+ [](const APInt &x, unsigned width) { return x.zextOrTrunc(width); });
+}
+
//===----------------------------------------------------------------------===//
// CmpOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Index/index-canonicalize.mlir b/mlir/test/Dialect/Index/index-canonicalize.mlir
index 67308ffbe55ac6d..f3eae3605b3b64f 100644
--- a/mlir/test/Dialect/Index/index-canonicalize.mlir
+++ b/mlir/test/Dialect/Index/index-canonicalize.mlir
@@ -556,3 +556,11 @@ func.func @sub_identity(%arg0: index) -> index {
// CHECK-NEXT: return %arg0
return %0 : index
}
+
+// CHECK-LABEL: @castu_to_index
+func.func @castu_to_index() -> index {
+ // CHECK: index.constant 8000000000000
+ %0 = arith.constant 8000000000000 : i48
+ %1 = index.castu %0 : i48 to index
+ return %1 : index
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/66960
More information about the Mlir-commits
mailing list