[Mlir-commits] [mlir] [MLIR][Arith][Vector] Reject i0 integer type in arith and vector ops (PR #183589)

Mehdi Amini llvmlistbot at llvm.org
Wed Mar 4 03:21:41 PST 2026


https://github.com/joker-eph updated https://github.com/llvm/llvm-project/pull/183589

>From 85502384562a7ab17c4ef3603100bbdbb1658d94 Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Thu, 26 Feb 2026 10:21:20 -0800
Subject: [PATCH] [MLIR][Arith][Vector] Reject i0 integer type in arith and
 vector ops

Add ODS type constraints that exclude zero-bitwidth integers (i0) from
operations in the arith and vector dialects.  i0 has no meaningful
arithmetic representation and operations on it can trigger undefined
behavior (e.g. bitwidth calculations assuming non-zero width).

Changes:
- Add `AnyNonZeroBitwidthSignlessInteger` (as a `ConfinedType` over
  `AnySignlessInteger`) and `AnyNonZeroBitwidthSignlessIntegerOrIndex`
  to CommonTypeConstraints.td.
- Introduce `Arith_SignlessIntegerOrIndexLike` in ArithOps.td that wraps
  `AnyNonZeroBitwidthSignlessIntegerOrIndex` via
  `TypeOrValueSemanticsContainer`, and update `SignlessFixedWidthIntegerLike`
  to use `AnyNonZeroBitwidthSignlessInteger`.  Replace all uses of the
  shared `SignlessIntegerOrIndexLike` in ArithOps.td with the new
  dialect-local constraint.
- Update `IndexCastTypeConstraint` to use `Arith_SignlessIntegerOrIndexLike`.
- Update `BitcastTypeConstraint` to exclude i0 by composing the already-
  defined `SignlessFixedWidthIntegerLike` and `FloatLike` constraints,
  keeping the definition compact (3 alternatives instead of 7).
- Add `AnyVectorOfNonI0Elem` and `AnyVectorOfNonZeroRankNonI0Elem` in
  VectorOps.td and apply them to `vector.contract`, `vector.reduction`,
  `vector.multi_reduction`, `vector.outerproduct`, `vector.bitcast`, and
  `vector.scan`.
- Update arith/invalid.mlir with explicit i0 rejection tests covering all
  integer op families (binary ops, cast ops, extended-multiply ops, cmpi,
  bitcast, index_cast, index_castui) for both scalar and vector<N> forms.
- Update vector/invalid.mlir with i0 rejection tests for all covered ops.
- Remove the now-invalid i0 canonicalization tests from
  arith/canonicalize.mlir.
---
 .../mlir/Dialect/Arith/IR/ArithBase.td        |   4 +
 .../include/mlir/Dialect/Arith/IR/ArithOps.td |  58 ++++---
 .../mlir/Dialect/Vector/IR/VectorOps.td       |  35 ++--
 mlir/include/mlir/IR/CommonTypeConstraints.td |  11 ++
 mlir/test/Dialect/Arith/canonicalize.mlir     |  81 ---------
 mlir/test/Dialect/Arith/invalid.mlir          | 155 +++++++++++++++++-
 mlir/test/Dialect/Vector/invalid.mlir         |  56 +++++++
 7 files changed, 275 insertions(+), 125 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
index a3e9fc3c57bb1..d4649c45c90c2 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
@@ -25,6 +25,10 @@ def Arith_Dialect : Dialect {
     propagate poison values, i.e., if any of its inputs are poison, then the
     output is poison. Unless otherwise stated, operations applied to `vector`
     and `tensor` values propagates poison elementwise.
+
+    Manipulating value with type `i0` isn't supported in this dialect at the
+    moment and is considered invalid. This can change in the future if some
+    motivating use-cases are presented.
   }];
 
   let hasConstantMaterializer = 1;
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 4ee9187c7f224..ab85574069687 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -21,6 +21,12 @@ include "mlir/IR/BuiltinAttributeInterfaces.td"
 include "mlir/IR/OpAsmInterface.td"
 include "mlir/IR/EnumAttr.td"
 
+// Type constraint for signless-integer-or-index-like types that additionally
+// excludes i0 (zero-bitwidth integers), used for arith integer operations.
+def Arith_SignlessIntegerOrIndexLike : TypeOrValueSemanticsContainer<
+    AnyNonZeroBitwidthSignlessIntegerOrIndex,
+    "signless-non-zero-bitwidth-integer-like">;
+
 // Base class for Arith dialect ops. Ops in this dialect have no memory
 // effects and can be applied element-wise to vectors and tensors.
 class Arith_Op<string mnemonic, list<Trait> traits = []> :
@@ -51,8 +57,8 @@ class Arith_BinaryOp<string mnemonic, list<Trait> traits = []> :
 class Arith_IntBinaryOp<string mnemonic, list<Trait> traits = []> :
     Arith_BinaryOp<mnemonic, traits #
       [DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]>,
-    Arguments<(ins SignlessIntegerOrIndexLike:$lhs, SignlessIntegerOrIndexLike:$rhs)>,
-    Results<(outs SignlessIntegerOrIndexLike:$result)>;
+    Arguments<(ins Arith_SignlessIntegerOrIndexLike:$lhs, Arith_SignlessIntegerOrIndexLike:$rhs)>,
+    Results<(outs Arith_SignlessIntegerOrIndexLike:$result)>;
 
 // Base class for integer binary operations without undefined behavior.
 class Arith_TotalIntBinaryOp<string mnemonic, list<Trait> traits = []> :
@@ -110,10 +116,12 @@ class Arith_CastOp<string mnemonic, TypeConstraint From, TypeConstraint To,
 
 // Casts do not accept indices. Type constraint for signless-integer-like types
 // excluding indices: signless integers, vectors or tensors thereof.
+// i0 (zero-bitwidth) integers are excluded as they have no meaningful
+// representation in cast operations.
 def SignlessFixedWidthIntegerLike : TypeConstraint<Or<[
-        AnySignlessInteger.predicate,
-        VectorOfAnyRankOf<[AnySignlessInteger]>.predicate,
-        TensorOf<[AnySignlessInteger]>.predicate]>,
+        AnyNonZeroBitwidthSignlessInteger.predicate,
+        VectorOfAnyRankOf<[AnyNonZeroBitwidthSignlessInteger]>.predicate,
+        TensorOf<[AnyNonZeroBitwidthSignlessInteger]>.predicate]>,
     "signless-fixed-width-integer-like">;
 
 // Cast from an integer type to another integer type.
@@ -148,11 +156,11 @@ class Arith_IntBinaryOpWithOverflowFlags<string mnemonic, list<Trait> traits = [
     Arith_BinaryOp<mnemonic, traits #
       [Pure, DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
        DeclareOpInterfaceMethods<ArithIntegerOverflowFlagsInterface>]>,
-    Arguments<(ins SignlessIntegerOrIndexLike:$lhs, SignlessIntegerOrIndexLike:$rhs,
+    Arguments<(ins Arith_SignlessIntegerOrIndexLike:$lhs, Arith_SignlessIntegerOrIndexLike:$rhs,
       DefaultValuedAttr<
         Arith_IntegerOverflowAttr,
         "::mlir::arith::IntegerOverflowFlags::none">:$overflowFlags)>,
-    Results<(outs SignlessIntegerOrIndexLike:$result)> {
+    Results<(outs Arith_SignlessIntegerOrIndexLike:$result)> {
 
   let assemblyFormat = [{ $lhs `,` $rhs (`overflow` `` $overflowFlags^)?
                           attr-dict `:` type($result) }];
@@ -162,10 +170,10 @@ class Arith_IntBinaryOpWithExactFlag<string mnemonic, list<Trait> traits = []> :
     Arith_BinaryOp<mnemonic, traits #
       [DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
        DeclareOpInterfaceMethods<ArithExactFlagInterface>]>,
-    Arguments<(ins SignlessIntegerOrIndexLike:$lhs,
-               SignlessIntegerOrIndexLike:$rhs,
+    Arguments<(ins Arith_SignlessIntegerOrIndexLike:$lhs,
+               Arith_SignlessIntegerOrIndexLike:$rhs,
                UnitAttr:$isExact)>,
-    Results<(outs SignlessIntegerOrIndexLike:$result)> {
+    Results<(outs Arith_SignlessIntegerOrIndexLike:$result)> {
 
   let assemblyFormat = [{ $lhs `,` $rhs (`exact` $isExact^)?
                           attr-dict `:` type($result) }];
@@ -294,8 +302,8 @@ def Arith_AddUIExtendedOp : Arith_Op<"addui_extended", [Pure, Commutative,
     ```
   }];
 
-  let arguments = (ins SignlessIntegerOrIndexLike:$lhs, SignlessIntegerOrIndexLike:$rhs);
-  let results = (outs SignlessIntegerOrIndexLike:$sum, BoolLike:$overflow);
+  let arguments = (ins Arith_SignlessIntegerOrIndexLike:$lhs, Arith_SignlessIntegerOrIndexLike:$rhs);
+  let results = (outs Arith_SignlessIntegerOrIndexLike:$sum, BoolLike:$overflow);
   let assemblyFormat = [{
     $lhs `,` $rhs attr-dict `:` type($sum) `,` type($overflow)
   }];
@@ -435,8 +443,8 @@ def Arith_MulSIExtendedOp : Arith_Op<"mulsi_extended", [Pure, Commutative,
     ```
   }];
 
-  let arguments = (ins SignlessIntegerOrIndexLike:$lhs, SignlessIntegerOrIndexLike:$rhs);
-  let results = (outs SignlessIntegerOrIndexLike:$low, SignlessIntegerOrIndexLike:$high);
+  let arguments = (ins Arith_SignlessIntegerOrIndexLike:$lhs, Arith_SignlessIntegerOrIndexLike:$rhs);
+  let results = (outs Arith_SignlessIntegerOrIndexLike:$low, Arith_SignlessIntegerOrIndexLike:$high);
 
   let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs)";
 
@@ -478,8 +486,8 @@ def Arith_MulUIExtendedOp : Arith_Op<"mului_extended", [Pure, Commutative,
     ```
   }];
 
-  let arguments = (ins SignlessIntegerOrIndexLike:$lhs, SignlessIntegerOrIndexLike:$rhs);
-  let results = (outs SignlessIntegerOrIndexLike:$low, SignlessIntegerOrIndexLike:$high);
+  let arguments = (ins Arith_SignlessIntegerOrIndexLike:$lhs, Arith_SignlessIntegerOrIndexLike:$rhs);
+  let results = (outs Arith_SignlessIntegerOrIndexLike:$low, Arith_SignlessIntegerOrIndexLike:$high);
 
   let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs)";
 
@@ -1574,9 +1582,9 @@ def Arith_FPToSIOp : Arith_FToICastOp<"fptosi"> {
 
 // Index cast can convert between memrefs of signless integers and indices too.
 def IndexCastTypeConstraint : TypeConstraint<Or<[
-        SignlessIntegerOrIndexLike.predicate,
-        MemRefOf<[AnySignlessInteger, Index]>.predicate]>,
-    "signless-integer-like or memref of signless-integer">;
+        Arith_SignlessIntegerOrIndexLike.predicate,
+        MemRefOf<[AnyNonZeroBitwidthSignlessInteger, Index]>.predicate]>,
+    "signless-non-zero-bitwidth-integer-like or memref of signless-integer">;
 
 def Arith_IndexCastOp
   : Arith_CastOp<"index_cast", IndexCastTypeConstraint, IndexCastTypeConstraint,
@@ -1664,10 +1672,12 @@ def Arith_IndexCastUIOp
 //===----------------------------------------------------------------------===//
 
 // Bitcast can convert between memrefs of signless integers and floats.
+// i0 (zero-bitwidth) integers are excluded: a 0-bit bitcast is meaningless.
 def BitcastTypeConstraint : TypeConstraint<Or<[
-        SignlessIntegerOrFloatLike.predicate,
-        MemRefOf<[AnySignlessInteger, AnyFloat]>.predicate]>,
-    "signless-integer-or-float-like or memref of signless-integer or float">;
+        SignlessFixedWidthIntegerLike.predicate,
+        FloatLike.predicate,
+        MemRefOf<[AnyNonZeroBitwidthSignlessInteger, AnyFloat]>.predicate]>,
+    "non-zero-bitwidth-signless-integer-or-float-like or memref of non-zero-bitwidth signless integer or float">;
 
 def Arith_BitcastOp : Arith_CastOp<"bitcast", BitcastTypeConstraint,
                                               BitcastTypeConstraint> {
@@ -1766,8 +1776,8 @@ def Arith_CmpIOp : Arith_CompareOp<"cmpi",
   }];
 
   let arguments = (ins Arith_CmpIPredicateAttr:$predicate,
-                       SignlessIntegerOrIndexLike:$lhs,
-                       SignlessIntegerOrIndexLike:$rhs);
+                       Arith_SignlessIntegerOrIndexLike:$lhs,
+                       Arith_SignlessIntegerOrIndexLike:$rhs);
 
   let hasFolder = 1;
   let hasCanonicalizer = 1;
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index ddb04b6bbe40d..43ad435ccf1c1 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -32,6 +32,17 @@ include "mlir/Interfaces/ViewLikeInterface.td"
 include "mlir/IR/BuiltinAttributes.td"
 include "mlir/IR/EnumAttr.td"
 
+// Type constraint helpers for the vector dialect.
+
+// Any vector of any rank whose element type is not i0 (zero-bitwidth integer).
+// Floats, index, and integers with width >= 1 are all accepted.
+def AnyVectorOfNonI0Elem : VectorOfAnyRankOf<[
+    Type<CPred<"!$_self.isInteger(0)">, "non-zero-bitwidth type">]>;
+
+// Like AnyVectorOfNonI0Elem but additionally requires rank >= 1.
+def AnyVectorOfNonZeroRankNonI0Elem : VectorOfNonZeroRankOf<[
+    Type<CPred<"!$_self.isInteger(0)">, "non-zero-bitwidth type">]>;
+
 // TODO: Add an attribute to specify a different algebra with operators other
 // than the current set: {*, +}.
 def Vector_ContractionOp :
@@ -44,7 +55,7 @@ def Vector_ContractionOp :
       DeclareOpInterfaceMethods<MaskableOpInterface>,
       DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
     ]>,
-    Arguments<(ins AnyVectorOfNonZeroRank:$lhs, AnyVectorOfNonZeroRank:$rhs, AnyType:$acc,
+    Arguments<(ins AnyVectorOfNonZeroRankNonI0Elem:$lhs, AnyVectorOfNonZeroRankNonI0Elem:$rhs, AnyType:$acc,
                ArrayAttr:$indexing_maps,
                Vector_IteratorTypeArrayAttr:$iterator_types,
                DefaultValuedAttr<Vector_CombiningKindAttr,
@@ -239,7 +250,7 @@ def Vector_ReductionOp :
      DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
     ]>,
     Arguments<(ins Vector_CombiningKindAttr:$kind,
-               AnyVectorOfAnyRank:$vector,
+               AnyVectorOfNonI0Elem:$vector,
                Optional<AnyType>:$acc,
                DefaultValuedAttr<
                  Arith_FastMathAttr,
@@ -299,7 +310,7 @@ def Vector_MultiDimReductionOp :
      DeclareOpInterfaceMethods<VectorUnrollOpInterface,
                                ["getShapeForUnroll"]>]>,
     Arguments<(ins Vector_CombiningKindAttr:$kind,
-                   AnyVectorOfNonZeroRank:$source,
+                   AnyVectorOfNonZeroRankNonI0Elem:$source,
                    AnyType:$acc,
                    DenseI64ArrayAttr:$reduction_dims)>,
     Results<(outs AnyType:$dest)> {
@@ -1094,10 +1105,10 @@ def Vector_OuterProductOp :
     PredOpTrait<"rhs operand and result have same element type",
                 TCresVTEtIsSameAsOpBase<0, 1>>,
     DeclareOpInterfaceMethods<MaskableOpInterface>]>,
-    Arguments<(ins AnyVectorOfNonZeroRank:$lhs, AnyType:$rhs,
-               Optional<AnyVectorOfNonZeroRank>:$acc,
+    Arguments<(ins AnyVectorOfNonZeroRankNonI0Elem:$lhs, AnyType:$rhs,
+               Optional<AnyVectorOfNonZeroRankNonI0Elem>:$acc,
                DefaultValuedAttr<Vector_CombiningKindAttr, "CombiningKind::ADD">:$kind)>,
-    Results<(outs AnyVectorOfNonZeroRank)> {
+    Results<(outs AnyVectorOfNonZeroRankNonI0Elem)> {
   let summary = "vector outerproduct with optional fused add";
   let description = [{
     Takes 2 1-D vectors and returns the 2-D vector containing the outer-product,
@@ -2454,8 +2465,8 @@ def Vector_ShapeCastOp :
 
 def Vector_BitCastOp :
   Vector_Op<"bitcast", [Pure, AllRanksMatch<["source", "result"]>]>,
-    Arguments<(ins AnyVectorOfAnyRank:$source)>,
-    Results<(outs AnyVectorOfAnyRank:$result)>{
+    Arguments<(ins AnyVectorOfNonI0Elem:$source)>,
+    Results<(outs AnyVectorOfNonI0Elem:$result)>{
   let summary = "bitcast casts between vectors";
   let description = [{
     The bitcast operation casts between vectors of the same rank, the minor 1-D
@@ -2938,12 +2949,12 @@ def Vector_ScanOp :
     AllTypesMatch<["source", "dest"]>,
     AllTypesMatch<["initial_value", "accumulated_value"]> ]>,
     Arguments<(ins Vector_CombiningKindAttr:$kind,
-                   AnyVectorOfNonZeroRank:$source,
-                   AnyVectorOfAnyRank:$initial_value,
+                   AnyVectorOfNonZeroRankNonI0Elem:$source,
+                   AnyVectorOfNonI0Elem:$initial_value,
                    I64Attr:$reduction_dim,
                    BoolAttr:$inclusive)>,
-    Results<(outs AnyVectorOfNonZeroRank:$dest,
-                  AnyVectorOfAnyRank:$accumulated_value)> {
+    Results<(outs AnyVectorOfNonZeroRankNonI0Elem:$dest,
+                  AnyVectorOfNonI0Elem:$accumulated_value)> {
   let summary = "Scan operation";
   let description = [{
     Performs an inclusive/exclusive scan on an n-D vector along a single
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index a49880b81e90d..57caaae08462f 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -308,6 +308,17 @@ def Index : Type<CPred<"::llvm::isa<::mlir::IndexType>($_self)">, "index",
 def AnySignlessIntegerOrIndex : Type<CPred<"$_self.isSignlessIntOrIndex()">,
                                      "signless integer or index">;
 
+// A signless integer type with a non-zero bitwidth (excludes i0).
+def AnyNonZeroBitwidthSignlessInteger : ConfinedType<
+    AnySignlessInteger, [CPred<"!$_self.isInteger(0)">],
+    "non-zero-bitwidth signless integer">;
+
+// A non-zero-bitwidth signless integer or index type.
+def AnyNonZeroBitwidthSignlessIntegerOrIndex : Type<
+    Or<[AnyNonZeroBitwidthSignlessInteger.predicate,
+        Index.predicate]>,
+    "non-zero-bitwidth signless integer or index">;
+
 // Floating point types.
 
 // Any float type irrespective of its width.
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index a2d0eff47ad92..326afcae696cc 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -3385,87 +3385,6 @@ func.func @unsignedExtendConstantResource() -> tensor<i16> {
   return %ext : tensor<i16>
 }
 
-// CHECK-LABEL: @extsi_i0
-//       CHECK:   %[[ZERO:.*]] = arith.constant 0 : i16
-//       CHECK:   return %[[ZERO]] : i16
-func.func @extsi_i0() -> i16 {
-  %c0 = arith.constant 0 : i0
-  %extsi = arith.extsi %c0 : i0 to i16
-  return %extsi : i16
-}
-
-// CHECK-LABEL: @extui_i0
-//       CHECK:   %[[ZERO:.*]] = arith.constant 0 : i16
-//       CHECK:   return %[[ZERO]] : i16
-func.func @extui_i0() -> i16 {
-  %c0 = arith.constant 0 : i0
-  %extui = arith.extui %c0 : i0 to i16
-  return %extui : i16
-}
-
-// CHECK-LABEL: @trunc_i0
-//       CHECK:   %[[ZERO:.*]] = arith.constant 0 : i0
-//       CHECK:   return %[[ZERO]] : i0
-func.func @trunc_i0() -> i0 {
-  %cFF = arith.constant 0xFF : i8
-  %trunc = arith.trunci %cFF : i8 to i0
-  return %trunc : i0
-}
-
-// CHECK-LABEL: @shli_i0
-//       CHECK:   %[[ZERO:.*]] = arith.constant 0 : i0
-//       CHECK:   return %[[ZERO]] : i0
-func.func @shli_i0() -> i0 {
-  %c0 = arith.constant 0 : i0
-  %shli = arith.shli %c0, %c0 : i0
-  return %shli : i0
-}
-
-// CHECK-LABEL: @shrsi_i0
-//       CHECK:   %[[ZERO:.*]] = arith.constant 0 : i0
-//       CHECK:   return %[[ZERO]] : i0
-func.func @shrsi_i0() -> i0 {
-  %c0 = arith.constant 0 : i0
-  %shrsi = arith.shrsi %c0, %c0 : i0
-  return %shrsi : i0
-}
-
-// CHECK-LABEL: @shrui_i0
-//       CHECK:   %[[ZERO:.*]] = arith.constant 0 : i0
-//       CHECK:   return %[[ZERO]] : i0
-func.func @shrui_i0() -> i0 {
-  %c0 = arith.constant 0 : i0
-  %shrui = arith.shrui %c0, %c0 : i0
-  return %shrui : i0
-}
-
-// CHECK-LABEL: @maxsi_i0
-//       CHECK:   %[[ZERO:.*]] = arith.constant 0 : i0
-//       CHECK:   return %[[ZERO]] : i0
-func.func @maxsi_i0() -> i0 {
-  %c0 = arith.constant 0 : i0
-  %maxsi = arith.maxsi %c0, %c0 : i0
-  return %maxsi : i0
-}
-
-// CHECK-LABEL: @minsi_i0
-//       CHECK:   %[[ZERO:.*]] = arith.constant 0 : i0
-//       CHECK:   return %[[ZERO]] : i0
-func.func @minsi_i0() -> i0 {
-  %c0 = arith.constant 0 : i0
-  %minsi = arith.minsi %c0, %c0 : i0
-  return %minsi : i0
-}
-
-// CHECK-LABEL: @mulsi_extended_i0
-//       CHECK:   %[[ZERO:.*]] = arith.constant 0 : i0
-//       CHECK:   return %[[ZERO]], %[[ZERO]] : i0
-func.func @mulsi_extended_i0() -> (i0, i0) {
-  %c0 = arith.constant 0 : i0
-  %mulsi_extended:2 = arith.mulsi_extended %c0, %c0 : i0
-  return %mulsi_extended#0, %mulsi_extended#1 : i0, i0
-}
-
 // CHECK-LABEL: @sequences_fastmath_contract
 // CHECK-SAME: ([[ARG0:%.+]]: bf16)
 // CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
diff --git a/mlir/test/Dialect/Arith/invalid.mlir b/mlir/test/Dialect/Arith/invalid.mlir
index 70b23e56a712c..0ea614e0d4b97 100644
--- a/mlir/test/Dialect/Arith/invalid.mlir
+++ b/mlir/test/Dialect/Arith/invalid.mlir
@@ -130,14 +130,14 @@ func.func @func_with_ops(f32) {
 
 func.func @func_with_ops(f32) {
 ^bb0(%a : f32):
-  // expected-error at +1 {{'arith.addi' op operand #0 must be signless-integer-like}}
+  // expected-error at +1 {{'arith.addi' op operand #0 must be signless-non-zero-bitwidth-integer-like}}
   %sf = arith.addi %a, %a : f32
 }
 
 // -----
 
 func.func @func_with_ops(%a: f32) {
-  // expected-error at +1 {{'arith.addui_extended' op operand #0 must be signless-integer-like}}
+  // expected-error at +1 {{'arith.addui_extended' op operand #0 must be signless-non-zero-bitwidth-integer-like}}
   %r:2 = arith.addui_extended %a, %a : f32, i32
   return
 }
@@ -202,7 +202,7 @@ func.func @func_with_ops(i32, i32) {
 // Integer comparisons are not recognized for float types.
 func.func @func_with_ops(f32, f32) {
 ^bb0(%a : f32, %b : f32):
-  %r = arith.cmpi eq, %a, %b : f32 // expected-error {{'lhs' must be signless-integer-like, but got 'f32'}}
+  %r = arith.cmpi eq, %a, %b : f32 // expected-error {{'lhs' must be signless-non-zero-bitwidth-integer-like, but got 'f32'}}
 }
 
 // -----
@@ -242,7 +242,7 @@ func.func @func_with_ops() {
 // -----
 
 func.func @invalid_cmp_shape(%idx : () -> ()) {
-  // expected-error at +1 {{'lhs' must be signless-integer-like, but got '() -> ()'}}
+  // expected-error at +1 {{'lhs' must be signless-non-zero-bitwidth-integer-like, but got '() -> ()'}}
   %cmp = arith.cmpi eq, %idx, %idx : () -> ()
 
 // -----
@@ -352,7 +352,7 @@ func.func @index_cast_index_to_index(%arg0: index) {
 // -----
 
 func.func @index_cast_float(%arg0: index, %arg1: f32) {
-  // expected-error at +1 {{op result #0 must be signless-integer-like or memref of signless-integer, but got 'f32'}}
+  // expected-error at +1 {{op result #0 must be signless-non-zero-bitwidth-integer-like or memref of signless-integer, but got 'f32'}}
   %0 = arith.index_cast %arg0 : index to f32
   return
 }
@@ -360,7 +360,7 @@ func.func @index_cast_float(%arg0: index, %arg1: f32) {
 // -----
 
 func.func @index_cast_float_to_index(%arg0: f32) {
-  // expected-error at +1 {{op operand #0 must be signless-integer-like or memref of signless-integer, but got 'f32'}}
+  // expected-error at +1 {{op operand #0 must be signless-non-zero-bitwidth-integer-like or memref of signless-integer, but got 'f32'}}
   %0 = arith.index_cast %arg0 : f32 to index
   return
 }
@@ -857,7 +857,7 @@ func.func @select_tensor_encoding(
 // -----
 
 func.func @bitcast_index_0(%arg0 : i64) -> index {
-  // expected-error @+1 {{'arith.bitcast' op result #0 must be signless-integer-or-float-like or memref of signless-integer or float, but got 'index'}}
+  // expected-error @+1 {{'arith.bitcast' op result #0 must be non-zero-bitwidth-signless-integer-or-float-like or memref of non-zero-bitwidth signless integer or float, but got 'index'}}
   %0 = arith.bitcast %arg0 : i64 to index
   return %0 : index
 }
@@ -865,7 +865,7 @@ func.func @bitcast_index_0(%arg0 : i64) -> index {
 // -----
 
 func.func @bitcast_index_1(%arg0 : index) -> i64 {
-  // expected-error @+1 {{'arith.bitcast' op operand #0 must be signless-integer-or-float-like or memref of signless-integer or float, but got 'index'}}
+  // expected-error @+1 {{'arith.bitcast' op operand #0 must be non-zero-bitwidth-signless-integer-or-float-like or memref of non-zero-bitwidth signless integer or float, but got 'index'}}
   %0 = arith.bitcast %arg0 : index to i64
   return %0 : i64
 }
@@ -877,3 +877,142 @@ func.func @select_vector_condition_scalar_operands(%arg0: vector<1xi1>, %arg1: i
   %0 = arith.select %arg0, %arg1, %arg1 : vector<1xi1>, i32
   return
 }
+
+// -----
+
+// Verify that i0 (zero-bitwidth integer) is rejected by arith integer ops.
+
+func.func @addi_i0(%a: i0, %b: i0) -> i0 {
+  // expected-error @+1 {{'arith.addi' op operand #0 must be signless-non-zero-bitwidth-integer-like, but got 'i0'}}
+  %0 = arith.addi %a, %b : i0
+  return %0 : i0
+}
+
+// -----
+
+func.func @addi_vector_i0(%a: vector<4xi0>, %b: vector<4xi0>) -> vector<4xi0> {
+  // expected-error @+1 {{'arith.addi' op operand #0 must be signless-non-zero-bitwidth-integer-like, but got 'vector<4xi0>'}}
+  %0 = arith.addi %a, %b : vector<4xi0>
+  return %0 : vector<4xi0>
+}
+
+// -----
+
+func.func @trunci_to_i0(%a: i32) -> i0 {
+  // expected-error @+1 {{'arith.trunci' op result #0 must be signless-fixed-width-integer-like, but got 'i0'}}
+  %0 = arith.trunci %a : i32 to i0
+  return %0 : i0
+}
+
+// -----
+
+func.func @extsi_from_i0(%a: i0) -> i16 {
+  // expected-error @+1 {{'arith.extsi' op operand #0 must be signless-fixed-width-integer-like, but got 'i0'}}
+  %0 = arith.extsi %a : i0 to i16
+  return %0 : i16
+}
+
+// -----
+
+func.func @extui_from_i0(%a: i0) -> i16 {
+  // expected-error @+1 {{'arith.extui' op operand #0 must be signless-fixed-width-integer-like, but got 'i0'}}
+  %0 = arith.extui %a : i0 to i16
+  return %0 : i16
+}
+
+// -----
+
+func.func @cmpi_i0(%a: i0, %b: i0) -> i1 {
+  // expected-error @+1 {{'lhs' must be signless-non-zero-bitwidth-integer-like, but got 'i0'}}
+  %0 = arith.cmpi eq, %a, %b : i0
+  return %0 : i1
+}
+
+// -----
+
+// Arith_TotalIntBinaryOp (andi, ori, xori, maxsi, maxui, minsi, minui,
+// floordivsi, remui, remsi).
+func.func @andi_i0(%a: i0, %b: i0) -> i0 {
+  // expected-error @+1 {{'arith.andi' op operand #0 must be signless-non-zero-bitwidth-integer-like, but got 'i0'}}
+  %0 = arith.andi %a, %b : i0
+  return %0 : i0
+}
+
+// -----
+
+// Arith_IntBinaryOpWithExactFlag (divsi, divui, shrsi, shrui).
+func.func @divsi_i0(%a: i0, %b: i0) -> i0 {
+  // expected-error @+1 {{'arith.divsi' op operand #0 must be signless-non-zero-bitwidth-integer-like, but got 'i0'}}
+  %0 = arith.divsi %a, %b : i0
+  return %0 : i0
+}
+
+// -----
+
+// Arith_IntBinaryOp (ceildivsi, ceildivui).
+func.func @ceildivsi_i0(%a: i0, %b: i0) -> i0 {
+  // expected-error @+1 {{'arith.ceildivsi' op operand #0 must be signless-non-zero-bitwidth-integer-like, but got 'i0'}}
+  %0 = arith.ceildivsi %a, %b : i0
+  return %0 : i0
+}
+
+// -----
+
+func.func @mulsi_extended_i0(%a: i0, %b: i0) -> (i0, i0) {
+  // expected-error @+1 {{'arith.mulsi_extended' op operand #0 must be signless-non-zero-bitwidth-integer-like, but got 'i0'}}
+  %0:2 = arith.mulsi_extended %a, %b : i0
+  return %0#0, %0#1 : i0, i0
+}
+
+// -----
+
+func.func @mului_extended_i0(%a: i0, %b: i0) -> (i0, i0) {
+  // expected-error @+1 {{'arith.mului_extended' op operand #0 must be signless-non-zero-bitwidth-integer-like, but got 'i0'}}
+  %0:2 = arith.mului_extended %a, %b : i0
+  return %0#0, %0#1 : i0, i0
+}
+
+// -----
+
+// IToFCastOp (sitofp, uitofp) — cast FROM i0.
+func.func @sitofp_i0(%a: i0) -> f32 {
+  // expected-error @+1 {{'arith.sitofp' op operand #0 must be signless-fixed-width-integer-like, but got 'i0'}}
+  %0 = arith.sitofp %a : i0 to f32
+  return %0 : f32
+}
+
+// -----
+
+// FToICastOp (fptosi, fptoui) — cast TO i0.
+func.func @fptosi_i0(%a: f32) -> i0 {
+  // expected-error @+1 {{'arith.fptosi' op result #0 must be signless-fixed-width-integer-like, but got 'i0'}}
+  %0 = arith.fptosi %a : f32 to i0
+  return %0 : i0
+}
+
+// -----
+
+// arith.bitcast rejects i0 source and result.
+func.func @bitcast_i0(%a: i0) -> i0 {
+  // expected-error @+1 {{'arith.bitcast' op operand #0 must be non-zero-bitwidth-signless-integer-or-float-like or memref of non-zero-bitwidth signless integer or float, but got 'i0'}}
+  %0 = arith.bitcast %a : i0 to i0
+  return %0 : i0
+}
+
+// -----
+
+// arith.index_cast rejects i0.
+func.func @index_cast_i0(%a: i0) -> index {
+  // expected-error @+1 {{'arith.index_cast' op operand #0 must be signless-non-zero-bitwidth-integer-like or memref of signless-integer, but got 'i0'}}
+  %0 = arith.index_cast %a : i0 to index
+  return %0 : index
+}
+
+// -----
+
+// arith.index_castui rejects i0.
+func.func @index_castui_i0(%a: i0) -> index {
+  // expected-error @+1 {{'arith.index_castui' op operand #0 must be signless-non-zero-bitwidth-integer-like or memref of signless-integer, but got 'i0'}}
+  %0 = arith.index_castui %a : i0 to index
+  return %0 : index
+}
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 9b2a2bdc19a57..333d342d76103 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -2064,3 +2064,59 @@ func.func @store_non_pow_of_2_alignment(%memref: memref<4xi32>, %val: vector<4xi
   vector.store %val, %memref[%c0] { alignment = 3 } : memref<4xi32>, vector<4xi32>
   return
 }
+
+// -----
+
+// Verify that vector.bitcast rejects vectors with i0 (zero-bitwidth) element type.
+func.func @bitcast_i0(%a: vector<4xi0>) -> vector<4xi0> {
+  // expected-error @+1 {{'vector.bitcast' op operand #0 must be vector of non-zero-bitwidth type values, but got 'vector<4xi0>'}}
+  %0 = vector.bitcast %a : vector<4xi0> to vector<4xi0>
+  return %0 : vector<4xi0>
+}
+
+// -----
+
+func.func @reduction_i0(%a: vector<4xi0>) -> i0 {
+  // expected-error @+1 {{'vector.reduction' op operand #0 must be vector of non-zero-bitwidth type values, but got 'vector<4xi0>'}}
+  %0 = vector.reduction <add>, %a : vector<4xi0> into i0
+  return %0 : i0
+}
+
+// -----
+
+func.func @multi_reduction_i0(%a: vector<4x8xi0>, %acc: vector<4xi0>) -> vector<4xi0> {
+  // expected-error @+1 {{'vector.multi_reduction' op operand #0 must be vector of non-zero-bitwidth type values, but got 'vector<4x8xi0>'}}
+  %0 = vector.multi_reduction <add>, %a, %acc [1] : vector<4x8xi0> to vector<4xi0>
+  return %0 : vector<4xi0>
+}
+
+// -----
+
+func.func @contract_i0(%lhs: vector<4xi0>, %rhs: vector<4xi0>, %acc: i0) -> i0 {
+  // expected-error @+1 {{'vector.contract' op operand #0 must be vector of non-zero-bitwidth type values, but got 'vector<4xi0>'}}
+  %0 = vector.contract {
+    indexing_maps = [affine_map<(d0) -> (d0)>,
+                     affine_map<(d0) -> (d0)>,
+                     affine_map<(d0) -> ()>],
+    iterator_types = ["reduction"],
+    kind = #vector.kind<add>
+  } %lhs, %rhs, %acc : vector<4xi0>, vector<4xi0> into i0
+  return %0 : i0
+}
+
+// -----
+
+func.func @outerproduct_i0(%lhs: vector<4xi0>, %rhs: i0) -> vector<4xi0> {
+  // expected-error @+1 {{'vector.outerproduct' op operand #0 must be vector of non-zero-bitwidth type values, but got 'vector<4xi0>'}}
+  %0 = vector.outerproduct %lhs, %rhs : vector<4xi0>, i0
+  return %0 : vector<4xi0>
+}
+
+// -----
+
+func.func @scan_i0(%a: vector<4xi0>, %init: vector<1xi0>) -> (vector<4xi0>, vector<1xi0>) {
+  // expected-error @+1 {{'vector.scan' op operand #0 must be vector of non-zero-bitwidth type values, but got 'vector<4xi0>'}}
+  %0:2 = vector.scan <add>, %a, %init {inclusive = true, reduction_dim = 0 : i64} :
+    vector<4xi0>, vector<1xi0>
+  return %0#0, %0#1 : vector<4xi0>, vector<1xi0>
+}



More information about the Mlir-commits mailing list