[Mlir-commits] [mlir] [mlir][arith][tensor] Disable index type for bitcast (PR #121455)
Jianjian Guan
llvmlistbot at llvm.org
Mon Jan 6 22:18:54 PST 2025
https://github.com/jacquesguan updated https://github.com/llvm/llvm-project/pull/121455
>From 0ca3666d385c2e27a35106ed7f6c021cef1c90b7 Mon Sep 17 00:00:00 2001
From: Jianjian GUAN <jacquesguan at me.com>
Date: Thu, 2 Jan 2025 15:18:34 +0800
Subject: [PATCH 1/3] [mlir][arith] Support bitcast with index type
Use kInternalStorageBitWidth as the bit width of index type.
---
mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 13 ++++++++++++-
mlir/test/Dialect/Arith/ops.mlir | 6 ++++++
2 files changed, 18 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index d8b314a3fa43c0..6c4aee3aad94fe 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1723,7 +1723,18 @@ bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
if (!srcType || !dstType)
return false;
- return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth();
+ unsigned srcWidth, dstWidth;
+ if (auto indexTy = dyn_cast<IndexType>(srcType))
+ srcWidth = IndexType::kInternalStorageBitWidth;
+ else
+ srcWidth = srcType.getIntOrFloatBitWidth();
+
+ if (auto indexTy = dyn_cast<IndexType>(dstType))
+ dstWidth = IndexType::kInternalStorageBitWidth;
+ else
+ dstWidth = dstType.getIntOrFloatBitWidth();
+
+ return srcWidth == dstWidth;
}
OpFoldResult arith::BitcastOp::fold(FoldAdaptor adaptor) {
diff --git a/mlir/test/Dialect/Arith/ops.mlir b/mlir/test/Dialect/Arith/ops.mlir
index f684e02344a517..46cb1993a3b789 100644
--- a/mlir/test/Dialect/Arith/ops.mlir
+++ b/mlir/test/Dialect/Arith/ops.mlir
@@ -954,6 +954,12 @@ func.func @test_bitcast_scalable_vector1(%arg0 : vector<[8]xf32>) -> vector<[8]x
return %0 : vector<[8]xi32>
}
+// CHECK-LABEL: test_bitcast_index
+func.func @test_bitcast_index(%arg0 : i64) -> index {
+ %0 = arith.bitcast %arg0 : i64 to index
+ return %0 : index
+}
+
// CHECK-LABEL: test_cmpi
func.func @test_cmpi(%arg0 : i64, %arg1 : i64) -> i1 {
%0 = arith.cmpi ne, %arg0, %arg1 : i64
>From e0f55da77f8d0596150d4979618ccc456540e61d Mon Sep 17 00:00:00 2001
From: Jianjian GUAN <jacquesguan at me.com>
Date: Sat, 4 Jan 2025 15:22:01 +0800
Subject: [PATCH 2/3] disable index type for bitcast
---
.../include/mlir/Dialect/Arith/IR/ArithOps.td | 5 ++---
mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 19 +++----------------
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 4 ++++
mlir/test/Dialect/Arith/invalid.mlir | 16 ++++++++++++++++
mlir/test/Dialect/Arith/ops.mlir | 6 ------
mlir/test/Dialect/Tensor/invalid.mlir | 16 ++++++++++++++++
6 files changed, 41 insertions(+), 25 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 0722ff68d890de..80b90f2ae480da 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1392,11 +1392,10 @@ def Arith_IndexCastUIOp
// BitcastOp
//===----------------------------------------------------------------------===//
-// Bitcast can convert between memrefs of signless integers, indices, and
-// floats too.
+// Bitcast can convert between memrefs of signless integers and floats.
def BitcastTypeConstraint : TypeConstraint<Or<[
SignlessIntegerOrFloatLike.predicate,
- MemRefOf<[AnySignlessInteger, Index, AnyFloat]>.predicate]>,
+ MemRefOf<[AnySignlessInteger, AnyFloat]>.predicate]>,
"signless-integer-or-float-like or memref of signless-integer or float">;
def Arith_BitcastOp : Arith_CastOp<"bitcast", BitcastTypeConstraint,
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 6c4aee3aad94fe..21450c3cf2a2dc 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1716,25 +1716,12 @@ bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
if (!areValidCastInputsAndOutputs(inputs, outputs))
return false;
- auto srcType =
- getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(inputs.front());
- auto dstType =
- getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(outputs.front());
+ auto srcType = getTypeIfLikeOrMemRef<IntegerType, FloatType>(inputs.front());
+ auto dstType = getTypeIfLikeOrMemRef<IntegerType, FloatType>(outputs.front());
if (!srcType || !dstType)
return false;
- unsigned srcWidth, dstWidth;
- if (auto indexTy = dyn_cast<IndexType>(srcType))
- srcWidth = IndexType::kInternalStorageBitWidth;
- else
- srcWidth = srcType.getIntOrFloatBitWidth();
-
- if (auto indexTy = dyn_cast<IndexType>(dstType))
- dstWidth = IndexType::kInternalStorageBitWidth;
- else
- dstWidth = dstType.getIntOrFloatBitWidth();
-
- return srcWidth == dstWidth;
+ return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth();
}
OpFoldResult arith::BitcastOp::fold(FoldAdaptor adaptor) {
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index f79c774ceb3e9a..906afe2fa3358f 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -219,6 +219,10 @@ bool BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
if (!aT || !bT)
return false;
+ if (isa<IndexType>(aT.getElementType()) ||
+ isa<IndexType>(bT.getElementType()))
+ return false;
+
if (aT.getElementTypeBitWidth() != bT.getElementTypeBitWidth())
return false;
diff --git a/mlir/test/Dialect/Arith/invalid.mlir b/mlir/test/Dialect/Arith/invalid.mlir
index 088da475e8eb4c..54c82f3802ced6 100644
--- a/mlir/test/Dialect/Arith/invalid.mlir
+++ b/mlir/test/Dialect/Arith/invalid.mlir
@@ -853,3 +853,19 @@ func.func @select_tensor_encoding(
%0 = arith.select %arg0, %arg1, %arg2 : tensor<8xi1, "bar">, tensor<8xi32, "foo">
return %0 : tensor<8xi32, "foo">
}
+
+// -----
+
+func.func @bitcast_index_0(%arg0 : i64) -> index {
+ // expected-error @+1 {{'arith.bitcast' op operand type 'i64' and result type 'index' are cast incompatible}}
+ %0 = arith.bitcast %arg0 : i64 to index
+ return %0 : index
+}
+
+// -----
+
+func.func @bitcast_index_1(%arg0 : index) -> i64 {
+ // expected-error @+1 {{'arith.bitcast' op operand type 'index' and result type 'i64' are cast incompatible}}
+ %0 = arith.bitcast %arg0 : index to i64
+ return %0 : i64
+}
diff --git a/mlir/test/Dialect/Arith/ops.mlir b/mlir/test/Dialect/Arith/ops.mlir
index 46cb1993a3b789..f684e02344a517 100644
--- a/mlir/test/Dialect/Arith/ops.mlir
+++ b/mlir/test/Dialect/Arith/ops.mlir
@@ -954,12 +954,6 @@ func.func @test_bitcast_scalable_vector1(%arg0 : vector<[8]xf32>) -> vector<[8]x
return %0 : vector<[8]xi32>
}
-// CHECK-LABEL: test_bitcast_index
-func.func @test_bitcast_index(%arg0 : i64) -> index {
- %0 = arith.bitcast %arg0 : i64 to index
- return %0 : index
-}
-
// CHECK-LABEL: test_cmpi
func.func @test_cmpi(%arg0 : i64, %arg1 : i64) -> i1 {
%0 = arith.cmpi ne, %arg0, %arg1 : i64
diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir
index 83cb4b9d4ab247..d81a8f345ce1b1 100644
--- a/mlir/test/Dialect/Tensor/invalid.mlir
+++ b/mlir/test/Dialect/Tensor/invalid.mlir
@@ -807,3 +807,19 @@ func.func @unpack_static_inner_tile_size_and_dynamic_output_shape(
%0 = tensor.unpack %input inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %output : tensor<?x?x?x4xf32> -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
+
+// -----
+
+func.func @bitcast_index_0(%arg0 : tensor<?xi64>) -> tensor<?xindex> {
+ // expected-error @+1 {{'tensor.bitcast' op operand type 'tensor<?xi64>' and result type 'tensor<?xindex>' are cast incompatible}}
+ %0 = tensor.bitcast %arg0 : tensor<?xi64> to tensor<?xindex>
+ return %0 : tensor<?xindex>
+}
+
+// -----
+
+func.func @bitcast_index_1(%arg0 : tensor<?xindex>) -> tensor<?xi64> {
+ // expected-error @+1 {{'tensor.bitcast' op operand type 'tensor<?xindex>' and result type 'tensor<?xi64>' are cast incompatible}}
+ %0 = tensor.bitcast %arg0 : tensor<?xindex> to tensor<?xi64>
+ return %0 : tensor<?xi64>
+}
>From 1a5f91cad2e695fb58c2f1e744d24b8a046fd2ff Mon Sep 17 00:00:00 2001
From: Jianjian GUAN <jacquesguan at me.com>
Date: Tue, 7 Jan 2025 14:18:08 +0800
Subject: [PATCH 3/3] change type constraint of bitcast
---
mlir/include/mlir/Dialect/Arith/IR/ArithOps.td | 2 +-
mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td | 6 ++++--
mlir/include/mlir/IR/CommonTypeConstraints.td | 5 +++++
mlir/test/Dialect/Arith/invalid.mlir | 4 ++--
mlir/test/Dialect/Tensor/invalid.mlir | 4 ++--
5 files changed, 14 insertions(+), 7 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 80b90f2ae480da..10d7519e09dbea 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1394,7 +1394,7 @@ def Arith_IndexCastUIOp
// Bitcast can convert between memrefs of signless integers and floats.
def BitcastTypeConstraint : TypeConstraint<Or<[
- SignlessIntegerOrFloatLike.predicate,
+ SignlessInteger.predicate, FloatLike.predicate,
MemRefOf<[AnySignlessInteger, AnyFloat]>.predicate]>,
"signless-integer-or-float-like or memref of signless-integer or float">;
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 812ac209845020..8ad1b23cb2bfe2 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -75,8 +75,10 @@ def Tensor_BitcastOp : Tensor_Op<"bitcast", [
```
}];
- let arguments = (ins AnyTensor:$source);
- let results = (outs AnyTensor:$dest);
+ let arguments = (ins TensorOf<[AnySignlessInteger, AnyUnsignedInteger,
+ AnySignedInteger, AnyFloat]>:$source);
+ let results = (outs TensorOf<[AnySignlessInteger, AnyUnsignedInteger,
+ AnySignedInteger, AnyFloat]>:$dest);
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
let hasCanonicalizer = 1;
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index b9f8c1ed19470d..74eb5ab4c2b5fe 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -908,6 +908,11 @@ def BoolLike : TypeOrContainer<I1, "bool-like">;
def BoolLikeOfAnyRank : TypeOrContainerOfAnyRank<I1, "bool-like">;
+// Type constraint for signless-integer-like types: signless integers,
+// vectors of signless integers or tensors of signless integers.
+def SignlessInteger : TypeOrValueSemanticsContainer<
+ AnySignlessInteger, "signless-integer">;
+
// Type constraint for signless-integer-like types: signless integers, indices,
// vectors of signless integers or indices, tensors of signless integers.
def SignlessIntegerLike : TypeOrValueSemanticsContainer<
diff --git a/mlir/test/Dialect/Arith/invalid.mlir b/mlir/test/Dialect/Arith/invalid.mlir
index 54c82f3802ced6..7bd68372de471e 100644
--- a/mlir/test/Dialect/Arith/invalid.mlir
+++ b/mlir/test/Dialect/Arith/invalid.mlir
@@ -857,7 +857,7 @@ func.func @select_tensor_encoding(
// -----
func.func @bitcast_index_0(%arg0 : i64) -> index {
- // expected-error @+1 {{'arith.bitcast' op operand type 'i64' and result type 'index' are cast incompatible}}
+ // expected-error @+1 {{'arith.bitcast' op result #0 must be signless-integer-or-float-like or memref of signless-integer or float, but got 'index'}}
%0 = arith.bitcast %arg0 : i64 to index
return %0 : index
}
@@ -865,7 +865,7 @@ func.func @bitcast_index_0(%arg0 : i64) -> index {
// -----
func.func @bitcast_index_1(%arg0 : index) -> i64 {
- // expected-error @+1 {{'arith.bitcast' op operand type 'index' and result type 'i64' are cast incompatible}}
+ // expected-error @+1 {{'arith.bitcast' op operand #0 must be signless-integer-or-float-like or memref of signless-integer or float, but got 'index'}}
%0 = arith.bitcast %arg0 : index to i64
return %0 : i64
}
diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir
index d81a8f345ce1b1..69608617b22260 100644
--- a/mlir/test/Dialect/Tensor/invalid.mlir
+++ b/mlir/test/Dialect/Tensor/invalid.mlir
@@ -811,7 +811,7 @@ func.func @unpack_static_inner_tile_size_and_dynamic_output_shape(
// -----
func.func @bitcast_index_0(%arg0 : tensor<?xi64>) -> tensor<?xindex> {
- // expected-error @+1 {{'tensor.bitcast' op operand type 'tensor<?xi64>' and result type 'tensor<?xindex>' are cast incompatible}}
+ // expected-error @+1 {{'tensor.bitcast' op result #0 must be tensor of signless integer or unsigned integer or signed integer or floating-point values, but got 'tensor<?xindex>'}}
%0 = tensor.bitcast %arg0 : tensor<?xi64> to tensor<?xindex>
return %0 : tensor<?xindex>
}
@@ -819,7 +819,7 @@ func.func @bitcast_index_0(%arg0 : tensor<?xi64>) -> tensor<?xindex> {
// -----
func.func @bitcast_index_1(%arg0 : tensor<?xindex>) -> tensor<?xi64> {
- // expected-error @+1 {{'tensor.bitcast' op operand type 'tensor<?xindex>' and result type 'tensor<?xi64>' are cast incompatible}}
+ // expected-error @+1 {{'tensor.bitcast' op operand #0 must be tensor of signless integer or unsigned integer or signed integer or floating-point values, but got 'tensor<?xindex>'}}
%0 = tensor.bitcast %arg0 : tensor<?xindex> to tensor<?xi64>
return %0 : tensor<?xi64>
}
More information about the Mlir-commits
mailing list