[Mlir-commits] [mlir] [mlir][index] Implement folders for CastSOp and CastUOp (PR #66960)

Jeff Niu llvmlistbot at llvm.org
Wed Sep 20 16:37:25 PDT 2023


https://github.com/Mogball created https://github.com/llvm/llvm-project/pull/66960

Fixes https://github.com/llvm/llvm-project/issues/66402

>From b8f5804651e4e3861d004cb50d6205eaf36de0c0 Mon Sep 17 00:00:00 2001
From: Mogball <jeff at modular.com>
Date: Wed, 20 Sep 2023 16:30:40 -0700
Subject: [PATCH] [mlir][index] Implement folders for CastSOp and CastUOp

---
 .../include/mlir/Dialect/Index/IR/IndexOps.td |  6 +-
 mlir/lib/Dialect/Index/IR/IndexOps.cpp        | 59 +++++++++++++++++++
 .../Dialect/Index/index-canonicalize.mlir     |  8 +++
 3 files changed, 71 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Index/IR/IndexOps.td b/mlir/include/mlir/Dialect/Index/IR/IndexOps.td
index 61cdf4ed0877a0f..c6079cb8a98c813 100644
--- a/mlir/include/mlir/Dialect/Index/IR/IndexOps.td
+++ b/mlir/include/mlir/Dialect/Index/IR/IndexOps.td
@@ -446,7 +446,7 @@ def Index_XOrOp : IndexBinaryOp<"xor", [Commutative, Pure]> {
 // CastSOp
 //===----------------------------------------------------------------------===//
 
-def Index_CastSOp : IndexOp<"casts", [Pure, 
+def Index_CastSOp : IndexOp<"casts", [Pure,
     DeclareOpInterfaceMethods<CastOpInterface>]> {
   let summary = "index signed cast";
   let description = [{
@@ -469,13 +469,14 @@ def Index_CastSOp : IndexOp<"casts", [Pure,
   let arguments = (ins AnyTypeOf<[AnyInteger, Index]>:$input);
   let results = (outs AnyTypeOf<[AnyInteger, Index]>:$output);
   let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)";
+  let hasFolder = 1;
 }
 
 //===----------------------------------------------------------------------===//
 // CastUOp
 //===----------------------------------------------------------------------===//
 
-def Index_CastUOp : IndexOp<"castu", [Pure, 
+def Index_CastUOp : IndexOp<"castu", [Pure,
     DeclareOpInterfaceMethods<CastOpInterface>]> {
   let summary = "index unsigned cast";
   let description = [{
@@ -498,6 +499,7 @@ def Index_CastUOp : IndexOp<"castu", [Pure,
   let arguments = (ins AnyTypeOf<[AnyInteger, Index]>:$input);
   let results = (outs AnyTypeOf<[AnyInteger, Index]>:$output);
   let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)";
+  let hasFolder = 1;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Index/IR/IndexOps.cpp b/mlir/lib/Dialect/Index/IR/IndexOps.cpp
index b6d802876c15ede..b506397742772a7 100644
--- a/mlir/lib/Dialect/Index/IR/IndexOps.cpp
+++ b/mlir/lib/Dialect/Index/IR/IndexOps.cpp
@@ -444,11 +444,63 @@ OpFoldResult XOrOp::fold(FoldAdaptor adaptor) {
 // CastSOp
 //===----------------------------------------------------------------------===//
 
+static OpFoldResult
+foldCastOp(Attribute input, Type type,
+           function_ref<APInt(const APInt &, unsigned)> extFn,
+           function_ref<APInt(const APInt &, unsigned)> extOrTruncFn) {
+  auto attr = dyn_cast_if_present<IntegerAttr>(input);
+  if (!attr)
+    return {};
+  const APInt &value = attr.getValue();
+
+  if (isa<IndexType>(type)) {
+    // When casting to an index type, perform the cast assuming a 64-bit target.
+    // The result can be truncated to 32 bits as needed and always be correct.
+    // This is because `cast32(cast64(value)) == cast32(value)`.
+    APInt result = extOrTruncFn(value, 64);
+    return IntegerAttr::get(type, result);
+  }
+
+  // When casting from an index type, we must ensure the results respect
+  // `cast_t(value) == cast_t(trunc32(value))`.
+  auto intType = cast<IntegerType>(type);
+  unsigned width = intType.getWidth();
+
+  // If the result type is at most 32 bits, then the cast can always be folded
+  // because it is always a truncation.
+  if (width <= 32) {
+    APInt result = value.trunc(width);
+    return IntegerAttr::get(type, result);
+  }
+
+  // If the result type is at least 64 bits, then the cast is always a
+  // extension. The results will differ if `trunc32(value) != value)`.
+  if (width >= 64) {
+    if (extFn(value.trunc(32), 64) != value)
+      return {};
+    APInt result = extFn(value, width);
+    return IntegerAttr::get(type, result);
+  }
+
+  // Otherwise, we just have to check the property directly.
+  APInt result = value.trunc(width);
+  if (result != extFn(value.trunc(32), width))
+    return {};
+  return IntegerAttr::get(type, result);
+}
+
 bool CastSOp::areCastCompatible(TypeRange lhsTypes, TypeRange rhsTypes) {
   return llvm::isa<IndexType>(lhsTypes.front()) !=
          llvm::isa<IndexType>(rhsTypes.front());
 }
 
+OpFoldResult CastSOp::fold(FoldAdaptor adaptor) {
+  return foldCastOp(
+      adaptor.getInput(), getType(),
+      [](const APInt &x, unsigned width) { return x.sext(width); },
+      [](const APInt &x, unsigned width) { return x.sextOrTrunc(width); });
+}
+
 //===----------------------------------------------------------------------===//
 // CastUOp
 //===----------------------------------------------------------------------===//
@@ -458,6 +510,13 @@ bool CastUOp::areCastCompatible(TypeRange lhsTypes, TypeRange rhsTypes) {
          llvm::isa<IndexType>(rhsTypes.front());
 }
 
+OpFoldResult CastUOp::fold(FoldAdaptor adaptor) {
+  return foldCastOp(
+      adaptor.getInput(), getType(),
+      [](const APInt &x, unsigned width) { return x.zext(width); },
+      [](const APInt &x, unsigned width) { return x.zextOrTrunc(width); });
+}
+
 //===----------------------------------------------------------------------===//
 // CmpOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Index/index-canonicalize.mlir b/mlir/test/Dialect/Index/index-canonicalize.mlir
index 67308ffbe55ac6d..f3eae3605b3b64f 100644
--- a/mlir/test/Dialect/Index/index-canonicalize.mlir
+++ b/mlir/test/Dialect/Index/index-canonicalize.mlir
@@ -556,3 +556,11 @@ func.func @sub_identity(%arg0: index) -> index {
   // CHECK-NEXT: return %arg0
   return %0 : index
 }
+
+// CHECK-LABEL: @castu_to_index
+func.func @castu_to_index() -> index {
+  // CHECK: index.constant 8000000000000
+  %0 = arith.constant 8000000000000 : i48
+  %1 = index.castu %0 : i48 to index
+  return %1 : index
+}



More information about the Mlir-commits mailing list