[Mlir-commits] [mlir] [mlir][arith] Support bitcast with index type (PR #121455)

Jianjian Guan llvmlistbot at llvm.org
Wed Jan 1 23:31:42 PST 2025


https://github.com/jacquesguan created https://github.com/llvm/llvm-project/pull/121455

Use kInternalStorageBitWidth as the bit width of index type.

>From 0ca3666d385c2e27a35106ed7f6c021cef1c90b7 Mon Sep 17 00:00:00 2001
From: Jianjian GUAN <jacquesguan at me.com>
Date: Thu, 2 Jan 2025 15:18:34 +0800
Subject: [PATCH] [mlir][arith] Support bitcast with index type

Use kInternalStorageBitWidth as the bit width of index type.
---
 mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 13 ++++++++++++-
 mlir/test/Dialect/Arith/ops.mlir       |  6 ++++++
 2 files changed, 18 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index d8b314a3fa43c0..6c4aee3aad94fe 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1723,7 +1723,18 @@ bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
   if (!srcType || !dstType)
     return false;
 
-  return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth();
+  unsigned srcWidth, dstWidth;
+  if (auto indexTy = dyn_cast<IndexType>(srcType))
+    srcWidth = IndexType::kInternalStorageBitWidth;
+  else
+    srcWidth = srcType.getIntOrFloatBitWidth();
+
+  if (auto indexTy = dyn_cast<IndexType>(dstType))
+    dstWidth = IndexType::kInternalStorageBitWidth;
+  else
+    dstWidth = dstType.getIntOrFloatBitWidth();
+
+  return srcWidth == dstWidth;
 }
 
 OpFoldResult arith::BitcastOp::fold(FoldAdaptor adaptor) {
diff --git a/mlir/test/Dialect/Arith/ops.mlir b/mlir/test/Dialect/Arith/ops.mlir
index f684e02344a517..46cb1993a3b789 100644
--- a/mlir/test/Dialect/Arith/ops.mlir
+++ b/mlir/test/Dialect/Arith/ops.mlir
@@ -954,6 +954,12 @@ func.func @test_bitcast_scalable_vector1(%arg0 : vector<[8]xf32>) -> vector<[8]x
   return %0 : vector<[8]xi32>
 }
 
+// CHECK-LABEL: test_bitcast_index
+func.func @test_bitcast_index(%arg0 : i64) -> index {
+  %0 = arith.bitcast %arg0 : i64 to index
+  return %0 : index
+}
+
 // CHECK-LABEL: test_cmpi
 func.func @test_cmpi(%arg0 : i64, %arg1 : i64) -> i1 {
   %0 = arith.cmpi ne, %arg0, %arg1 : i64



More information about the Mlir-commits mailing list