[Mlir-commits] [mlir] [MLIR][Arith][Vector] Reject i0 integer type in arith and vector ops (PR #183589)
Mehdi Amini
llvmlistbot at llvm.org
Wed Mar 4 03:21:41 PST 2026
https://github.com/joker-eph updated https://github.com/llvm/llvm-project/pull/183589
>From 85502384562a7ab17c4ef3603100bbdbb1658d94 Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Thu, 26 Feb 2026 10:21:20 -0800
Subject: [PATCH] [MLIR][Arith][Vector] Reject i0 integer type in arith and
vector ops
Add ODS type constraints that exclude zero-bitwidth integers (i0) from
operations in the arith and vector dialects. i0 has no meaningful
arithmetic representation and operations on it can trigger undefined
behavior (e.g. bitwidth calculations assuming non-zero width).
Changes:
- Add `AnyNonZeroBitwidthSignlessInteger` (as a `ConfinedType` over
`AnySignlessInteger`) and `AnyNonZeroBitwidthSignlessIntegerOrIndex`
to CommonTypeConstraints.td.
- Introduce `Arith_SignlessIntegerOrIndexLike` in ArithOps.td that wraps
`AnyNonZeroBitwidthSignlessIntegerOrIndex` via
`TypeOrValueSemanticsContainer`, and update `SignlessFixedWidthIntegerLike`
to use `AnyNonZeroBitwidthSignlessInteger`. Replace all uses of the
shared `SignlessIntegerOrIndexLike` in ArithOps.td with the new
dialect-local constraint.
- Update `IndexCastTypeConstraint` to use `Arith_SignlessIntegerOrIndexLike`.
- Update `BitcastTypeConstraint` to exclude i0 by composing the already-
defined `SignlessFixedWidthIntegerLike` and `FloatLike` constraints,
keeping the definition compact (3 alternatives instead of 7).
- Add `AnyVectorOfNonI0Elem` and `AnyVectorOfNonZeroRankNonI0Elem` in
VectorOps.td and apply them to `vector.contract`, `vector.reduction`,
`vector.multi_reduction`, `vector.outerproduct`, `vector.bitcast`, and
`vector.scan`.
- Update arith/invalid.mlir with explicit i0 rejection tests covering all
integer op families (binary ops, cast ops, extended-multiply ops, cmpi,
bitcast, index_cast, index_castui) for both scalar and vector<N> forms.
- Update vector/invalid.mlir with i0 rejection tests for all covered ops.
- Remove the now-invalid i0 canonicalization tests from
arith/canonicalize.mlir.
---
.../mlir/Dialect/Arith/IR/ArithBase.td | 4 +
.../include/mlir/Dialect/Arith/IR/ArithOps.td | 58 ++++---
.../mlir/Dialect/Vector/IR/VectorOps.td | 35 ++--
mlir/include/mlir/IR/CommonTypeConstraints.td | 11 ++
mlir/test/Dialect/Arith/canonicalize.mlir | 81 ---------
mlir/test/Dialect/Arith/invalid.mlir | 155 +++++++++++++++++-
mlir/test/Dialect/Vector/invalid.mlir | 56 +++++++
7 files changed, 275 insertions(+), 125 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
index a3e9fc3c57bb1..d4649c45c90c2 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
@@ -25,6 +25,10 @@ def Arith_Dialect : Dialect {
propagate poison values, i.e., if any of its inputs are poison, then the
output is poison. Unless otherwise stated, operations applied to `vector`
and `tensor` values propagates poison elementwise.
+
+ Manipulating value with type `i0` isn't supported in this dialect at the
+ moment and is considered invalid. This can change in the future if some
+ motivating use-cases are presented.
}];
let hasConstantMaterializer = 1;
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 4ee9187c7f224..ab85574069687 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -21,6 +21,12 @@ include "mlir/IR/BuiltinAttributeInterfaces.td"
include "mlir/IR/OpAsmInterface.td"
include "mlir/IR/EnumAttr.td"
+// Type constraint for signless-integer-or-index-like types that additionally
+// excludes i0 (zero-bitwidth integers), used for arith integer operations.
+def Arith_SignlessIntegerOrIndexLike : TypeOrValueSemanticsContainer<
+ AnyNonZeroBitwidthSignlessIntegerOrIndex,
+ "signless-non-zero-bitwidth-integer-like">;
+
// Base class for Arith dialect ops. Ops in this dialect have no memory
// effects and can be applied element-wise to vectors and tensors.
class Arith_Op<string mnemonic, list<Trait> traits = []> :
@@ -51,8 +57,8 @@ class Arith_BinaryOp<string mnemonic, list<Trait> traits = []> :
class Arith_IntBinaryOp<string mnemonic, list<Trait> traits = []> :
Arith_BinaryOp<mnemonic, traits #
[DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]>,
- Arguments<(ins SignlessIntegerOrIndexLike:$lhs, SignlessIntegerOrIndexLike:$rhs)>,
- Results<(outs SignlessIntegerOrIndexLike:$result)>;
+ Arguments<(ins Arith_SignlessIntegerOrIndexLike:$lhs, Arith_SignlessIntegerOrIndexLike:$rhs)>,
+ Results<(outs Arith_SignlessIntegerOrIndexLike:$result)>;
// Base class for integer binary operations without undefined behavior.
class Arith_TotalIntBinaryOp<string mnemonic, list<Trait> traits = []> :
@@ -110,10 +116,12 @@ class Arith_CastOp<string mnemonic, TypeConstraint From, TypeConstraint To,
// Casts do not accept indices. Type constraint for signless-integer-like types
// excluding indices: signless integers, vectors or tensors thereof.
+// i0 (zero-bitwidth) integers are excluded as they have no meaningful
+// representation in cast operations.
def SignlessFixedWidthIntegerLike : TypeConstraint<Or<[
- AnySignlessInteger.predicate,
- VectorOfAnyRankOf<[AnySignlessInteger]>.predicate,
- TensorOf<[AnySignlessInteger]>.predicate]>,
+ AnyNonZeroBitwidthSignlessInteger.predicate,
+ VectorOfAnyRankOf<[AnyNonZeroBitwidthSignlessInteger]>.predicate,
+ TensorOf<[AnyNonZeroBitwidthSignlessInteger]>.predicate]>,
"signless-fixed-width-integer-like">;
// Cast from an integer type to another integer type.
@@ -148,11 +156,11 @@ class Arith_IntBinaryOpWithOverflowFlags<string mnemonic, list<Trait> traits = [
Arith_BinaryOp<mnemonic, traits #
[Pure, DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
DeclareOpInterfaceMethods<ArithIntegerOverflowFlagsInterface>]>,
- Arguments<(ins SignlessIntegerOrIndexLike:$lhs, SignlessIntegerOrIndexLike:$rhs,
+ Arguments<(ins Arith_SignlessIntegerOrIndexLike:$lhs, Arith_SignlessIntegerOrIndexLike:$rhs,
DefaultValuedAttr<
Arith_IntegerOverflowAttr,
"::mlir::arith::IntegerOverflowFlags::none">:$overflowFlags)>,
- Results<(outs SignlessIntegerOrIndexLike:$result)> {
+ Results<(outs Arith_SignlessIntegerOrIndexLike:$result)> {
let assemblyFormat = [{ $lhs `,` $rhs (`overflow` `` $overflowFlags^)?
attr-dict `:` type($result) }];
@@ -162,10 +170,10 @@ class Arith_IntBinaryOpWithExactFlag<string mnemonic, list<Trait> traits = []> :
Arith_BinaryOp<mnemonic, traits #
[DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
DeclareOpInterfaceMethods<ArithExactFlagInterface>]>,
- Arguments<(ins SignlessIntegerOrIndexLike:$lhs,
- SignlessIntegerOrIndexLike:$rhs,
+ Arguments<(ins Arith_SignlessIntegerOrIndexLike:$lhs,
+ Arith_SignlessIntegerOrIndexLike:$rhs,
UnitAttr:$isExact)>,
- Results<(outs SignlessIntegerOrIndexLike:$result)> {
+ Results<(outs Arith_SignlessIntegerOrIndexLike:$result)> {
let assemblyFormat = [{ $lhs `,` $rhs (`exact` $isExact^)?
attr-dict `:` type($result) }];
@@ -294,8 +302,8 @@ def Arith_AddUIExtendedOp : Arith_Op<"addui_extended", [Pure, Commutative,
```
}];
- let arguments = (ins SignlessIntegerOrIndexLike:$lhs, SignlessIntegerOrIndexLike:$rhs);
- let results = (outs SignlessIntegerOrIndexLike:$sum, BoolLike:$overflow);
+ let arguments = (ins Arith_SignlessIntegerOrIndexLike:$lhs, Arith_SignlessIntegerOrIndexLike:$rhs);
+ let results = (outs Arith_SignlessIntegerOrIndexLike:$sum, BoolLike:$overflow);
let assemblyFormat = [{
$lhs `,` $rhs attr-dict `:` type($sum) `,` type($overflow)
}];
@@ -435,8 +443,8 @@ def Arith_MulSIExtendedOp : Arith_Op<"mulsi_extended", [Pure, Commutative,
```
}];
- let arguments = (ins SignlessIntegerOrIndexLike:$lhs, SignlessIntegerOrIndexLike:$rhs);
- let results = (outs SignlessIntegerOrIndexLike:$low, SignlessIntegerOrIndexLike:$high);
+ let arguments = (ins Arith_SignlessIntegerOrIndexLike:$lhs, Arith_SignlessIntegerOrIndexLike:$rhs);
+ let results = (outs Arith_SignlessIntegerOrIndexLike:$low, Arith_SignlessIntegerOrIndexLike:$high);
let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs)";
@@ -478,8 +486,8 @@ def Arith_MulUIExtendedOp : Arith_Op<"mului_extended", [Pure, Commutative,
```
}];
- let arguments = (ins SignlessIntegerOrIndexLike:$lhs, SignlessIntegerOrIndexLike:$rhs);
- let results = (outs SignlessIntegerOrIndexLike:$low, SignlessIntegerOrIndexLike:$high);
+ let arguments = (ins Arith_SignlessIntegerOrIndexLike:$lhs, Arith_SignlessIntegerOrIndexLike:$rhs);
+ let results = (outs Arith_SignlessIntegerOrIndexLike:$low, Arith_SignlessIntegerOrIndexLike:$high);
let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs)";
@@ -1574,9 +1582,9 @@ def Arith_FPToSIOp : Arith_FToICastOp<"fptosi"> {
// Index cast can convert between memrefs of signless integers and indices too.
def IndexCastTypeConstraint : TypeConstraint<Or<[
- SignlessIntegerOrIndexLike.predicate,
- MemRefOf<[AnySignlessInteger, Index]>.predicate]>,
- "signless-integer-like or memref of signless-integer">;
+ Arith_SignlessIntegerOrIndexLike.predicate,
+ MemRefOf<[AnyNonZeroBitwidthSignlessInteger, Index]>.predicate]>,
+ "signless-non-zero-bitwidth-integer-like or memref of signless-integer">;
def Arith_IndexCastOp
: Arith_CastOp<"index_cast", IndexCastTypeConstraint, IndexCastTypeConstraint,
@@ -1664,10 +1672,12 @@ def Arith_IndexCastUIOp
//===----------------------------------------------------------------------===//
// Bitcast can convert between memrefs of signless integers and floats.
+// i0 (zero-bitwidth) integers are excluded: a 0-bit bitcast is meaningless.
def BitcastTypeConstraint : TypeConstraint<Or<[
- SignlessIntegerOrFloatLike.predicate,
- MemRefOf<[AnySignlessInteger, AnyFloat]>.predicate]>,
- "signless-integer-or-float-like or memref of signless-integer or float">;
+ SignlessFixedWidthIntegerLike.predicate,
+ FloatLike.predicate,
+ MemRefOf<[AnyNonZeroBitwidthSignlessInteger, AnyFloat]>.predicate]>,
+ "non-zero-bitwidth-signless-integer-or-float-like or memref of non-zero-bitwidth signless integer or float">;
def Arith_BitcastOp : Arith_CastOp<"bitcast", BitcastTypeConstraint,
BitcastTypeConstraint> {
@@ -1766,8 +1776,8 @@ def Arith_CmpIOp : Arith_CompareOp<"cmpi",
}];
let arguments = (ins Arith_CmpIPredicateAttr:$predicate,
- SignlessIntegerOrIndexLike:$lhs,
- SignlessIntegerOrIndexLike:$rhs);
+ Arith_SignlessIntegerOrIndexLike:$lhs,
+ Arith_SignlessIntegerOrIndexLike:$rhs);
let hasFolder = 1;
let hasCanonicalizer = 1;
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index ddb04b6bbe40d..43ad435ccf1c1 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -32,6 +32,17 @@ include "mlir/Interfaces/ViewLikeInterface.td"
include "mlir/IR/BuiltinAttributes.td"
include "mlir/IR/EnumAttr.td"
+// Type constraint helpers for the vector dialect.
+
+// Any vector of any rank whose element type is not i0 (zero-bitwidth integer).
+// Floats, index, and integers with width >= 1 are all accepted.
+def AnyVectorOfNonI0Elem : VectorOfAnyRankOf<[
+ Type<CPred<"!$_self.isInteger(0)">, "non-zero-bitwidth type">]>;
+
+// Like AnyVectorOfNonI0Elem but additionally requires rank >= 1.
+def AnyVectorOfNonZeroRankNonI0Elem : VectorOfNonZeroRankOf<[
+ Type<CPred<"!$_self.isInteger(0)">, "non-zero-bitwidth type">]>;
+
// TODO: Add an attribute to specify a different algebra with operators other
// than the current set: {*, +}.
def Vector_ContractionOp :
@@ -44,7 +55,7 @@ def Vector_ContractionOp :
DeclareOpInterfaceMethods<MaskableOpInterface>,
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
]>,
- Arguments<(ins AnyVectorOfNonZeroRank:$lhs, AnyVectorOfNonZeroRank:$rhs, AnyType:$acc,
+ Arguments<(ins AnyVectorOfNonZeroRankNonI0Elem:$lhs, AnyVectorOfNonZeroRankNonI0Elem:$rhs, AnyType:$acc,
ArrayAttr:$indexing_maps,
Vector_IteratorTypeArrayAttr:$iterator_types,
DefaultValuedAttr<Vector_CombiningKindAttr,
@@ -239,7 +250,7 @@ def Vector_ReductionOp :
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
]>,
Arguments<(ins Vector_CombiningKindAttr:$kind,
- AnyVectorOfAnyRank:$vector,
+ AnyVectorOfNonI0Elem:$vector,
Optional<AnyType>:$acc,
DefaultValuedAttr<
Arith_FastMathAttr,
@@ -299,7 +310,7 @@ def Vector_MultiDimReductionOp :
DeclareOpInterfaceMethods<VectorUnrollOpInterface,
["getShapeForUnroll"]>]>,
Arguments<(ins Vector_CombiningKindAttr:$kind,
- AnyVectorOfNonZeroRank:$source,
+ AnyVectorOfNonZeroRankNonI0Elem:$source,
AnyType:$acc,
DenseI64ArrayAttr:$reduction_dims)>,
Results<(outs AnyType:$dest)> {
@@ -1094,10 +1105,10 @@ def Vector_OuterProductOp :
PredOpTrait<"rhs operand and result have same element type",
TCresVTEtIsSameAsOpBase<0, 1>>,
DeclareOpInterfaceMethods<MaskableOpInterface>]>,
- Arguments<(ins AnyVectorOfNonZeroRank:$lhs, AnyType:$rhs,
- Optional<AnyVectorOfNonZeroRank>:$acc,
+ Arguments<(ins AnyVectorOfNonZeroRankNonI0Elem:$lhs, AnyType:$rhs,
+ Optional<AnyVectorOfNonZeroRankNonI0Elem>:$acc,
DefaultValuedAttr<Vector_CombiningKindAttr, "CombiningKind::ADD">:$kind)>,
- Results<(outs AnyVectorOfNonZeroRank)> {
+ Results<(outs AnyVectorOfNonZeroRankNonI0Elem)> {
let summary = "vector outerproduct with optional fused add";
let description = [{
Takes 2 1-D vectors and returns the 2-D vector containing the outer-product,
@@ -2454,8 +2465,8 @@ def Vector_ShapeCastOp :
def Vector_BitCastOp :
Vector_Op<"bitcast", [Pure, AllRanksMatch<["source", "result"]>]>,
- Arguments<(ins AnyVectorOfAnyRank:$source)>,
- Results<(outs AnyVectorOfAnyRank:$result)>{
+ Arguments<(ins AnyVectorOfNonI0Elem:$source)>,
+ Results<(outs AnyVectorOfNonI0Elem:$result)>{
let summary = "bitcast casts between vectors";
let description = [{
The bitcast operation casts between vectors of the same rank, the minor 1-D
@@ -2938,12 +2949,12 @@ def Vector_ScanOp :
AllTypesMatch<["source", "dest"]>,
AllTypesMatch<["initial_value", "accumulated_value"]> ]>,
Arguments<(ins Vector_CombiningKindAttr:$kind,
- AnyVectorOfNonZeroRank:$source,
- AnyVectorOfAnyRank:$initial_value,
+ AnyVectorOfNonZeroRankNonI0Elem:$source,
+ AnyVectorOfNonI0Elem:$initial_value,
I64Attr:$reduction_dim,
BoolAttr:$inclusive)>,
- Results<(outs AnyVectorOfNonZeroRank:$dest,
- AnyVectorOfAnyRank:$accumulated_value)> {
+ Results<(outs AnyVectorOfNonZeroRankNonI0Elem:$dest,
+ AnyVectorOfNonI0Elem:$accumulated_value)> {
let summary = "Scan operation";
let description = [{
Performs an inclusive/exclusive scan on an n-D vector along a single
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index a49880b81e90d..57caaae08462f 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -308,6 +308,17 @@ def Index : Type<CPred<"::llvm::isa<::mlir::IndexType>($_self)">, "index",
def AnySignlessIntegerOrIndex : Type<CPred<"$_self.isSignlessIntOrIndex()">,
"signless integer or index">;
+// A signless integer type with a non-zero bitwidth (excludes i0).
+def AnyNonZeroBitwidthSignlessInteger : ConfinedType<
+ AnySignlessInteger, [CPred<"!$_self.isInteger(0)">],
+ "non-zero-bitwidth signless integer">;
+
+// A non-zero-bitwidth signless integer or index type.
+def AnyNonZeroBitwidthSignlessIntegerOrIndex : Type<
+ Or<[AnyNonZeroBitwidthSignlessInteger.predicate,
+ Index.predicate]>,
+ "non-zero-bitwidth signless integer or index">;
+
// Floating point types.
// Any float type irrespective of its width.
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index a2d0eff47ad92..326afcae696cc 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -3385,87 +3385,6 @@ func.func @unsignedExtendConstantResource() -> tensor<i16> {
return %ext : tensor<i16>
}
-// CHECK-LABEL: @extsi_i0
-// CHECK: %[[ZERO:.*]] = arith.constant 0 : i16
-// CHECK: return %[[ZERO]] : i16
-func.func @extsi_i0() -> i16 {
- %c0 = arith.constant 0 : i0
- %extsi = arith.extsi %c0 : i0 to i16
- return %extsi : i16
-}
-
-// CHECK-LABEL: @extui_i0
-// CHECK: %[[ZERO:.*]] = arith.constant 0 : i16
-// CHECK: return %[[ZERO]] : i16
-func.func @extui_i0() -> i16 {
- %c0 = arith.constant 0 : i0
- %extui = arith.extui %c0 : i0 to i16
- return %extui : i16
-}
-
-// CHECK-LABEL: @trunc_i0
-// CHECK: %[[ZERO:.*]] = arith.constant 0 : i0
-// CHECK: return %[[ZERO]] : i0
-func.func @trunc_i0() -> i0 {
- %cFF = arith.constant 0xFF : i8
- %trunc = arith.trunci %cFF : i8 to i0
- return %trunc : i0
-}
-
-// CHECK-LABEL: @shli_i0
-// CHECK: %[[ZERO:.*]] = arith.constant 0 : i0
-// CHECK: return %[[ZERO]] : i0
-func.func @shli_i0() -> i0 {
- %c0 = arith.constant 0 : i0
- %shli = arith.shli %c0, %c0 : i0
- return %shli : i0
-}
-
-// CHECK-LABEL: @shrsi_i0
-// CHECK: %[[ZERO:.*]] = arith.constant 0 : i0
-// CHECK: return %[[ZERO]] : i0
-func.func @shrsi_i0() -> i0 {
- %c0 = arith.constant 0 : i0
- %shrsi = arith.shrsi %c0, %c0 : i0
- return %shrsi : i0
-}
-
-// CHECK-LABEL: @shrui_i0
-// CHECK: %[[ZERO:.*]] = arith.constant 0 : i0
-// CHECK: return %[[ZERO]] : i0
-func.func @shrui_i0() -> i0 {
- %c0 = arith.constant 0 : i0
- %shrui = arith.shrui %c0, %c0 : i0
- return %shrui : i0
-}
-
-// CHECK-LABEL: @maxsi_i0
-// CHECK: %[[ZERO:.*]] = arith.constant 0 : i0
-// CHECK: return %[[ZERO]] : i0
-func.func @maxsi_i0() -> i0 {
- %c0 = arith.constant 0 : i0
- %maxsi = arith.maxsi %c0, %c0 : i0
- return %maxsi : i0
-}
-
-// CHECK-LABEL: @minsi_i0
-// CHECK: %[[ZERO:.*]] = arith.constant 0 : i0
-// CHECK: return %[[ZERO]] : i0
-func.func @minsi_i0() -> i0 {
- %c0 = arith.constant 0 : i0
- %minsi = arith.minsi %c0, %c0 : i0
- return %minsi : i0
-}
-
-// CHECK-LABEL: @mulsi_extended_i0
-// CHECK: %[[ZERO:.*]] = arith.constant 0 : i0
-// CHECK: return %[[ZERO]], %[[ZERO]] : i0
-func.func @mulsi_extended_i0() -> (i0, i0) {
- %c0 = arith.constant 0 : i0
- %mulsi_extended:2 = arith.mulsi_extended %c0, %c0 : i0
- return %mulsi_extended#0, %mulsi_extended#1 : i0, i0
-}
-
// CHECK-LABEL: @sequences_fastmath_contract
// CHECK-SAME: ([[ARG0:%.+]]: bf16)
// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
diff --git a/mlir/test/Dialect/Arith/invalid.mlir b/mlir/test/Dialect/Arith/invalid.mlir
index 70b23e56a712c..0ea614e0d4b97 100644
--- a/mlir/test/Dialect/Arith/invalid.mlir
+++ b/mlir/test/Dialect/Arith/invalid.mlir
@@ -130,14 +130,14 @@ func.func @func_with_ops(f32) {
func.func @func_with_ops(f32) {
^bb0(%a : f32):
- // expected-error at +1 {{'arith.addi' op operand #0 must be signless-integer-like}}
+ // expected-error at +1 {{'arith.addi' op operand #0 must be signless-non-zero-bitwidth-integer-like}}
%sf = arith.addi %a, %a : f32
}
// -----
func.func @func_with_ops(%a: f32) {
- // expected-error at +1 {{'arith.addui_extended' op operand #0 must be signless-integer-like}}
+ // expected-error at +1 {{'arith.addui_extended' op operand #0 must be signless-non-zero-bitwidth-integer-like}}
%r:2 = arith.addui_extended %a, %a : f32, i32
return
}
@@ -202,7 +202,7 @@ func.func @func_with_ops(i32, i32) {
// Integer comparisons are not recognized for float types.
func.func @func_with_ops(f32, f32) {
^bb0(%a : f32, %b : f32):
- %r = arith.cmpi eq, %a, %b : f32 // expected-error {{'lhs' must be signless-integer-like, but got 'f32'}}
+ %r = arith.cmpi eq, %a, %b : f32 // expected-error {{'lhs' must be signless-non-zero-bitwidth-integer-like, but got 'f32'}}
}
// -----
@@ -242,7 +242,7 @@ func.func @func_with_ops() {
// -----
func.func @invalid_cmp_shape(%idx : () -> ()) {
- // expected-error at +1 {{'lhs' must be signless-integer-like, but got '() -> ()'}}
+ // expected-error at +1 {{'lhs' must be signless-non-zero-bitwidth-integer-like, but got '() -> ()'}}
%cmp = arith.cmpi eq, %idx, %idx : () -> ()
// -----
@@ -352,7 +352,7 @@ func.func @index_cast_index_to_index(%arg0: index) {
// -----
func.func @index_cast_float(%arg0: index, %arg1: f32) {
- // expected-error at +1 {{op result #0 must be signless-integer-like or memref of signless-integer, but got 'f32'}}
+ // expected-error at +1 {{op result #0 must be signless-non-zero-bitwidth-integer-like or memref of signless-integer, but got 'f32'}}
%0 = arith.index_cast %arg0 : index to f32
return
}
@@ -360,7 +360,7 @@ func.func @index_cast_float(%arg0: index, %arg1: f32) {
// -----
func.func @index_cast_float_to_index(%arg0: f32) {
- // expected-error at +1 {{op operand #0 must be signless-integer-like or memref of signless-integer, but got 'f32'}}
+ // expected-error at +1 {{op operand #0 must be signless-non-zero-bitwidth-integer-like or memref of signless-integer, but got 'f32'}}
%0 = arith.index_cast %arg0 : f32 to index
return
}
@@ -857,7 +857,7 @@ func.func @select_tensor_encoding(
// -----
func.func @bitcast_index_0(%arg0 : i64) -> index {
- // expected-error @+1 {{'arith.bitcast' op result #0 must be signless-integer-or-float-like or memref of signless-integer or float, but got 'index'}}
+ // expected-error @+1 {{'arith.bitcast' op result #0 must be non-zero-bitwidth-signless-integer-or-float-like or memref of non-zero-bitwidth signless integer or float, but got 'index'}}
%0 = arith.bitcast %arg0 : i64 to index
return %0 : index
}
@@ -865,7 +865,7 @@ func.func @bitcast_index_0(%arg0 : i64) -> index {
// -----
func.func @bitcast_index_1(%arg0 : index) -> i64 {
- // expected-error @+1 {{'arith.bitcast' op operand #0 must be signless-integer-or-float-like or memref of signless-integer or float, but got 'index'}}
+ // expected-error @+1 {{'arith.bitcast' op operand #0 must be non-zero-bitwidth-signless-integer-or-float-like or memref of non-zero-bitwidth signless integer or float, but got 'index'}}
%0 = arith.bitcast %arg0 : index to i64
return %0 : i64
}
@@ -877,3 +877,142 @@ func.func @select_vector_condition_scalar_operands(%arg0: vector<1xi1>, %arg1: i
%0 = arith.select %arg0, %arg1, %arg1 : vector<1xi1>, i32
return
}
+
+// -----
+
+// Verify that i0 (zero-bitwidth integer) is rejected by arith integer ops.
+
+func.func @addi_i0(%a: i0, %b: i0) -> i0 {
+ // expected-error @+1 {{'arith.addi' op operand #0 must be signless-non-zero-bitwidth-integer-like, but got 'i0'}}
+ %0 = arith.addi %a, %b : i0
+ return %0 : i0
+}
+
+// -----
+
+func.func @addi_vector_i0(%a: vector<4xi0>, %b: vector<4xi0>) -> vector<4xi0> {
+ // expected-error @+1 {{'arith.addi' op operand #0 must be signless-non-zero-bitwidth-integer-like, but got 'vector<4xi0>'}}
+ %0 = arith.addi %a, %b : vector<4xi0>
+ return %0 : vector<4xi0>
+}
+
+// -----
+
+func.func @trunci_to_i0(%a: i32) -> i0 {
+ // expected-error @+1 {{'arith.trunci' op result #0 must be signless-fixed-width-integer-like, but got 'i0'}}
+ %0 = arith.trunci %a : i32 to i0
+ return %0 : i0
+}
+
+// -----
+
+func.func @extsi_from_i0(%a: i0) -> i16 {
+ // expected-error @+1 {{'arith.extsi' op operand #0 must be signless-fixed-width-integer-like, but got 'i0'}}
+ %0 = arith.extsi %a : i0 to i16
+ return %0 : i16
+}
+
+// -----
+
+func.func @extui_from_i0(%a: i0) -> i16 {
+ // expected-error @+1 {{'arith.extui' op operand #0 must be signless-fixed-width-integer-like, but got 'i0'}}
+ %0 = arith.extui %a : i0 to i16
+ return %0 : i16
+}
+
+// -----
+
+func.func @cmpi_i0(%a: i0, %b: i0) -> i1 {
+ // expected-error @+1 {{'lhs' must be signless-non-zero-bitwidth-integer-like, but got 'i0'}}
+ %0 = arith.cmpi eq, %a, %b : i0
+ return %0 : i1
+}
+
+// -----
+
+// Arith_TotalIntBinaryOp (andi, ori, xori, maxsi, maxui, minsi, minui,
+// floordivsi, remui, remsi).
+func.func @andi_i0(%a: i0, %b: i0) -> i0 {
+ // expected-error @+1 {{'arith.andi' op operand #0 must be signless-non-zero-bitwidth-integer-like, but got 'i0'}}
+ %0 = arith.andi %a, %b : i0
+ return %0 : i0
+}
+
+// -----
+
+// Arith_IntBinaryOpWithExactFlag (divsi, divui, shrsi, shrui).
+func.func @divsi_i0(%a: i0, %b: i0) -> i0 {
+ // expected-error @+1 {{'arith.divsi' op operand #0 must be signless-non-zero-bitwidth-integer-like, but got 'i0'}}
+ %0 = arith.divsi %a, %b : i0
+ return %0 : i0
+}
+
+// -----
+
+// Arith_IntBinaryOp (ceildivsi, ceildivui).
+func.func @ceildivsi_i0(%a: i0, %b: i0) -> i0 {
+ // expected-error @+1 {{'arith.ceildivsi' op operand #0 must be signless-non-zero-bitwidth-integer-like, but got 'i0'}}
+ %0 = arith.ceildivsi %a, %b : i0
+ return %0 : i0
+}
+
+// -----
+
+func.func @mulsi_extended_i0(%a: i0, %b: i0) -> (i0, i0) {
+ // expected-error @+1 {{'arith.mulsi_extended' op operand #0 must be signless-non-zero-bitwidth-integer-like, but got 'i0'}}
+ %0:2 = arith.mulsi_extended %a, %b : i0
+ return %0#0, %0#1 : i0, i0
+}
+
+// -----
+
+func.func @mului_extended_i0(%a: i0, %b: i0) -> (i0, i0) {
+ // expected-error @+1 {{'arith.mului_extended' op operand #0 must be signless-non-zero-bitwidth-integer-like, but got 'i0'}}
+ %0:2 = arith.mului_extended %a, %b : i0
+ return %0#0, %0#1 : i0, i0
+}
+
+// -----
+
+// IToFCastOp (sitofp, uitofp) — cast FROM i0.
+func.func @sitofp_i0(%a: i0) -> f32 {
+ // expected-error @+1 {{'arith.sitofp' op operand #0 must be signless-fixed-width-integer-like, but got 'i0'}}
+ %0 = arith.sitofp %a : i0 to f32
+ return %0 : f32
+}
+
+// -----
+
+// FToICastOp (fptosi, fptoui) — cast TO i0.
+func.func @fptosi_i0(%a: f32) -> i0 {
+ // expected-error @+1 {{'arith.fptosi' op result #0 must be signless-fixed-width-integer-like, but got 'i0'}}
+ %0 = arith.fptosi %a : f32 to i0
+ return %0 : i0
+}
+
+// -----
+
+// arith.bitcast rejects i0 source and result.
+func.func @bitcast_i0(%a: i0) -> i0 {
+ // expected-error @+1 {{'arith.bitcast' op operand #0 must be non-zero-bitwidth-signless-integer-or-float-like or memref of non-zero-bitwidth signless integer or float, but got 'i0'}}
+ %0 = arith.bitcast %a : i0 to i0
+ return %0 : i0
+}
+
+// -----
+
+// arith.index_cast rejects i0.
+func.func @index_cast_i0(%a: i0) -> index {
+ // expected-error @+1 {{'arith.index_cast' op operand #0 must be signless-non-zero-bitwidth-integer-like or memref of signless-integer, but got 'i0'}}
+ %0 = arith.index_cast %a : i0 to index
+ return %0 : index
+}
+
+// -----
+
+// arith.index_castui rejects i0.
+func.func @index_castui_i0(%a: i0) -> index {
+ // expected-error @+1 {{'arith.index_castui' op operand #0 must be signless-non-zero-bitwidth-integer-like or memref of signless-integer, but got 'i0'}}
+ %0 = arith.index_castui %a : i0 to index
+ return %0 : index
+}
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 9b2a2bdc19a57..333d342d76103 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -2064,3 +2064,59 @@ func.func @store_non_pow_of_2_alignment(%memref: memref<4xi32>, %val: vector<4xi
vector.store %val, %memref[%c0] { alignment = 3 } : memref<4xi32>, vector<4xi32>
return
}
+
+// -----
+
+// Verify that vector.bitcast rejects vectors with i0 (zero-bitwidth) element type.
+func.func @bitcast_i0(%a: vector<4xi0>) -> vector<4xi0> {
+ // expected-error @+1 {{'vector.bitcast' op operand #0 must be vector of non-zero-bitwidth type values, but got 'vector<4xi0>'}}
+ %0 = vector.bitcast %a : vector<4xi0> to vector<4xi0>
+ return %0 : vector<4xi0>
+}
+
+// -----
+
+func.func @reduction_i0(%a: vector<4xi0>) -> i0 {
+ // expected-error @+1 {{'vector.reduction' op operand #0 must be vector of non-zero-bitwidth type values, but got 'vector<4xi0>'}}
+ %0 = vector.reduction <add>, %a : vector<4xi0> into i0
+ return %0 : i0
+}
+
+// -----
+
+func.func @multi_reduction_i0(%a: vector<4x8xi0>, %acc: vector<4xi0>) -> vector<4xi0> {
+ // expected-error @+1 {{'vector.multi_reduction' op operand #0 must be vector of non-zero-bitwidth type values, but got 'vector<4x8xi0>'}}
+ %0 = vector.multi_reduction <add>, %a, %acc [1] : vector<4x8xi0> to vector<4xi0>
+ return %0 : vector<4xi0>
+}
+
+// -----
+
+func.func @contract_i0(%lhs: vector<4xi0>, %rhs: vector<4xi0>, %acc: i0) -> i0 {
+ // expected-error @+1 {{'vector.contract' op operand #0 must be vector of non-zero-bitwidth type values, but got 'vector<4xi0>'}}
+ %0 = vector.contract {
+ indexing_maps = [affine_map<(d0) -> (d0)>,
+ affine_map<(d0) -> (d0)>,
+ affine_map<(d0) -> ()>],
+ iterator_types = ["reduction"],
+ kind = #vector.kind<add>
+ } %lhs, %rhs, %acc : vector<4xi0>, vector<4xi0> into i0
+ return %0 : i0
+}
+
+// -----
+
+func.func @outerproduct_i0(%lhs: vector<4xi0>, %rhs: i0) -> vector<4xi0> {
+ // expected-error @+1 {{'vector.outerproduct' op operand #0 must be vector of non-zero-bitwidth type values, but got 'vector<4xi0>'}}
+ %0 = vector.outerproduct %lhs, %rhs : vector<4xi0>, i0
+ return %0 : vector<4xi0>
+}
+
+// -----
+
+func.func @scan_i0(%a: vector<4xi0>, %init: vector<1xi0>) -> (vector<4xi0>, vector<1xi0>) {
+ // expected-error @+1 {{'vector.scan' op operand #0 must be vector of non-zero-bitwidth type values, but got 'vector<4xi0>'}}
+ %0:2 = vector.scan <add>, %a, %init {inclusive = true, reduction_dim = 0 : i64} :
+ vector<4xi0>, vector<1xi0>
+ return %0#0, %0#1 : vector<4xi0>, vector<1xi0>
+}
More information about the Mlir-commits
mailing list