[flang-commits] [flang] 990837f - [mlir][arith][tensor] Disable index type for bitcast (#121455)
via flang-commits
flang-commits at lists.llvm.org
Fri Jan 24 00:53:08 PST 2025
Author: Jianjian Guan
Date: 2025-01-24T16:53:04+08:00
New Revision: 990837f91de329b1e045f90fadb86ffe21611d9a
URL: https://github.com/llvm/llvm-project/commit/990837f91de329b1e045f90fadb86ffe21611d9a
DIFF: https://github.com/llvm/llvm-project/commit/990837f91de329b1e045f90fadb86ffe21611d9a.diff
LOG: [mlir][arith][tensor] Disable index type for bitcast (#121455)
Fixes #121397.
Added:
Modified:
flang/include/flang/Optimizer/Dialect/FIRTypes.td
mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
mlir/include/mlir/Dialect/Math/IR/MathOps.td
mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
mlir/include/mlir/IR/CommonTypeConstraints.td
mlir/lib/Dialect/Arith/IR/ArithOps.cpp
mlir/test/Dialect/Arith/invalid.mlir
mlir/test/Dialect/Tensor/invalid.mlir
Removed:
################################################################################
diff --git a/flang/include/flang/Optimizer/Dialect/FIRTypes.td b/flang/include/flang/Optimizer/Dialect/FIRTypes.td
index 6ae74f16a72d37..41e765c1cb7b9f 100644
--- a/flang/include/flang/Optimizer/Dialect/FIRTypes.td
+++ b/flang/include/flang/Optimizer/Dialect/FIRTypes.td
@@ -579,7 +579,7 @@ def IsBaseBoxTypePred
def fir_BaseBoxType : Type<IsBaseBoxTypePred, "fir.box or fir.class type">;
// Generalized FIR and standard dialect types representing intrinsic types
-def AnyIntegerLike : TypeConstraint<Or<[SignlessIntegerLike.predicate,
+def AnyIntegerLike : TypeConstraint<Or<[SignlessIntegerOrIndexLike.predicate,
AnySignedInteger.predicate, AnyUnsignedInteger.predicate,
fir_IntegerType.predicate, fir_UnsignedType.predicate]>, "any integer">;
def AnyLogicalLike : TypeConstraint<Or<[BoolLike.predicate,
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 0722ff68d890de..ea9b0f6509b80b 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -51,8 +51,8 @@ class Arith_BinaryOp<string mnemonic, list<Trait> traits = []> :
class Arith_IntBinaryOp<string mnemonic, list<Trait> traits = []> :
Arith_BinaryOp<mnemonic, traits #
[DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]>,
- Arguments<(ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs)>,
- Results<(outs SignlessIntegerLike:$result)>;
+ Arguments<(ins SignlessIntegerOrIndexLike:$lhs, SignlessIntegerOrIndexLike:$rhs)>,
+ Results<(outs SignlessIntegerOrIndexLike:$result)>;
// Base class for integer binary operations without undefined behavior.
class Arith_TotalIntBinaryOp<string mnemonic, list<Trait> traits = []> :
@@ -155,11 +155,11 @@ class Arith_IntBinaryOpWithOverflowFlags<string mnemonic, list<Trait> traits = [
Arith_BinaryOp<mnemonic, traits #
[Pure, DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
DeclareOpInterfaceMethods<ArithIntegerOverflowFlagsInterface>]>,
- Arguments<(ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs,
+ Arguments<(ins SignlessIntegerOrIndexLike:$lhs, SignlessIntegerOrIndexLike:$rhs,
DefaultValuedAttr<
Arith_IntegerOverflowAttr,
"::mlir::arith::IntegerOverflowFlags::none">:$overflowFlags)>,
- Results<(outs SignlessIntegerLike:$result)> {
+ Results<(outs SignlessIntegerOrIndexLike:$result)> {
let assemblyFormat = [{ $lhs `,` $rhs (`overflow` `` $overflowFlags^)?
attr-dict `:` type($result) }];
@@ -198,7 +198,7 @@ def Arith_ConstantOp : Op<Arith_Dialect, "constant",
// However, it is necessary to allow arith.constant to return vectors/tensors
// of strings and signed/unsigned integers (for now) as an artefact of
// splitting the Standard dialect.
- let results = (outs /*SignlessIntegerOrFloatLike*/AnyType:$result);
+ let results = (outs /*SignlessIntegerOrIndexOrFloatLike*/AnyType:$result);
let extraClassDeclaration = [{
/// Whether the constant op can be constructed with a particular value and
@@ -288,8 +288,8 @@ def Arith_AddUIExtendedOp : Arith_Op<"addui_extended", [Pure, Commutative,
```
}];
- let arguments = (ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs);
- let results = (outs SignlessIntegerLike:$sum, BoolLike:$overflow);
+ let arguments = (ins SignlessIntegerOrIndexLike:$lhs, SignlessIntegerOrIndexLike:$rhs);
+ let results = (outs SignlessIntegerOrIndexLike:$sum, BoolLike:$overflow);
let assemblyFormat = [{
$lhs `,` $rhs attr-dict `:` type($sum) `,` type($overflow)
}];
@@ -429,8 +429,8 @@ def Arith_MulSIExtendedOp : Arith_Op<"mulsi_extended", [Pure, Commutative,
```
}];
- let arguments = (ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs);
- let results = (outs SignlessIntegerLike:$low, SignlessIntegerLike:$high);
+ let arguments = (ins SignlessIntegerOrIndexLike:$lhs, SignlessIntegerOrIndexLike:$rhs);
+ let results = (outs SignlessIntegerOrIndexLike:$low, SignlessIntegerOrIndexLike:$high);
let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs)";
@@ -472,8 +472,8 @@ def Arith_MulUIExtendedOp : Arith_Op<"mului_extended", [Pure, Commutative,
```
}];
- let arguments = (ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs);
- let results = (outs SignlessIntegerLike:$low, SignlessIntegerLike:$high);
+ let arguments = (ins SignlessIntegerOrIndexLike:$lhs, SignlessIntegerOrIndexLike:$rhs);
+ let results = (outs SignlessIntegerOrIndexLike:$low, SignlessIntegerOrIndexLike:$high);
let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs)";
@@ -1350,7 +1350,7 @@ def Arith_FPToSIOp : Arith_FToICastOp<"fptosi"> {
// Index cast can convert between memrefs of signless integers and indices too.
def IndexCastTypeConstraint : TypeConstraint<Or<[
- SignlessIntegerLike.predicate,
+ SignlessIntegerOrIndexLike.predicate,
MemRefOf<[AnySignlessInteger, Index]>.predicate]>,
"signless-integer-like or memref of signless-integer">;
@@ -1392,11 +1392,10 @@ def Arith_IndexCastUIOp
// BitcastOp
//===----------------------------------------------------------------------===//
-// Bitcast can convert between memrefs of signless integers, indices, and
-// floats too.
+// Bitcast can convert between memrefs of signless integers and floats.
def BitcastTypeConstraint : TypeConstraint<Or<[
SignlessIntegerOrFloatLike.predicate,
- MemRefOf<[AnySignlessInteger, Index, AnyFloat]>.predicate]>,
+ MemRefOf<[AnySignlessInteger, AnyFloat]>.predicate]>,
"signless-integer-or-float-like or memref of signless-integer or float">;
def Arith_BitcastOp : Arith_CastOp<"bitcast", BitcastTypeConstraint,
@@ -1496,8 +1495,8 @@ def Arith_CmpIOp
}];
let arguments = (ins Arith_CmpIPredicateAttr:$predicate,
- SignlessIntegerLikeOfAnyRank:$lhs,
- SignlessIntegerLikeOfAnyRank:$rhs);
+ SignlessIntegerOrIndexLikeOfAnyRank:$lhs,
+ SignlessIntegerOrIndexLikeOfAnyRank:$rhs);
let hasFolder = 1;
let hasCanonicalizer = 1;
diff --git a/mlir/include/mlir/Dialect/Math/IR/MathOps.td b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
index 3f6d2d2e44783f..5990a9f0d2e442 100644
--- a/mlir/include/mlir/Dialect/Math/IR/MathOps.td
+++ b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
@@ -28,8 +28,8 @@ class Math_Op<string mnemonic, list<Trait> traits = []> :
// tensor thereof.
class Math_IntegerUnaryOp<string mnemonic, list<Trait> traits = []> :
Math_Op<mnemonic, traits # [SameOperandsAndResultType]> {
- let arguments = (ins SignlessIntegerLike:$operand);
- let results = (outs SignlessIntegerLike:$result);
+ let arguments = (ins SignlessIntegerOrIndexLike:$operand);
+ let results = (outs SignlessIntegerOrIndexLike:$result);
let assemblyFormat = "$operand attr-dict `:` type($result)";
}
@@ -55,8 +55,8 @@ class Math_FloatUnaryOp<string mnemonic, list<Trait> traits = []> :
// type, vector or tensor thereof.
class Math_IntegerBinaryOp<string mnemonic, list<Trait> traits = []> :
Math_Op<mnemonic, traits # [SameOperandsAndResultType]> {
- let arguments = (ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs);
- let results = (outs SignlessIntegerLike:$result);
+ let arguments = (ins SignlessIntegerOrIndexLike:$lhs, SignlessIntegerOrIndexLike:$rhs);
+ let results = (outs SignlessIntegerOrIndexLike:$result);
let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($result)";
}
@@ -976,7 +976,7 @@ def Math_FPowIOp : Math_Op<"fpowi",
```
}];
- let arguments = (ins FloatLike:$lhs, SignlessIntegerLike:$rhs,
+ let arguments = (ins FloatLike:$lhs, SignlessIntegerOrIndexLike:$rhs,
DefaultValuedAttr<Arith_FastMathAttr,
"::mlir::arith::FastMathFlags::none">:$fastmath);
let results = (outs FloatLike:$result);
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 812ac209845020..8ad1b23cb2bfe2 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -75,8 +75,10 @@ def Tensor_BitcastOp : Tensor_Op<"bitcast", [
```
}];
- let arguments = (ins AnyTensor:$source);
- let results = (outs AnyTensor:$dest);
+ let arguments = (ins TensorOf<[AnySignlessInteger, AnyUnsignedInteger,
+ AnySignedInteger, AnyFloat]>:$source);
+ let results = (outs TensorOf<[AnySignlessInteger, AnyUnsignedInteger,
+ AnySignedInteger, AnyFloat]>:$dest);
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
let hasCanonicalizer = 1;
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index e5929103035686..82e335e30b6fa4 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -908,21 +908,31 @@ def BoolLike : TypeOrContainer<I1, "bool-like">;
def BoolLikeOfAnyRank : TypeOrContainerOfAnyRank<I1, "bool-like">;
+// Type constraint for signless-integer-like types: signless integers,
+// vectors of signless integers or tensors of signless integers.
+def SignlessIntegerLike : TypeOrValueSemanticsContainer<
+ AnySignlessInteger, "signless-integer">;
+
// Type constraint for signless-integer-like types: signless integers, indices,
// vectors of signless integers or indices, tensors of signless integers.
-def SignlessIntegerLike : TypeOrValueSemanticsContainer<
+def SignlessIntegerOrIndexLike : TypeOrValueSemanticsContainer<
AnySignlessIntegerOrIndex, "signless-integer-like">;
-def SignlessIntegerLikeOfAnyRank : TypeOrContainerOfAnyRank<
+def SignlessIntegerOrIndexLikeOfAnyRank : TypeOrContainerOfAnyRank<
AnySignlessIntegerOrIndex,
"signless-integer-like">;
// Type constraint for float-like types: floats, vectors or tensors thereof.
def FloatLike : TypeOrContainer<AnyFloat, "floating-point-like">;
-// Type constraint for signless-integer-like or float-like types.
+// Type constraint for signless-integer-or-index-like or float-like types.
def SignlessIntegerOrFloatLike : TypeConstraint<Or<[
SignlessIntegerLike.predicate, FloatLike.predicate]>,
"signless-integer-like or floating-point-like">;
+// Type constraint for signless-integer-or-index-like or float-like types.
+def SignlessIntegerOrIndexOrFloatLike : TypeConstraint<Or<[
+ SignlessIntegerOrIndexLike.predicate, FloatLike.predicate]>,
+ "signless-integer-or-index-like or floating-point-like">;
+
#endif // COMMON_TYPE_CONSTRAINTS_TD
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index e016a6e16e59ff..7ca104691e6df6 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1740,10 +1740,8 @@ bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
if (!areValidCastInputsAndOutputs(inputs, outputs))
return false;
- auto srcType =
- getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(inputs.front());
- auto dstType =
- getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(outputs.front());
+ auto srcType = getTypeIfLikeOrMemRef<IntegerType, FloatType>(inputs.front());
+ auto dstType = getTypeIfLikeOrMemRef<IntegerType, FloatType>(outputs.front());
if (!srcType || !dstType)
return false;
diff --git a/mlir/test/Dialect/Arith/invalid.mlir b/mlir/test/Dialect/Arith/invalid.mlir
index 088da475e8eb4c..7bd68372de471e 100644
--- a/mlir/test/Dialect/Arith/invalid.mlir
+++ b/mlir/test/Dialect/Arith/invalid.mlir
@@ -853,3 +853,19 @@ func.func @select_tensor_encoding(
%0 = arith.select %arg0, %arg1, %arg2 : tensor<8xi1, "bar">, tensor<8xi32, "foo">
return %0 : tensor<8xi32, "foo">
}
+
+// -----
+
+func.func @bitcast_index_0(%arg0 : i64) -> index {
+ // expected-error @+1 {{'arith.bitcast' op result #0 must be signless-integer-or-float-like or memref of signless-integer or float, but got 'index'}}
+ %0 = arith.bitcast %arg0 : i64 to index
+ return %0 : index
+}
+
+// -----
+
+func.func @bitcast_index_1(%arg0 : index) -> i64 {
+ // expected-error @+1 {{'arith.bitcast' op operand #0 must be signless-integer-or-float-like or memref of signless-integer or float, but got 'index'}}
+ %0 = arith.bitcast %arg0 : index to i64
+ return %0 : i64
+}
diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir
index 1de3e281bc462b..0c6d8f4e05c332 100644
--- a/mlir/test/Dialect/Tensor/invalid.mlir
+++ b/mlir/test/Dialect/Tensor/invalid.mlir
@@ -807,3 +807,19 @@ func.func @unpack_static_inner_tile_size_and_dynamic_output_shape(
%0 = tensor.unpack %input inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %output : tensor<?x?x?x4xf32> -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
+
+// -----
+
+func.func @bitcast_index_0(%arg0 : tensor<?xi64>) -> tensor<?xindex> {
+ // expected-error @+1 {{'tensor.bitcast' op result #0 must be tensor of signless integer or unsigned integer or signed integer or floating-point values, but got 'tensor<?xindex>'}}
+ %0 = tensor.bitcast %arg0 : tensor<?xi64> to tensor<?xindex>
+ return %0 : tensor<?xindex>
+}
+
+// -----
+
+func.func @bitcast_index_1(%arg0 : tensor<?xindex>) -> tensor<?xi64> {
+ // expected-error @+1 {{'tensor.bitcast' op operand #0 must be tensor of signless integer or unsigned integer or signed integer or floating-point values, but got 'tensor<?xindex>'}}
+ %0 = tensor.bitcast %arg0 : tensor<?xindex> to tensor<?xi64>
+ return %0 : tensor<?xi64>
+}
More information about the flang-commits
mailing list