[flang-commits] [flang] [mlir] [mlir][arith][tensor] Disable index type for bitcast (PR #121455)

Jianjian Guan via flang-commits flang-commits at lists.llvm.org
Fri Jan 24 00:11:15 PST 2025


https://github.com/jacquesguan updated https://github.com/llvm/llvm-project/pull/121455

>From a122eeec7560cd4a0c0a830cbd28d99a7090756e Mon Sep 17 00:00:00 2001
From: Jianjian GUAN <jacquesguan at me.com>
Date: Thu, 2 Jan 2025 15:18:34 +0800
Subject: [PATCH 1/5] [mlir][arith] Support bitcast with index type

Use kInternalStorageBitWidth as the bit width of index type.
---
 mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 13 ++++++++++++-
 mlir/test/Dialect/Arith/ops.mlir       |  6 ++++++
 2 files changed, 18 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index e016a6e16e59ffd..e16eefd32212fa8 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1747,7 +1747,18 @@ bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
   if (!srcType || !dstType)
     return false;
 
-  return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth();
+  unsigned srcWidth, dstWidth;
+  if (auto indexTy = dyn_cast<IndexType>(srcType))
+    srcWidth = IndexType::kInternalStorageBitWidth;
+  else
+    srcWidth = srcType.getIntOrFloatBitWidth();
+
+  if (auto indexTy = dyn_cast<IndexType>(dstType))
+    dstWidth = IndexType::kInternalStorageBitWidth;
+  else
+    dstWidth = dstType.getIntOrFloatBitWidth();
+
+  return srcWidth == dstWidth;
 }
 
 OpFoldResult arith::BitcastOp::fold(FoldAdaptor adaptor) {
diff --git a/mlir/test/Dialect/Arith/ops.mlir b/mlir/test/Dialect/Arith/ops.mlir
index f684e02344a517e..46cb1993a3b7896 100644
--- a/mlir/test/Dialect/Arith/ops.mlir
+++ b/mlir/test/Dialect/Arith/ops.mlir
@@ -954,6 +954,12 @@ func.func @test_bitcast_scalable_vector1(%arg0 : vector<[8]xf32>) -> vector<[8]x
   return %0 : vector<[8]xi32>
 }
 
+// CHECK-LABEL: test_bitcast_index
+func.func @test_bitcast_index(%arg0 : i64) -> index {
+  %0 = arith.bitcast %arg0 : i64 to index
+  return %0 : index
+}
+
 // CHECK-LABEL: test_cmpi
 func.func @test_cmpi(%arg0 : i64, %arg1 : i64) -> i1 {
   %0 = arith.cmpi ne, %arg0, %arg1 : i64

>From 4243d0377d20e94183204de23ba50fe0bf0b591f Mon Sep 17 00:00:00 2001
From: Jianjian GUAN <jacquesguan at me.com>
Date: Sat, 4 Jan 2025 15:22:01 +0800
Subject: [PATCH 2/5] disable index type for bitcast

---
 .../include/mlir/Dialect/Arith/IR/ArithOps.td |  5 ++---
 mlir/lib/Dialect/Arith/IR/ArithOps.cpp        | 19 +++----------------
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp      |  4 ++++
 mlir/test/Dialect/Arith/invalid.mlir          | 16 ++++++++++++++++
 mlir/test/Dialect/Arith/ops.mlir              |  6 ------
 mlir/test/Dialect/Tensor/invalid.mlir         | 16 ++++++++++++++++
 6 files changed, 41 insertions(+), 25 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 0722ff68d890ded..80b90f2ae480da1 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1392,11 +1392,10 @@ def Arith_IndexCastUIOp
 // BitcastOp
 //===----------------------------------------------------------------------===//
 
-// Bitcast can convert between memrefs of signless integers, indices, and
-// floats too.
+// Bitcast can convert between memrefs of signless integers and floats.
 def BitcastTypeConstraint : TypeConstraint<Or<[
         SignlessIntegerOrFloatLike.predicate,
-        MemRefOf<[AnySignlessInteger, Index, AnyFloat]>.predicate]>,
+        MemRefOf<[AnySignlessInteger, AnyFloat]>.predicate]>,
     "signless-integer-or-float-like or memref of signless-integer or float">;
 
 def Arith_BitcastOp : Arith_CastOp<"bitcast", BitcastTypeConstraint,
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index e16eefd32212fa8..7ca104691e6df62 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1740,25 +1740,12 @@ bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
   if (!areValidCastInputsAndOutputs(inputs, outputs))
     return false;
 
-  auto srcType =
-      getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(inputs.front());
-  auto dstType =
-      getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(outputs.front());
+  auto srcType = getTypeIfLikeOrMemRef<IntegerType, FloatType>(inputs.front());
+  auto dstType = getTypeIfLikeOrMemRef<IntegerType, FloatType>(outputs.front());
   if (!srcType || !dstType)
     return false;
 
-  unsigned srcWidth, dstWidth;
-  if (auto indexTy = dyn_cast<IndexType>(srcType))
-    srcWidth = IndexType::kInternalStorageBitWidth;
-  else
-    srcWidth = srcType.getIntOrFloatBitWidth();
-
-  if (auto indexTy = dyn_cast<IndexType>(dstType))
-    dstWidth = IndexType::kInternalStorageBitWidth;
-  else
-    dstWidth = dstType.getIntOrFloatBitWidth();
-
-  return srcWidth == dstWidth;
+  return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth();
 }
 
 OpFoldResult arith::BitcastOp::fold(FoldAdaptor adaptor) {
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 24a1d5531531981..9cebb5534ebddb9 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -219,6 +219,10 @@ bool BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
   if (!aT || !bT)
     return false;
 
+  if (isa<IndexType>(aT.getElementType()) ||
+      isa<IndexType>(bT.getElementType()))
+    return false;
+
   if (aT.getElementTypeBitWidth() != bT.getElementTypeBitWidth())
     return false;
 
diff --git a/mlir/test/Dialect/Arith/invalid.mlir b/mlir/test/Dialect/Arith/invalid.mlir
index 088da475e8eb4ce..54c82f3802ced6e 100644
--- a/mlir/test/Dialect/Arith/invalid.mlir
+++ b/mlir/test/Dialect/Arith/invalid.mlir
@@ -853,3 +853,19 @@ func.func @select_tensor_encoding(
   %0 = arith.select %arg0, %arg1, %arg2 : tensor<8xi1, "bar">, tensor<8xi32, "foo">
   return %0 : tensor<8xi32, "foo">
 }
+
+// -----
+
+func.func @bitcast_index_0(%arg0 : i64) -> index {
+  // expected-error @+1 {{'arith.bitcast' op operand type 'i64' and result type 'index' are cast incompatible}}
+  %0 = arith.bitcast %arg0 : i64 to index
+  return %0 : index
+}
+
+// -----
+
+func.func @bitcast_index_1(%arg0 : index) -> i64 {
+  // expected-error @+1 {{'arith.bitcast' op operand type 'index' and result type 'i64' are cast incompatible}}
+  %0 = arith.bitcast %arg0 : index to i64
+  return %0 : i64
+}
diff --git a/mlir/test/Dialect/Arith/ops.mlir b/mlir/test/Dialect/Arith/ops.mlir
index 46cb1993a3b7896..f684e02344a517e 100644
--- a/mlir/test/Dialect/Arith/ops.mlir
+++ b/mlir/test/Dialect/Arith/ops.mlir
@@ -954,12 +954,6 @@ func.func @test_bitcast_scalable_vector1(%arg0 : vector<[8]xf32>) -> vector<[8]x
   return %0 : vector<[8]xi32>
 }
 
-// CHECK-LABEL: test_bitcast_index
-func.func @test_bitcast_index(%arg0 : i64) -> index {
-  %0 = arith.bitcast %arg0 : i64 to index
-  return %0 : index
-}
-
 // CHECK-LABEL: test_cmpi
 func.func @test_cmpi(%arg0 : i64, %arg1 : i64) -> i1 {
   %0 = arith.cmpi ne, %arg0, %arg1 : i64
diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir
index 1de3e281bc462b3..23c1f5360d361db 100644
--- a/mlir/test/Dialect/Tensor/invalid.mlir
+++ b/mlir/test/Dialect/Tensor/invalid.mlir
@@ -807,3 +807,19 @@ func.func @unpack_static_inner_tile_size_and_dynamic_output_shape(
   %0 = tensor.unpack %input inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %output : tensor<?x?x?x4xf32> -> tensor<?x?xf32>
   return %0 : tensor<?x?xf32>
 }
+
+// -----
+
+func.func @bitcast_index_0(%arg0 : tensor<?xi64>) -> tensor<?xindex> {
+  // expected-error @+1 {{'tensor.bitcast' op operand type 'tensor<?xi64>' and result type 'tensor<?xindex>' are cast incompatible}}
+  %0 = tensor.bitcast %arg0 : tensor<?xi64> to tensor<?xindex>
+  return %0 : tensor<?xindex>
+}
+
+// -----
+
+func.func @bitcast_index_1(%arg0 : tensor<?xindex>) -> tensor<?xi64> {
+  // expected-error @+1 {{'tensor.bitcast' op operand type 'tensor<?xindex>' and result type 'tensor<?xi64>' are cast incompatible}}
+  %0 = tensor.bitcast %arg0 : tensor<?xindex> to tensor<?xi64>
+  return %0 : tensor<?xi64>
+}

>From 31a53225b96dac17363d6156a1f33f115fa1672e Mon Sep 17 00:00:00 2001
From: Jianjian GUAN <jacquesguan at me.com>
Date: Tue, 7 Jan 2025 14:18:08 +0800
Subject: [PATCH 3/5] change type constraint of bitcast

---
 mlir/include/mlir/Dialect/Arith/IR/ArithOps.td   | 2 +-
 mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td | 6 ++++--
 mlir/include/mlir/IR/CommonTypeConstraints.td    | 5 +++++
 mlir/test/Dialect/Arith/invalid.mlir             | 4 ++--
 mlir/test/Dialect/Tensor/invalid.mlir            | 4 ++--
 5 files changed, 14 insertions(+), 7 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 80b90f2ae480da1..10d7519e09dbeab 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1394,7 +1394,7 @@ def Arith_IndexCastUIOp
 
 // Bitcast can convert between memrefs of signless integers and floats.
 def BitcastTypeConstraint : TypeConstraint<Or<[
-        SignlessIntegerOrFloatLike.predicate,
+        SignlessInteger.predicate, FloatLike.predicate,
         MemRefOf<[AnySignlessInteger, AnyFloat]>.predicate]>,
     "signless-integer-or-float-like or memref of signless-integer or float">;
 
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 812ac2098450204..8ad1b23cb2bfe27 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -75,8 +75,10 @@ def Tensor_BitcastOp : Tensor_Op<"bitcast", [
     ```
   }];
 
-  let arguments = (ins AnyTensor:$source);
-  let results = (outs AnyTensor:$dest);
+  let arguments = (ins TensorOf<[AnySignlessInteger, AnyUnsignedInteger, 
+                                 AnySignedInteger, AnyFloat]>:$source);
+  let results = (outs TensorOf<[AnySignlessInteger, AnyUnsignedInteger,
+                                AnySignedInteger, AnyFloat]>:$dest);
   let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
 
   let hasCanonicalizer = 1;
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index e59291030356861..38bff642630fee8 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -908,6 +908,11 @@ def BoolLike : TypeOrContainer<I1, "bool-like">;
 
 def BoolLikeOfAnyRank : TypeOrContainerOfAnyRank<I1, "bool-like">;
 
+// Type constraint for signless-integer-like types: signless integers, 
+// vectors of signless integers or tensors of signless integers.
+def SignlessInteger : TypeOrValueSemanticsContainer<
+    AnySignlessInteger, "signless-integer">;
+
 // Type constraint for signless-integer-like types: signless integers, indices,
 // vectors of signless integers or indices, tensors of signless integers.
 def SignlessIntegerLike : TypeOrValueSemanticsContainer<
diff --git a/mlir/test/Dialect/Arith/invalid.mlir b/mlir/test/Dialect/Arith/invalid.mlir
index 54c82f3802ced6e..7bd68372de471e1 100644
--- a/mlir/test/Dialect/Arith/invalid.mlir
+++ b/mlir/test/Dialect/Arith/invalid.mlir
@@ -857,7 +857,7 @@ func.func @select_tensor_encoding(
 // -----
 
 func.func @bitcast_index_0(%arg0 : i64) -> index {
-  // expected-error @+1 {{'arith.bitcast' op operand type 'i64' and result type 'index' are cast incompatible}}
+  // expected-error @+1 {{'arith.bitcast' op result #0 must be signless-integer-or-float-like or memref of signless-integer or float, but got 'index'}}
   %0 = arith.bitcast %arg0 : i64 to index
   return %0 : index
 }
@@ -865,7 +865,7 @@ func.func @bitcast_index_0(%arg0 : i64) -> index {
 // -----
 
 func.func @bitcast_index_1(%arg0 : index) -> i64 {
-  // expected-error @+1 {{'arith.bitcast' op operand type 'index' and result type 'i64' are cast incompatible}}
+  // expected-error @+1 {{'arith.bitcast' op operand #0 must be signless-integer-or-float-like or memref of signless-integer or float, but got 'index'}}
   %0 = arith.bitcast %arg0 : index to i64
   return %0 : i64
 }
diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir
index 23c1f5360d361db..0c6d8f4e05c3324 100644
--- a/mlir/test/Dialect/Tensor/invalid.mlir
+++ b/mlir/test/Dialect/Tensor/invalid.mlir
@@ -811,7 +811,7 @@ func.func @unpack_static_inner_tile_size_and_dynamic_output_shape(
 // -----
 
 func.func @bitcast_index_0(%arg0 : tensor<?xi64>) -> tensor<?xindex> {
-  // expected-error @+1 {{'tensor.bitcast' op operand type 'tensor<?xi64>' and result type 'tensor<?xindex>' are cast incompatible}}
+  // expected-error @+1 {{'tensor.bitcast' op result #0 must be tensor of signless integer or unsigned integer or signed integer or floating-point values, but got 'tensor<?xindex>'}}
   %0 = tensor.bitcast %arg0 : tensor<?xi64> to tensor<?xindex>
   return %0 : tensor<?xindex>
 }
@@ -819,7 +819,7 @@ func.func @bitcast_index_0(%arg0 : tensor<?xi64>) -> tensor<?xindex> {
 // -----
 
 func.func @bitcast_index_1(%arg0 : tensor<?xindex>) -> tensor<?xi64> {
-  // expected-error @+1 {{'tensor.bitcast' op operand type 'tensor<?xindex>' and result type 'tensor<?xi64>' are cast incompatible}}
+  // expected-error @+1 {{'tensor.bitcast' op operand #0 must be tensor of signless integer or unsigned integer or signed integer or floating-point values, but got 'tensor<?xindex>'}}
   %0 = tensor.bitcast %arg0 : tensor<?xindex> to tensor<?xi64>
   return %0 : tensor<?xi64>
 }

>From 20981e4668cd09f12476706e7b644d29245aecea Mon Sep 17 00:00:00 2001
From: Jianjian GUAN <jacquesguan at me.com>
Date: Tue, 7 Jan 2025 17:49:45 +0800
Subject: [PATCH 4/5] Address rename comment

---
 .../include/mlir/Dialect/Arith/IR/ArithOps.td | 30 +++++++++----------
 mlir/include/mlir/Dialect/Math/IR/MathOps.td  | 10 +++----
 mlir/include/mlir/IR/CommonTypeConstraints.td | 13 +++++---
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp      |  4 ---
 4 files changed, 29 insertions(+), 28 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 10d7519e09dbeab..ea9b0f6509b80b6 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -51,8 +51,8 @@ class Arith_BinaryOp<string mnemonic, list<Trait> traits = []> :
 class Arith_IntBinaryOp<string mnemonic, list<Trait> traits = []> :
     Arith_BinaryOp<mnemonic, traits #
       [DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]>,
-    Arguments<(ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs)>,
-    Results<(outs SignlessIntegerLike:$result)>;
+    Arguments<(ins SignlessIntegerOrIndexLike:$lhs, SignlessIntegerOrIndexLike:$rhs)>,
+    Results<(outs SignlessIntegerOrIndexLike:$result)>;
 
 // Base class for integer binary operations without undefined behavior.
 class Arith_TotalIntBinaryOp<string mnemonic, list<Trait> traits = []> :
@@ -155,11 +155,11 @@ class Arith_IntBinaryOpWithOverflowFlags<string mnemonic, list<Trait> traits = [
     Arith_BinaryOp<mnemonic, traits #
       [Pure, DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
        DeclareOpInterfaceMethods<ArithIntegerOverflowFlagsInterface>]>,
-    Arguments<(ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs,
+    Arguments<(ins SignlessIntegerOrIndexLike:$lhs, SignlessIntegerOrIndexLike:$rhs,
       DefaultValuedAttr<
         Arith_IntegerOverflowAttr,
         "::mlir::arith::IntegerOverflowFlags::none">:$overflowFlags)>,
-    Results<(outs SignlessIntegerLike:$result)> {
+    Results<(outs SignlessIntegerOrIndexLike:$result)> {
 
   let assemblyFormat = [{ $lhs `,` $rhs (`overflow` `` $overflowFlags^)?
                           attr-dict `:` type($result) }];
@@ -198,7 +198,7 @@ def Arith_ConstantOp : Op<Arith_Dialect, "constant",
   // However, it is necessary to allow arith.constant to return vectors/tensors
   // of strings and signed/unsigned integers (for now) as an artefact of
   // splitting the Standard dialect.
-  let results = (outs /*SignlessIntegerOrFloatLike*/AnyType:$result);
+  let results = (outs /*SignlessIntegerOrIndexOrFloatLike*/AnyType:$result);
 
   let extraClassDeclaration = [{
     /// Whether the constant op can be constructed with a particular value and
@@ -288,8 +288,8 @@ def Arith_AddUIExtendedOp : Arith_Op<"addui_extended", [Pure, Commutative,
     ```
   }];
 
-  let arguments = (ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs);
-  let results = (outs SignlessIntegerLike:$sum, BoolLike:$overflow);
+  let arguments = (ins SignlessIntegerOrIndexLike:$lhs, SignlessIntegerOrIndexLike:$rhs);
+  let results = (outs SignlessIntegerOrIndexLike:$sum, BoolLike:$overflow);
   let assemblyFormat = [{
     $lhs `,` $rhs attr-dict `:` type($sum) `,` type($overflow)
   }];
@@ -429,8 +429,8 @@ def Arith_MulSIExtendedOp : Arith_Op<"mulsi_extended", [Pure, Commutative,
     ```
   }];
 
-  let arguments = (ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs);
-  let results = (outs SignlessIntegerLike:$low, SignlessIntegerLike:$high);
+  let arguments = (ins SignlessIntegerOrIndexLike:$lhs, SignlessIntegerOrIndexLike:$rhs);
+  let results = (outs SignlessIntegerOrIndexLike:$low, SignlessIntegerOrIndexLike:$high);
 
   let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs)";
 
@@ -472,8 +472,8 @@ def Arith_MulUIExtendedOp : Arith_Op<"mului_extended", [Pure, Commutative,
     ```
   }];
 
-  let arguments = (ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs);
-  let results = (outs SignlessIntegerLike:$low, SignlessIntegerLike:$high);
+  let arguments = (ins SignlessIntegerOrIndexLike:$lhs, SignlessIntegerOrIndexLike:$rhs);
+  let results = (outs SignlessIntegerOrIndexLike:$low, SignlessIntegerOrIndexLike:$high);
 
   let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs)";
 
@@ -1350,7 +1350,7 @@ def Arith_FPToSIOp : Arith_FToICastOp<"fptosi"> {
 
 // Index cast can convert between memrefs of signless integers and indices too.
 def IndexCastTypeConstraint : TypeConstraint<Or<[
-        SignlessIntegerLike.predicate,
+        SignlessIntegerOrIndexLike.predicate,
         MemRefOf<[AnySignlessInteger, Index]>.predicate]>,
     "signless-integer-like or memref of signless-integer">;
 
@@ -1394,7 +1394,7 @@ def Arith_IndexCastUIOp
 
 // Bitcast can convert between memrefs of signless integers and floats.
 def BitcastTypeConstraint : TypeConstraint<Or<[
-        SignlessInteger.predicate, FloatLike.predicate,
+        SignlessIntegerOrFloatLike.predicate,
         MemRefOf<[AnySignlessInteger, AnyFloat]>.predicate]>,
     "signless-integer-or-float-like or memref of signless-integer or float">;
 
@@ -1495,8 +1495,8 @@ def Arith_CmpIOp
   }];
 
   let arguments = (ins Arith_CmpIPredicateAttr:$predicate,
-                       SignlessIntegerLikeOfAnyRank:$lhs,
-                       SignlessIntegerLikeOfAnyRank:$rhs);
+                       SignlessIntegerOrIndexLikeOfAnyRank:$lhs,
+                       SignlessIntegerOrIndexLikeOfAnyRank:$rhs);
 
   let hasFolder = 1;
   let hasCanonicalizer = 1;
diff --git a/mlir/include/mlir/Dialect/Math/IR/MathOps.td b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
index 3f6d2d2e44783f8..5990a9f0d2e442b 100644
--- a/mlir/include/mlir/Dialect/Math/IR/MathOps.td
+++ b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
@@ -28,8 +28,8 @@ class Math_Op<string mnemonic, list<Trait> traits = []> :
 // tensor thereof.
 class Math_IntegerUnaryOp<string mnemonic, list<Trait> traits = []> :
     Math_Op<mnemonic, traits # [SameOperandsAndResultType]> {
-  let arguments = (ins SignlessIntegerLike:$operand);
-  let results = (outs SignlessIntegerLike:$result);
+  let arguments = (ins SignlessIntegerOrIndexLike:$operand);
+  let results = (outs SignlessIntegerOrIndexLike:$result);
 
   let assemblyFormat = "$operand attr-dict `:` type($result)";
 }
@@ -55,8 +55,8 @@ class Math_FloatUnaryOp<string mnemonic, list<Trait> traits = []> :
 // type, vector or tensor thereof.
 class Math_IntegerBinaryOp<string mnemonic, list<Trait> traits = []> :
     Math_Op<mnemonic, traits # [SameOperandsAndResultType]> {
-  let arguments = (ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs);
-  let results = (outs SignlessIntegerLike:$result);
+  let arguments = (ins SignlessIntegerOrIndexLike:$lhs, SignlessIntegerOrIndexLike:$rhs);
+  let results = (outs SignlessIntegerOrIndexLike:$result);
 
   let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($result)";
 }
@@ -976,7 +976,7 @@ def Math_FPowIOp : Math_Op<"fpowi",
     ```
   }];
 
-  let arguments = (ins FloatLike:$lhs, SignlessIntegerLike:$rhs,
+  let arguments = (ins FloatLike:$lhs, SignlessIntegerOrIndexLike:$rhs,
       DefaultValuedAttr<Arith_FastMathAttr,
                         "::mlir::arith::FastMathFlags::none">:$fastmath);
   let results = (outs FloatLike:$result);
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 38bff642630fee8..82e335e30b6fa47 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -910,24 +910,29 @@ def BoolLikeOfAnyRank : TypeOrContainerOfAnyRank<I1, "bool-like">;
 
 // Type constraint for signless-integer-like types: signless integers, 
 // vectors of signless integers or tensors of signless integers.
-def SignlessInteger : TypeOrValueSemanticsContainer<
+def SignlessIntegerLike : TypeOrValueSemanticsContainer<
     AnySignlessInteger, "signless-integer">;
 
 // Type constraint for signless-integer-like types: signless integers, indices,
 // vectors of signless integers or indices, tensors of signless integers.
-def SignlessIntegerLike : TypeOrValueSemanticsContainer<
+def SignlessIntegerOrIndexLike : TypeOrValueSemanticsContainer<
     AnySignlessIntegerOrIndex, "signless-integer-like">;
 
-def SignlessIntegerLikeOfAnyRank : TypeOrContainerOfAnyRank<
+def SignlessIntegerOrIndexLikeOfAnyRank : TypeOrContainerOfAnyRank<
     AnySignlessIntegerOrIndex,
     "signless-integer-like">;
 
 // Type constraint for float-like types: floats, vectors or tensors thereof.
 def FloatLike : TypeOrContainer<AnyFloat, "floating-point-like">;
 
-// Type constraint for signless-integer-like or float-like types.
+// Type constraint for signless-integer-or-index-like or float-like types.
 def SignlessIntegerOrFloatLike : TypeConstraint<Or<[
     SignlessIntegerLike.predicate, FloatLike.predicate]>,
     "signless-integer-like or floating-point-like">;
 
+// Type constraint for signless-integer-or-index-like or float-like types.
+def SignlessIntegerOrIndexOrFloatLike : TypeConstraint<Or<[
+    SignlessIntegerOrIndexLike.predicate, FloatLike.predicate]>,
+    "signless-integer-or-index-like or floating-point-like">;
+
 #endif // COMMON_TYPE_CONSTRAINTS_TD
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 9cebb5534ebddb9..24a1d5531531981 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -219,10 +219,6 @@ bool BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
   if (!aT || !bT)
     return false;
 
-  if (isa<IndexType>(aT.getElementType()) ||
-      isa<IndexType>(bT.getElementType()))
-    return false;
-
   if (aT.getElementTypeBitWidth() != bT.getElementTypeBitWidth())
     return false;
 

>From 5767f51d164056c2cd675e3d865d21b1e3e3029f Mon Sep 17 00:00:00 2001
From: Jianjian GUAN <jacquesguan at me.com>
Date: Fri, 24 Jan 2025 16:10:42 +0800
Subject: [PATCH 5/5] Fix flang error

---
 flang/include/flang/Optimizer/Dialect/FIRTypes.td | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/flang/include/flang/Optimizer/Dialect/FIRTypes.td b/flang/include/flang/Optimizer/Dialect/FIRTypes.td
index 6ae74f16a72d37a..41e765c1cb7b9fd 100644
--- a/flang/include/flang/Optimizer/Dialect/FIRTypes.td
+++ b/flang/include/flang/Optimizer/Dialect/FIRTypes.td
@@ -579,7 +579,7 @@ def IsBaseBoxTypePred
 def fir_BaseBoxType : Type<IsBaseBoxTypePred, "fir.box or fir.class type">;
 
 // Generalized FIR and standard dialect types representing intrinsic types
-def AnyIntegerLike : TypeConstraint<Or<[SignlessIntegerLike.predicate,
+def AnyIntegerLike : TypeConstraint<Or<[SignlessIntegerOrIndexLike.predicate,
     AnySignedInteger.predicate, AnyUnsignedInteger.predicate,
     fir_IntegerType.predicate, fir_UnsignedType.predicate]>, "any integer">;
 def AnyLogicalLike : TypeConstraint<Or<[BoolLike.predicate,



More information about the flang-commits mailing list