[Mlir-commits] [mlir] 9744d39 - [mlir][index] Implement folders for CastSOp and CastUOp (#66960)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Sep 21 11:53:47 PDT 2023


Author: Jeff Niu
Date: 2023-09-21T11:53:43-07:00
New Revision: 9744d396f8823eaeb15c856f2fa45d8dd4f1d60b

URL: https://github.com/llvm/llvm-project/commit/9744d396f8823eaeb15c856f2fa45d8dd4f1d60b
DIFF: https://github.com/llvm/llvm-project/commit/9744d396f8823eaeb15c856f2fa45d8dd4f1d60b.diff

LOG: [mlir][index] Implement folders for CastSOp and CastUOp (#66960)

Fixes https://github.com/llvm/llvm-project/issues/66402

Added: 
    mlir/unittests/Dialect/Index/CMakeLists.txt
    mlir/unittests/Dialect/Index/IndexOpsFoldersTest.cpp

Modified: 
    mlir/include/mlir/Dialect/Index/IR/IndexOps.td
    mlir/lib/Dialect/Index/IR/IndexOps.cpp
    mlir/test/Dialect/Index/index-canonicalize.mlir
    mlir/unittests/Dialect/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Index/IR/IndexOps.td b/mlir/include/mlir/Dialect/Index/IR/IndexOps.td
index 61cdf4ed0877a0f..c6079cb8a98c813 100644
--- a/mlir/include/mlir/Dialect/Index/IR/IndexOps.td
+++ b/mlir/include/mlir/Dialect/Index/IR/IndexOps.td
@@ -446,7 +446,7 @@ def Index_XOrOp : IndexBinaryOp<"xor", [Commutative, Pure]> {
 // CastSOp
 //===----------------------------------------------------------------------===//
 
-def Index_CastSOp : IndexOp<"casts", [Pure, 
+def Index_CastSOp : IndexOp<"casts", [Pure,
     DeclareOpInterfaceMethods<CastOpInterface>]> {
   let summary = "index signed cast";
   let description = [{
@@ -469,13 +469,14 @@ def Index_CastSOp : IndexOp<"casts", [Pure,
   let arguments = (ins AnyTypeOf<[AnyInteger, Index]>:$input);
   let results = (outs AnyTypeOf<[AnyInteger, Index]>:$output);
   let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)";
+  let hasFolder = 1;
 }
 
 //===----------------------------------------------------------------------===//
 // CastUOp
 //===----------------------------------------------------------------------===//
 
-def Index_CastUOp : IndexOp<"castu", [Pure, 
+def Index_CastUOp : IndexOp<"castu", [Pure,
     DeclareOpInterfaceMethods<CastOpInterface>]> {
   let summary = "index unsigned cast";
   let description = [{
@@ -498,6 +499,7 @@ def Index_CastUOp : IndexOp<"castu", [Pure,
   let arguments = (ins AnyTypeOf<[AnyInteger, Index]>:$input);
   let results = (outs AnyTypeOf<[AnyInteger, Index]>:$output);
   let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)";
+  let hasFolder = 1;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Index/IR/IndexOps.cpp b/mlir/lib/Dialect/Index/IR/IndexOps.cpp
index b6d802876c15ede..b506397742772a7 100644
--- a/mlir/lib/Dialect/Index/IR/IndexOps.cpp
+++ b/mlir/lib/Dialect/Index/IR/IndexOps.cpp
@@ -444,11 +444,63 @@ OpFoldResult XOrOp::fold(FoldAdaptor adaptor) {
 // CastSOp
 //===----------------------------------------------------------------------===//
 
+static OpFoldResult
+foldCastOp(Attribute input, Type type,
+           function_ref<APInt(const APInt &, unsigned)> extFn,
+           function_ref<APInt(const APInt &, unsigned)> extOrTruncFn) {
+  auto attr = dyn_cast_if_present<IntegerAttr>(input);
+  if (!attr)
+    return {};
+  const APInt &value = attr.getValue();
+
+  if (isa<IndexType>(type)) {
+    // When casting to an index type, perform the cast assuming a 64-bit target.
+    // The result can be truncated to 32 bits as needed and always be correct.
+    // This is because `cast32(cast64(value)) == cast32(value)`.
+    APInt result = extOrTruncFn(value, 64);
+    return IntegerAttr::get(type, result);
+  }
+
+  // When casting from an index type, we must ensure the results respect
+  // `cast_t(value) == cast_t(trunc32(value))`.
+  auto intType = cast<IntegerType>(type);
+  unsigned width = intType.getWidth();
+
+  // If the result type is at most 32 bits, then the cast can always be folded
+  // because it is always a truncation.
+  if (width <= 32) {
+    APInt result = value.trunc(width);
+    return IntegerAttr::get(type, result);
+  }
+
+  // If the result type is at least 64 bits, then the cast is always a
+  // extension. The results will 
diff er if `trunc32(value) != value)`.
+  if (width >= 64) {
+    if (extFn(value.trunc(32), 64) != value)
+      return {};
+    APInt result = extFn(value, width);
+    return IntegerAttr::get(type, result);
+  }
+
+  // Otherwise, we just have to check the property directly.
+  APInt result = value.trunc(width);
+  if (result != extFn(value.trunc(32), width))
+    return {};
+  return IntegerAttr::get(type, result);
+}
+
 bool CastSOp::areCastCompatible(TypeRange lhsTypes, TypeRange rhsTypes) {
   return llvm::isa<IndexType>(lhsTypes.front()) !=
          llvm::isa<IndexType>(rhsTypes.front());
 }
 
+OpFoldResult CastSOp::fold(FoldAdaptor adaptor) {
+  return foldCastOp(
+      adaptor.getInput(), getType(),
+      [](const APInt &x, unsigned width) { return x.sext(width); },
+      [](const APInt &x, unsigned width) { return x.sextOrTrunc(width); });
+}
+
 //===----------------------------------------------------------------------===//
 // CastUOp
 //===----------------------------------------------------------------------===//
@@ -458,6 +510,13 @@ bool CastUOp::areCastCompatible(TypeRange lhsTypes, TypeRange rhsTypes) {
          llvm::isa<IndexType>(rhsTypes.front());
 }
 
+OpFoldResult CastUOp::fold(FoldAdaptor adaptor) {
+  return foldCastOp(
+      adaptor.getInput(), getType(),
+      [](const APInt &x, unsigned width) { return x.zext(width); },
+      [](const APInt &x, unsigned width) { return x.zextOrTrunc(width); });
+}
+
 //===----------------------------------------------------------------------===//
 // CmpOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Index/index-canonicalize.mlir b/mlir/test/Dialect/Index/index-canonicalize.mlir
index 67308ffbe55ac6d..db03505350b77e3 100644
--- a/mlir/test/Dialect/Index/index-canonicalize.mlir
+++ b/mlir/test/Dialect/Index/index-canonicalize.mlir
@@ -556,3 +556,19 @@ func.func @sub_identity(%arg0: index) -> index {
   // CHECK-NEXT: return %arg0
   return %0 : index
 }
+
+// CHECK-LABEL: @castu_to_index
+func.func @castu_to_index() -> index {
+  // CHECK: index.constant 8000000000000
+  %0 = arith.constant 8000000000000 : i48
+  %1 = index.castu %0 : i48 to index
+  return %1 : index
+}
+
+// CHECK-LABEL: @casts_to_index
+func.func @casts_to_index() -> index {
+  // CHECK: index.constant -1000
+  %0 = arith.constant -1000 : i48
+  %1 = index.casts %0 : i48 to index
+  return %1 : index
+}

diff  --git a/mlir/unittests/Dialect/CMakeLists.txt b/mlir/unittests/Dialect/CMakeLists.txt
index 522aeca29146d17..2d2835c64b9844f 100644
--- a/mlir/unittests/Dialect/CMakeLists.txt
+++ b/mlir/unittests/Dialect/CMakeLists.txt
@@ -6,6 +6,7 @@ target_link_libraries(MLIRDialectTests
   MLIRIR
   MLIRDialect)
 
+add_subdirectory(Index)
 add_subdirectory(LLVMIR)
 add_subdirectory(MemRef)
 add_subdirectory(SparseTensor)

diff  --git a/mlir/unittests/Dialect/Index/CMakeLists.txt b/mlir/unittests/Dialect/Index/CMakeLists.txt
new file mode 100644
index 000000000000000..c4bac2371e52fb3
--- /dev/null
+++ b/mlir/unittests/Dialect/Index/CMakeLists.txt
@@ -0,0 +1,7 @@
+add_mlir_unittest(MLIRIndexOpsTests
+  IndexOpsFoldersTest.cpp
+)
+target_link_libraries(MLIRIndexOpsTests
+  PRIVATE
+  MLIRIndexDialect
+)

diff  --git a/mlir/unittests/Dialect/Index/IndexOpsFoldersTest.cpp b/mlir/unittests/Dialect/Index/IndexOpsFoldersTest.cpp
new file mode 100644
index 000000000000000..948033ddb5934a6
--- /dev/null
+++ b/mlir/unittests/Dialect/Index/IndexOpsFoldersTest.cpp
@@ -0,0 +1,104 @@
+//===- IndexOpsFoldersTest.cpp - unit tests for index op folders ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Index/IR/IndexDialect.h"
+#include "mlir/Dialect/Index/IR/IndexOps.h"
+#include "mlir/IR/OwningOpRef.h"
+#include "gtest/gtest.h"
+
+using namespace mlir;
+
+namespace {
+/// Test fixture for testing operation folders.
+class IndexFolderTest : public testing::Test {
+public:
+  IndexFolderTest() { ctx.getOrLoadDialect<index::IndexDialect>(); }
+
+  /// Instantiate an operation, invoke its folder, and return the attribute
+  /// result.
+  template <typename OpT>
+  void foldOp(IntegerAttr &value, Type type, ArrayRef<Attribute> operands);
+
+protected:
+  /// The MLIR context to use.
+  MLIRContext ctx;
+  /// A builder to use.
+  OpBuilder b{&ctx};
+};
+} // namespace
+
+template <typename OpT>
+void IndexFolderTest::foldOp(IntegerAttr &value, Type type,
+                             ArrayRef<Attribute> operands) {
+  // This function returns null so that `ASSERT_*` works within it.
+  OperationState state(UnknownLoc::get(&ctx), OpT::getOperationName());
+  state.addTypes(type);
+  OwningOpRef<OpT> op = cast<OpT>(b.create(state));
+  SmallVector<OpFoldResult> results;
+  LogicalResult result = op->getOperation()->fold(operands, results);
+  // Propagate the failure to the test.
+  if (failed(result)) {
+    value = nullptr;
+    return;
+  }
+  ASSERT_EQ(results.size(), 1u);
+  value = dyn_cast_or_null<IntegerAttr>(dyn_cast<Attribute>(results.front()));
+  ASSERT_TRUE(value);
+}
+
+TEST_F(IndexFolderTest, TestCastUOpFolder) {
+  IntegerAttr value;
+  auto fold = [&](Type type, Attribute input) {
+    foldOp<index::CastUOp>(value, type, input);
+  };
+
+  // Target width less than or equal to 32 bits.
+  fold(b.getIntegerType(16), b.getIndexAttr(8000000000));
+  ASSERT_TRUE(value);
+  EXPECT_EQ(value.getInt(), 20480u);
+
+  // Target width greater than or equal to 64 bits.
+  fold(b.getIntegerType(64), b.getIndexAttr(2000));
+  ASSERT_TRUE(value);
+  EXPECT_EQ(value.getInt(), 2000u);
+
+  // Fails to fold, because truncating to 32 bits and then extending creates a
+  // 
diff erent value.
+  fold(b.getIntegerType(64), b.getIndexAttr(8000000000));
+  EXPECT_FALSE(value);
+
+  // Target width between 32 and 64 bits.
+  fold(b.getIntegerType(40), b.getIndexAttr(0x10000000010000));
+  // Fold succeeds because the upper bits are truncated in the cast.
+  ASSERT_TRUE(value);
+  EXPECT_EQ(value.getInt(), 65536);
+
+  // Fails to fold because the upper bits are not truncated.
+  fold(b.getIntegerType(60), b.getIndexAttr(0x10000000010000));
+  EXPECT_FALSE(value);
+}
+
+TEST_F(IndexFolderTest, TestCastSOpFolder) {
+  IntegerAttr value;
+  auto fold = [&](Type type, Attribute input) {
+    foldOp<index::CastSOp>(value, type, input);
+  };
+
+  // Just test the extension cases to ensure signs are being respected.
+
+  // Target width greater than or equal to 64 bits.
+  fold(b.getIntegerType(64), b.getIndexAttr(-2000));
+  ASSERT_TRUE(value);
+  EXPECT_EQ(value.getInt(), -2000);
+
+  // Target width between 32 and 64 bits.
+  fold(b.getIntegerType(40), b.getIndexAttr(-0x10000000010000));
+  // Fold succeeds because the upper bits are truncated in the cast.
+  ASSERT_TRUE(value);
+  EXPECT_EQ(value.getInt(), -65536);
+}


        


More information about the Mlir-commits mailing list