[Mlir-commits] [mlir] [mlir][arith] Support bitcast with index type (PR #121455)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jan 1 23:32:15 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-arith

Author: Jianjian Guan (jacquesguan)

<details>
<summary>Changes</summary>

Use kInternalStorageBitWidth as the bit width of index type.

---
Full diff: https://github.com/llvm/llvm-project/pull/121455.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Arith/IR/ArithOps.cpp (+12-1) 
- (modified) mlir/test/Dialect/Arith/ops.mlir (+6) 


``````````diff
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index d8b314a3fa43c0..6c4aee3aad94fe 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1723,7 +1723,18 @@ bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
   if (!srcType || !dstType)
     return false;
 
-  return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth();
+  unsigned srcWidth, dstWidth;
+  if (auto indexTy = dyn_cast<IndexType>(srcType))
+    srcWidth = IndexType::kInternalStorageBitWidth;
+  else
+    srcWidth = srcType.getIntOrFloatBitWidth();
+
+  if (auto indexTy = dyn_cast<IndexType>(dstType))
+    dstWidth = IndexType::kInternalStorageBitWidth;
+  else
+    dstWidth = dstType.getIntOrFloatBitWidth();
+
+  return srcWidth == dstWidth;
 }
 
 OpFoldResult arith::BitcastOp::fold(FoldAdaptor adaptor) {
diff --git a/mlir/test/Dialect/Arith/ops.mlir b/mlir/test/Dialect/Arith/ops.mlir
index f684e02344a517..46cb1993a3b789 100644
--- a/mlir/test/Dialect/Arith/ops.mlir
+++ b/mlir/test/Dialect/Arith/ops.mlir
@@ -954,6 +954,12 @@ func.func @test_bitcast_scalable_vector1(%arg0 : vector<[8]xf32>) -> vector<[8]x
   return %0 : vector<[8]xi32>
 }
 
+// CHECK-LABEL: test_bitcast_index
+func.func @test_bitcast_index(%arg0 : i64) -> index {
+  %0 = arith.bitcast %arg0 : i64 to index
+  return %0 : index
+}
+
 // CHECK-LABEL: test_cmpi
 func.func @test_cmpi(%arg0 : i64, %arg1 : i64) -> i1 {
   %0 = arith.cmpi ne, %arg0, %arg1 : i64

``````````

</details>


https://github.com/llvm/llvm-project/pull/121455


More information about the Mlir-commits mailing list