[Mlir-commits] [mlir] [mlir][test] Extend `InferIntRangeInterface` test Ops to arbitrary ints (PR #91850)

Felix Schneider llvmlistbot at llvm.org
Sat May 11 04:02:02 PDT 2024


https://github.com/ubfx created https://github.com/llvm/llvm-project/pull/91850

This PR is in preparation to some extensions to the `InferIntRangeInterface` around the `nsw` and `nuw` flags supported in the `arith` dialect and LLVM.

We provide some common inference logic for `index` and `arith` in `InferIntRangeCommon.h` but our Test Ops are currently fixed to `Index` Types. As we test the range inference for arith Ops, especially around the overflow behaviour, it's handy to have native support for the typical integer types in the test Ops.

This patch
1. Changes the Attributes of `test.with_bounds` ops from `Index` to `APInt` which matches the internal representation in `ConstantIntRanges`.
2. Allows the use of `AnyInteger` in addition to `Index` for the operands and results of the test Ops. This now requires explicit specification of the type in the IR, where before `Index` was implicit.
3. Requires bounds Attrs to be specified in the precision of the SSA value, eliminating any implicit truncation or extension. (*Could this lead to problems?*)

>From 9daa1d57bf4d092c9befe09e88aae1109ec659f1 Mon Sep 17 00:00:00 2001
From: Felix Schneider <fx.schn at gmail.com>
Date: Sat, 11 May 2024 12:46:32 +0200
Subject: [PATCH] [mlir][test] Extend `InferIntRangeInterface` test Ops to
 arbitrary ints

This PR is in preparation to some extensions to the `InferIntRangeInterface`
around the `nsw` and `nuw` flags supported in the `arith` dialect and LLVM.

We provide some common inference logic for `index` and `arith` in
`InferIntRangeCommon.h` but our Test Ops are currently fixed to `Index`
Types. As we test the range inference for arith Ops, especially around
the overflow behaviour, it's handy to have native support for the typical
integer types in the test Ops.

This patch
1. Changes the Attributes of `test.with_bounds` ops from `Index` to `APInt`
which matches the internal representation in `ConstantIntRanges`.
2. Allows the use of `AnyInteger` in addition to `Index` for the operands
and results of the test Ops. This now requires explicit specification of
the type in the IR, where before `Index` was implicit.
3. Requires bounds Attrs to be specified in the precision of the SSA value,
eliminating any implicit truncation or extension. (*Could this lead to
problems?*)
---
 .../Dialect/Arith/int-range-interface.mlir    |  10 ++
 mlir/test/Dialect/Arith/int-range-opts.mlir   |  12 +-
 .../test/Dialect/GPU/int-range-interface.mlir | 106 +++++++++---------
 .../infer-int-range-test-ops.mlir             |  56 ++++-----
 mlir/test/lib/Dialect/Test/TestOpDefs.cpp     |  14 +--
 mlir/test/lib/Dialect/Test/TestOps.td         |  50 +++++----
 6 files changed, 130 insertions(+), 118 deletions(-)

diff --git a/mlir/test/Dialect/Arith/int-range-interface.mlir b/mlir/test/Dialect/Arith/int-range-interface.mlir
index 02a9827d19d8f..16524b3634723 100644
--- a/mlir/test/Dialect/Arith/int-range-interface.mlir
+++ b/mlir/test/Dialect/Arith/int-range-interface.mlir
@@ -756,3 +756,13 @@ func.func private @callee(%arg0: memref<?xindex, 4>) {
   }
   return
 }
+
+// CHECK-LABEL: func @test_i8_bounds
+// CHECK: test.reflect_bounds {smax = 127 : i8, smin = -128 : i8, umax = -1 : i8, umin = 0 : i8}
+func.func @test_i8_bounds() -> i8 {
+  %cst1 = arith.constant 1 : i8
+  %0 = test.with_bounds { umin = 0 : i8, umax = 255 : i8, smin = -128 : i8, smax = 127 : i8 } : i8
+  %1 = arith.addi %0, %cst1 : i8
+  %2 = test.reflect_bounds %1 : i8
+  return %2: i8
+}
diff --git a/mlir/test/Dialect/Arith/int-range-opts.mlir b/mlir/test/Dialect/Arith/int-range-opts.mlir
index be0a7e8ccd70b..9eb50cf6c3777 100644
--- a/mlir/test/Dialect/Arith/int-range-opts.mlir
+++ b/mlir/test/Dialect/Arith/int-range-opts.mlir
@@ -5,7 +5,7 @@
 //       CHECK:   return %[[C]]
 func.func @test() -> i1 {
   %cst1 = arith.constant -1 : index
-  %0 = test.with_bounds { umin = 0 : index, umax = 0x7fffffffffffffff : index, smin = 0 : index, smax = 0x7fffffffffffffff : index }
+  %0 = test.with_bounds { umin = 0 : index, umax = 0x7fffffffffffffff : index, smin = 0 : index, smax = 0x7fffffffffffffff : index } : index
   %1 = arith.cmpi eq, %0, %cst1 : index
   return %1: i1
 }
@@ -17,7 +17,7 @@ func.func @test() -> i1 {
 //       CHECK:   return %[[C]]
 func.func @test() -> i1 {
   %cst1 = arith.constant -1 : index
-  %0 = test.with_bounds { umin = 0 : index, umax = 0x7fffffffffffffff : index, smin = 0 : index, smax = 0x7fffffffffffffff : index }
+  %0 = test.with_bounds { umin = 0 : index, umax = 0x7fffffffffffffff : index, smin = 0 : index, smax = 0x7fffffffffffffff : index } : index
   %1 = arith.cmpi ne, %0, %cst1 : index
   return %1: i1
 }
@@ -30,7 +30,7 @@ func.func @test() -> i1 {
 //       CHECK:   return %[[C]]
 func.func @test() -> i1 {
   %cst = arith.constant 0 : index
-  %0 = test.with_bounds { umin = 0 : index, umax = 0x7fffffffffffffff : index, smin = 0 : index, smax = 0x7fffffffffffffff : index }
+  %0 = test.with_bounds { umin = 0 : index, umax = 0x7fffffffffffffff : index, smin = 0 : index, smax = 0x7fffffffffffffff : index } : index
   %1 = arith.cmpi sge, %0, %cst : index
   return %1: i1
 }
@@ -42,7 +42,7 @@ func.func @test() -> i1 {
 //       CHECK:   return %[[C]]
 func.func @test() -> i1 {
   %cst = arith.constant 0 : index
-  %0 = test.with_bounds { umin = 0 : index, umax = 0x7fffffffffffffff : index, smin = 0 : index, smax = 0x7fffffffffffffff : index }
+  %0 = test.with_bounds { umin = 0 : index, umax = 0x7fffffffffffffff : index, smin = 0 : index, smax = 0x7fffffffffffffff : index } : index
   %1 = arith.cmpi slt, %0, %cst : index
   return %1: i1
 }
@@ -55,7 +55,7 @@ func.func @test() -> i1 {
 //       CHECK:   return %[[C]]
 func.func @test() -> i1 {
   %cst1 = arith.constant -1 : index
-  %0 = test.with_bounds { umin = 0 : index, umax = 0x7fffffffffffffff : index, smin = 0 : index, smax = 0x7fffffffffffffff : index }
+  %0 = test.with_bounds { umin = 0 : index, umax = 0x7fffffffffffffff : index, smin = 0 : index, smax = 0x7fffffffffffffff : index } : index
   %1 = arith.cmpi sgt, %0, %cst1 : index
   return %1: i1
 }
@@ -67,7 +67,7 @@ func.func @test() -> i1 {
 //       CHECK:   return %[[C]]
 func.func @test() -> i1 {
   %cst1 = arith.constant -1 : index
-  %0 = test.with_bounds { umin = 0 : index, umax = 0x7fffffffffffffff : index, smin = 0 : index, smax = 0x7fffffffffffffff : index }
+  %0 = test.with_bounds { umin = 0 : index, umax = 0x7fffffffffffffff : index, smin = 0 : index, smax = 0x7fffffffffffffff : index } : index
   %1 = arith.cmpi sle, %0, %cst1 : index
   return %1: i1
 }
diff --git a/mlir/test/Dialect/GPU/int-range-interface.mlir b/mlir/test/Dialect/GPU/int-range-interface.mlir
index 02aec9dc0476f..980f7e5873e0c 100644
--- a/mlir/test/Dialect/GPU/int-range-interface.mlir
+++ b/mlir/test/Dialect/GPU/int-range-interface.mlir
@@ -5,46 +5,46 @@ func.func @launch_func(%arg0 : index) {
   %0 = test.with_bounds {
     umin = 3 : index, umax = 5 : index,
     smin = 3 : index, smax = 5 : index
-  }
+  } : index
   %1 = test.with_bounds {
     umin = 7 : index, umax = 11 : index,
     smin = 7 : index, smax = 11 : index
-  }
+  } : index
   gpu.launch blocks(%block_id_x, %block_id_y, %block_id_z) in (%grid_dim_x = %0, %grid_dim_y = %1, %grid_dim_z = %arg0)
       threads(%thread_id_x, %thread_id_y, %thread_id_z) in (%block_dim_x = %arg0, %block_dim_y = %0, %block_dim_z = %1) {
 
     // CHECK: test.reflect_bounds {smax = 5 : index, smin = 3 : index, umax = 5 : index, umin = 3 : index}
     // CHECK: test.reflect_bounds {smax = 11 : index, smin = 7 : index, umax = 11 : index, umin = 7 : index}
     // CHECK: test.reflect_bounds {smax = 4294967295 : index, smin = 1 : index, umax = 4294967295 : index, umin = 1 : index}
-    %grid_dim_x0 = test.reflect_bounds %grid_dim_x
-    %grid_dim_y0 = test.reflect_bounds %grid_dim_y
-    %grid_dim_z0 = test.reflect_bounds %grid_dim_z
+    %grid_dim_x0 = test.reflect_bounds %grid_dim_x : index
+    %grid_dim_y0 = test.reflect_bounds %grid_dim_y : index
+    %grid_dim_z0 = test.reflect_bounds %grid_dim_z : index
 
     // CHECK: test.reflect_bounds {smax = 4 : index, smin = 0 : index, umax = 4 : index, umin = 0 : index}
     // CHECK: test.reflect_bounds {smax = 10 : index, smin = 0 : index, umax = 10 : index, umin = 0 : index}
     // CHECK: test.reflect_bounds {smax = 4294967294 : index, smin = 0 : index, umax = 4294967294 : index, umin = 0 : index}
-    %block_id_x0 = test.reflect_bounds %block_id_x
-    %block_id_y0 = test.reflect_bounds %block_id_y
-    %block_id_z0 = test.reflect_bounds %block_id_z
+    %block_id_x0 = test.reflect_bounds %block_id_x : index
+    %block_id_y0 = test.reflect_bounds %block_id_y : index
+    %block_id_z0 = test.reflect_bounds %block_id_z : index
 
     // CHECK: test.reflect_bounds {smax = 4294967295 : index, smin = 1 : index, umax = 4294967295 : index, umin = 1 : index}
     // CHECK: test.reflect_bounds {smax = 5 : index, smin = 3 : index, umax = 5 : index, umin = 3 : index}
     // CHECK: test.reflect_bounds {smax = 11 : index, smin = 7 : index, umax = 11 : index, umin = 7 : index}
-    %block_dim_x0 = test.reflect_bounds %block_dim_x
-    %block_dim_y0 = test.reflect_bounds %block_dim_y
-    %block_dim_z0 = test.reflect_bounds %block_dim_z
+    %block_dim_x0 = test.reflect_bounds %block_dim_x : index
+    %block_dim_y0 = test.reflect_bounds %block_dim_y : index
+    %block_dim_z0 = test.reflect_bounds %block_dim_z : index
 
     // CHECK: test.reflect_bounds {smax = 4294967294 : index, smin = 0 : index, umax = 4294967294 : index, umin = 0 : index}
     // CHECK: test.reflect_bounds {smax = 4 : index, smin = 0 : index, umax = 4 : index, umin = 0 : index}
     // CHECK: test.reflect_bounds {smax = 10 : index, smin = 0 : index, umax = 10 : index, umin = 0 : index}
-    %thread_id_x0 = test.reflect_bounds %thread_id_x
-    %thread_id_y0 = test.reflect_bounds %thread_id_y
-    %thread_id_z0 = test.reflect_bounds %thread_id_z
+    %thread_id_x0 = test.reflect_bounds %thread_id_x : index
+    %thread_id_y0 = test.reflect_bounds %thread_id_y : index
+    %thread_id_z0 = test.reflect_bounds %thread_id_z : index
 
     // The launch bounds are not constant, and so this can't infer anything
     // CHECK: test.reflect_bounds {smax = 4294967294 : index, smin = 0 : index, umax = 4294967294 : index, umin = 0 : index}
     %thread_id_op = gpu.thread_id y
-    %thread_id_op0 = test.reflect_bounds %thread_id_op
+    %thread_id_op0 = test.reflect_bounds %thread_id_op : index
     gpu.terminator
   }
 
@@ -65,9 +65,9 @@ module attributes {gpu.container_module} {
       // CHECK: test.reflect_bounds {smax = 4294967295 : index, smin = 1 : index, umax = 4294967295 : index, umin = 1 : index}
       // CHECK: test.reflect_bounds {smax = 4294967295 : index, smin = 1 : index, umax = 4294967295 : index, umin = 1 : index}
       // CHECK: test.reflect_bounds {smax = 4294967295 : index, smin = 1 : index, umax = 4294967295 : index, umin = 1 : index}
-      %grid_dim_x0 = test.reflect_bounds %grid_dim_x
-      %grid_dim_y0 = test.reflect_bounds %grid_dim_y
-      %grid_dim_z0 = test.reflect_bounds %grid_dim_z
+      %grid_dim_x0 = test.reflect_bounds %grid_dim_x : index
+      %grid_dim_y0 = test.reflect_bounds %grid_dim_y : index
+      %grid_dim_z0 = test.reflect_bounds %grid_dim_z : index
 
       %block_id_x = gpu.block_id x
       %block_id_y = gpu.block_id y
@@ -76,9 +76,9 @@ module attributes {gpu.container_module} {
       // CHECK: test.reflect_bounds {smax = 4294967294 : index, smin = 0 : index, umax = 4294967294 : index, umin = 0 : index}
       // CHECK: test.reflect_bounds {smax = 4294967294 : index, smin = 0 : index, umax = 4294967294 : index, umin = 0 : index}
       // CHECK: test.reflect_bounds {smax = 4294967294 : index, smin = 0 : index, umax = 4294967294 : index, umin = 0 : index}
-      %block_id_x0 = test.reflect_bounds %block_id_x
-      %block_id_y0 = test.reflect_bounds %block_id_y
-      %block_id_z0 = test.reflect_bounds %block_id_z
+      %block_id_x0 = test.reflect_bounds %block_id_x : index
+      %block_id_y0 = test.reflect_bounds %block_id_y : index
+      %block_id_z0 = test.reflect_bounds %block_id_z : index
 
       %block_dim_x = gpu.block_dim x
       %block_dim_y = gpu.block_dim y
@@ -87,9 +87,9 @@ module attributes {gpu.container_module} {
       // CHECK: test.reflect_bounds {smax = 4294967295 : index, smin = 1 : index, umax = 4294967295 : index, umin = 1 : index}
       // CHECK: test.reflect_bounds {smax = 4294967295 : index, smin = 1 : index, umax = 4294967295 : index, umin = 1 : index}
       // CHECK: test.reflect_bounds {smax = 4294967295 : index, smin = 1 : index, umax = 4294967295 : index, umin = 1 : index}
-      %block_dim_x0 = test.reflect_bounds %block_dim_x
-      %block_dim_y0 = test.reflect_bounds %block_dim_y
-      %block_dim_z0 = test.reflect_bounds %block_dim_z
+      %block_dim_x0 = test.reflect_bounds %block_dim_x : index
+      %block_dim_y0 = test.reflect_bounds %block_dim_y : index
+      %block_dim_z0 = test.reflect_bounds %block_dim_z : index
 
       %thread_id_x = gpu.thread_id x
       %thread_id_y = gpu.thread_id y
@@ -98,9 +98,9 @@ module attributes {gpu.container_module} {
       // CHECK: test.reflect_bounds {smax = 4294967294 : index, smin = 0 : index, umax = 4294967294 : index, umin = 0 : index}
       // CHECK: test.reflect_bounds {smax = 4294967294 : index, smin = 0 : index, umax = 4294967294 : index, umin = 0 : index}
       // CHECK: test.reflect_bounds {smax = 4294967294 : index, smin = 0 : index, umax = 4294967294 : index, umin = 0 : index}
-      %thread_id_x0 = test.reflect_bounds %thread_id_x
-      %thread_id_y0 = test.reflect_bounds %thread_id_y
-      %thread_id_z0 = test.reflect_bounds %thread_id_z
+      %thread_id_x0 = test.reflect_bounds %thread_id_x : index
+      %thread_id_y0 = test.reflect_bounds %thread_id_y : index
+      %thread_id_z0 = test.reflect_bounds %thread_id_z : index
 
       %global_id_x = gpu.global_id x
       %global_id_y = gpu.global_id y
@@ -109,9 +109,9 @@ module attributes {gpu.container_module} {
       // CHECK: test.reflect_bounds {smax = 9223372036854775807 : index, smin = -9223372036854775808 : index, umax = -8589934592 : index, umin = 0 : index}
       // CHECK: test.reflect_bounds {smax = 9223372036854775807 : index, smin = -9223372036854775808 : index, umax = -8589934592 : index, umin = 0 : index}
       // CHECK: test.reflect_bounds {smax = 9223372036854775807 : index, smin = -9223372036854775808 : index, umax = -8589934592 : index, umin = 0 : index}
-      %global_id_x0 = test.reflect_bounds %global_id_x
-      %global_id_y0 = test.reflect_bounds %global_id_y
-      %global_id_z0 = test.reflect_bounds %global_id_z
+      %global_id_x0 = test.reflect_bounds %global_id_x : index
+      %global_id_y0 = test.reflect_bounds %global_id_y : index
+      %global_id_z0 = test.reflect_bounds %global_id_z : index
 
       %subgroup_size = gpu.subgroup_size : index
       %lane_id = gpu.lane_id
@@ -122,10 +122,10 @@ module attributes {gpu.container_module} {
       // CHECK: test.reflect_bounds {smax = 127 : index, smin = 0 : index, umax = 127 : index, umin = 0 : index}
       // CHECK: test.reflect_bounds {smax = 4294967295 : index, smin = 1 : index, umax = 4294967295 : index, umin = 1 : index}
       // CHECK: test.reflect_bounds {smax = 4294967294 : index, smin = 0 : index, umax = 4294967294 : index, umin = 0 : index}
-      %subgroup_size0 = test.reflect_bounds %subgroup_size
-      %lane_id0 = test.reflect_bounds %lane_id
-      %num_subgroups0 = test.reflect_bounds %num_subgroups
-      %subgroup_id0 = test.reflect_bounds %subgroup_id
+      %subgroup_size0 = test.reflect_bounds %subgroup_size : index
+      %lane_id0 = test.reflect_bounds %lane_id : index
+      %num_subgroups0 = test.reflect_bounds %num_subgroups : index
+      %subgroup_id0 = test.reflect_bounds %subgroup_id : index
 
       llvm.return
     }
@@ -148,9 +148,9 @@ module attributes {gpu.container_module} {
       // CHECK: test.reflect_bounds {smax = 20 : index, smin = 20 : index, umax = 20 : index, umin = 20 : index}
       // CHECK: test.reflect_bounds {smax = 24 : index, smin = 24 : index, umax = 24 : index, umin = 24 : index}
       // CHECK: test.reflect_bounds {smax = 28 : index, smin = 28 : index, umax = 28 : index, umin = 28 : index}
-      %grid_dim_x0 = test.reflect_bounds %grid_dim_x
-      %grid_dim_y0 = test.reflect_bounds %grid_dim_y
-      %grid_dim_z0 = test.reflect_bounds %grid_dim_z
+      %grid_dim_x0 = test.reflect_bounds %grid_dim_x : index
+      %grid_dim_y0 = test.reflect_bounds %grid_dim_y : index
+      %grid_dim_z0 = test.reflect_bounds %grid_dim_z : index
 
       %block_id_x = gpu.block_id x
       %block_id_y = gpu.block_id y
@@ -159,9 +159,9 @@ module attributes {gpu.container_module} {
       // CHECK: test.reflect_bounds {smax = 19 : index, smin = 0 : index, umax = 19 : index, umin = 0 : index}
       // CHECK: test.reflect_bounds {smax = 23 : index, smin = 0 : index, umax = 23 : index, umin = 0 : index}
       // CHECK: test.reflect_bounds {smax = 27 : index, smin = 0 : index, umax = 27 : index, umin = 0 : index}
-      %block_id_x0 = test.reflect_bounds %block_id_x
-      %block_id_y0 = test.reflect_bounds %block_id_y
-      %block_id_z0 = test.reflect_bounds %block_id_z
+      %block_id_x0 = test.reflect_bounds %block_id_x : index
+      %block_id_y0 = test.reflect_bounds %block_id_y : index
+      %block_id_z0 = test.reflect_bounds %block_id_z : index
 
       %block_dim_x = gpu.block_dim x
       %block_dim_y = gpu.block_dim y
@@ -170,9 +170,9 @@ module attributes {gpu.container_module} {
       // CHECK: test.reflect_bounds {smax = 8 : index, smin = 8 : index, umax = 8 : index, umin = 8 : index}
       // CHECK: test.reflect_bounds {smax = 12 : index, smin = 12 : index, umax = 12 : index, umin = 12 : index}
       // CHECK: test.reflect_bounds {smax = 16 : index, smin = 16 : index, umax = 16 : index, umin = 16 : index}
-      %block_dim_x0 = test.reflect_bounds %block_dim_x
-      %block_dim_y0 = test.reflect_bounds %block_dim_y
-      %block_dim_z0 = test.reflect_bounds %block_dim_z
+      %block_dim_x0 = test.reflect_bounds %block_dim_x : index
+      %block_dim_y0 = test.reflect_bounds %block_dim_y : index
+      %block_dim_z0 = test.reflect_bounds %block_dim_z : index
 
       %thread_id_x = gpu.thread_id x
       %thread_id_y = gpu.thread_id y
@@ -181,9 +181,9 @@ module attributes {gpu.container_module} {
       // CHECK: test.reflect_bounds {smax = 7 : index, smin = 0 : index, umax = 7 : index, umin = 0 : index}
       // CHECK: test.reflect_bounds {smax = 11 : index, smin = 0 : index, umax = 11 : index, umin = 0 : index}
       // CHECK: test.reflect_bounds {smax = 15 : index, smin = 0 : index, umax = 15 : index, umin = 0 : index}
-      %thread_id_x0 = test.reflect_bounds %thread_id_x
-      %thread_id_y0 = test.reflect_bounds %thread_id_y
-      %thread_id_z0 = test.reflect_bounds %thread_id_z
+      %thread_id_x0 = test.reflect_bounds %thread_id_x : index
+      %thread_id_y0 = test.reflect_bounds %thread_id_y : index
+      %thread_id_z0 = test.reflect_bounds %thread_id_z : index
 
       %global_id_x = gpu.global_id x
       %global_id_y = gpu.global_id y
@@ -192,9 +192,9 @@ module attributes {gpu.container_module} {
       // CHECK: test.reflect_bounds {smax = 159 : index, smin = 0 : index, umax = 159 : index, umin = 0 : index}
       // CHECK: test.reflect_bounds {smax = 287 : index, smin = 0 : index, umax = 287 : index, umin = 0 : index}
       // CHECK: test.reflect_bounds {smax = 447 : index, smin = 0 : index, umax = 447 : index, umin = 0 : index}
-      %global_id_x0 = test.reflect_bounds %global_id_x
-      %global_id_y0 = test.reflect_bounds %global_id_y
-      %global_id_z0 = test.reflect_bounds %global_id_z
+      %global_id_x0 = test.reflect_bounds %global_id_x : index
+      %global_id_y0 = test.reflect_bounds %global_id_y : index
+      %global_id_z0 = test.reflect_bounds %global_id_z : index
 
       %subgroup_size = gpu.subgroup_size : index
       %lane_id = gpu.lane_id
@@ -205,10 +205,10 @@ module attributes {gpu.container_module} {
       // CHECK: test.reflect_bounds {smax = 127 : index, smin = 0 : index, umax = 127 : index, umin = 0 : index}
       // CHECK: test.reflect_bounds {smax = 4294967295 : index, smin = 1 : index, umax = 4294967295 : index, umin = 1 : index}
       // CHECK: test.reflect_bounds {smax = 4294967294 : index, smin = 0 : index, umax = 4294967294 : index, umin = 0 : index}
-      %subgroup_size0 = test.reflect_bounds %subgroup_size
-      %lane_id0 = test.reflect_bounds %lane_id
-      %num_subgroups0 = test.reflect_bounds %num_subgroups
-      %subgroup_id0 = test.reflect_bounds %subgroup_id
+      %subgroup_size0 = test.reflect_bounds %subgroup_size : index
+      %lane_id0 = test.reflect_bounds %lane_id : index
+      %num_subgroups0 = test.reflect_bounds %num_subgroups : index
+      %subgroup_id0 = test.reflect_bounds %subgroup_id : index
 
       gpu.return
     }
diff --git a/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir b/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir
index c74af447d1b1f..2106eeefdca4d 100644
--- a/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir
+++ b/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir
@@ -5,7 +5,7 @@
 // CHECK: return %[[cst]]
 func.func @constant() -> index {
   %0 = test.with_bounds { umin = 3 : index, umax = 3 : index,
-                               smin = 3 : index, smax = 3 : index}
+                               smin = 3 : index, smax = 3 : index} : index
   func.return %0 : index
 }
 
@@ -13,8 +13,8 @@ func.func @constant() -> index {
 // CHECK: %[[cst:.*]] = "test.constant"() <{value = 4 : index}
 // CHECK: return %[[cst]]
 func.func @increment() -> index {
-  %0 = test.with_bounds { umin = 3 : index, umax = 3 : index, smin = 0 : index, smax = 0x7fffffffffffffff : index }
-  %1 = test.increment %0
+  %0 = test.with_bounds { umin = 3 : index, umax = 3 : index, smin = 0 : index, smax = 0x7fffffffffffffff : index } : index
+  %1 = test.increment %0 : index
   func.return %1 : index
 }
 
@@ -22,14 +22,14 @@ func.func @increment() -> index {
 // CHECK: test.reflect_bounds {smax = 4 : index, smin = 3 : index, umax = 4 : index, umin = 3 : index}
 func.func @maybe_increment(%arg0 : i1) -> index {
   %0 = test.with_bounds { umin = 3 : index, umax = 3 : index,
-                               smin = 3 : index, smax = 3 : index}
+                               smin = 3 : index, smax = 3 : index} : index
   %1 = scf.if %arg0 -> index {
     scf.yield %0 : index
   } else {
-    %2 = test.increment %0
+    %2 = test.increment %0 : index
     scf.yield %2 : index
   }
-  %3 = test.reflect_bounds %1
+  %3 = test.reflect_bounds %1 : index
   func.return %3 : index
 }
 
@@ -37,15 +37,15 @@ func.func @maybe_increment(%arg0 : i1) -> index {
 // CHECK: test.reflect_bounds {smax = 4 : index, smin = 3 : index, umax = 4 : index, umin = 3 : index}
 func.func @maybe_increment_br(%arg0 : i1) -> index {
   %0 = test.with_bounds { umin = 3 : index, umax = 3 : index,
-                               smin = 3 : index, smax = 3 : index}
+                               smin = 3 : index, smax = 3 : index} : index
   cf.cond_br %arg0, ^bb0, ^bb1
 ^bb0:
-    %1 = test.increment %0
+    %1 = test.increment %0 : index
     cf.br ^bb2(%1 : index)
 ^bb1:
     cf.br ^bb2(%0 : index)
 ^bb2(%2 : index):
-  %3 = test.reflect_bounds %2
+  %3 = test.reflect_bounds %2 : index
   func.return %3 : index
 }
 
@@ -53,16 +53,16 @@ func.func @maybe_increment_br(%arg0 : i1) -> index {
 // CHECK: test.reflect_bounds {smax = 1 : index, smin = 0 : index, umax = 1 : index, umin = 0 : index}
 func.func @for_bounds() -> index {
   %c0 = test.with_bounds { umin = 0 : index, umax = 0 : index,
-                                smin = 0 : index, smax = 0 : index}
+                                smin = 0 : index, smax = 0 : index} : index
   %c1 = test.with_bounds { umin = 1 : index, umax = 1 : index,
-                                smin = 1 : index, smax = 1 : index}
+                                smin = 1 : index, smax = 1 : index} : index
   %c2 = test.with_bounds { umin = 2 : index, umax = 2 : index,
-                                smin = 2 : index, smax = 2 : index}
+                                smin = 2 : index, smax = 2 : index} : index
 
   %0 = scf.for %arg0 = %c0 to %c2 step %c1 iter_args(%arg2 = %c0) -> index {
     scf.yield %arg0 : index
   }
-  %1 = test.reflect_bounds %0
+  %1 = test.reflect_bounds %0 : index
   func.return %1 : index
 }
 
@@ -70,17 +70,17 @@ func.func @for_bounds() -> index {
 // CHECK: test.reflect_bounds {smax = 9223372036854775807 : index, smin = -9223372036854775808 : index, umax = -1 : index, umin = 0 : index}
 func.func @no_analysis_of_loop_variants() -> index {
   %c0 = test.with_bounds { umin = 0 : index, umax = 0 : index,
-                                smin = 0 : index, smax = 0 : index}
+                                smin = 0 : index, smax = 0 : index} : index
   %c1 = test.with_bounds { umin = 1 : index, umax = 1 : index,
-                                smin = 1 : index, smax = 1 : index}
+                                smin = 1 : index, smax = 1 : index} : index
   %c2 = test.with_bounds { umin = 2 : index, umax = 2 : index,
-                                smin = 2 : index, smax = 2 : index}
+                                smin = 2 : index, smax = 2 : index} : index
 
   %0 = scf.for %arg0 = %c0 to %c2 step %c1 iter_args(%arg2 = %c0) -> index {
-    %1 = test.increment %arg2
+    %1 = test.increment %arg2 : index
     scf.yield %1 : index
   }
-  %2 = test.reflect_bounds %0
+  %2 = test.reflect_bounds %0 : index
   func.return %2 : index
 }
 
@@ -88,8 +88,8 @@ func.func @no_analysis_of_loop_variants() -> index {
 // CHECK: test.reflect_bounds {smax = 4 : index, smin = 3 : index, umax = 4 : index, umin = 3 : index}
 func.func @region_args() {
   test.with_bounds_region { umin = 3 : index, umax = 4 : index,
-                            smin = 3 : index, smax = 4 : index } %arg0 {
-    %0 = test.reflect_bounds %arg0
+                            smin = 3 : index, smax = 4 : index } %arg0 : index {
+    %0 = test.reflect_bounds %arg0 : index
   }
   func.return
 }
@@ -97,7 +97,7 @@ func.func @region_args() {
 // CHECK-LABEL: func @func_args_unbound
 // CHECK: test.reflect_bounds {smax = 9223372036854775807 : index, smin = -9223372036854775808 : index, umax = -1 : index, umin = 0 : index}
 func.func @func_args_unbound(%arg0 : index) -> index {
-  %0 = test.reflect_bounds %arg0
+  %0 = test.reflect_bounds %arg0 : index
   func.return %0 : index
 }
 
@@ -106,7 +106,7 @@ func.func @propagate_across_while_loop_false() -> index {
   // CHECK-DAG: %[[C0:.*]] = "test.constant"() <{value = 0
   // CHECK-DAG: %[[C1:.*]] = "test.constant"() <{value = 1
   %0 = test.with_bounds { umin = 0 : index, umax = 0 : index,
-                          smin = 0 : index, smax = 0 : index }
+                          smin = 0 : index, smax = 0 : index } : index
   %1 = scf.while : () -> index {
     %false = arith.constant false
     // CHECK: scf.condition(%{{.*}}) %[[C0]]
@@ -116,7 +116,7 @@ func.func @propagate_across_while_loop_false() -> index {
     scf.yield
   }
   // CHECK: return %[[C1]]
-  %2 = test.increment %1
+  %2 = test.increment %1 : index
   return %2 : index
 }
 
@@ -125,7 +125,7 @@ func.func @propagate_across_while_loop(%arg0 : i1) -> index {
   // CHECK-DAG: %[[C0:.*]] = "test.constant"() <{value = 0
   // CHECK-DAG: %[[C1:.*]] = "test.constant"() <{value = 1
   %0 = test.with_bounds { umin = 0 : index, umax = 0 : index,
-                          smin = 0 : index, smax = 0 : index }
+                          smin = 0 : index, smax = 0 : index } : index
   %1 = scf.while : () -> index {
     // CHECK: scf.condition(%{{.*}}) %[[C0]]
     scf.condition(%arg0) %0 : index
@@ -134,7 +134,7 @@ func.func @propagate_across_while_loop(%arg0 : i1) -> index {
     scf.yield
   }
   // CHECK: return %[[C1]]
-  %2 = test.increment %1
+  %2 = test.increment %1 : index
   return %2 : index
 }
 
@@ -142,7 +142,7 @@ func.func @propagate_across_while_loop(%arg0 : i1) -> index {
 func.func @dont_propagate_across_infinite_loop() -> index {
   // CHECK: %[[C0:.*]] = "test.constant"() <{value = 0
   %0 = test.with_bounds { umin = 0 : index, umax = 0 : index,
-                          smin = 0 : index, smax = 0 : index }
+                          smin = 0 : index, smax = 0 : index } : index
   // CHECK: %[[loopRes:.*]] = scf.while
   %1 = scf.while : () -> index {
     %true = arith.constant true
@@ -152,8 +152,8 @@ func.func @dont_propagate_across_infinite_loop() -> index {
   ^bb0(%i1: index):
     scf.yield
   }
-  // CHECK: %[[ret:.*]] = test.reflect_bounds %[[loopRes]]
-  %2 = test.reflect_bounds %1
+  // CHECK: %[[ret:.*]] = test.reflect_bounds %[[loopRes]] : index
+  %2 = test.reflect_bounds %1 : index
   // CHECK: return %[[ret]]
   return %2 : index
 }
diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
index 08df2e5e12286..04acf347243d7 100644
--- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
+++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
@@ -662,8 +662,7 @@ ParseResult TestWithBoundsRegionOp::parse(OpAsmParser &parser,
 
   // Parse the input argument
   OpAsmParser::Argument argInfo;
-  argInfo.type = parser.getBuilder().getIndexType();
-  if (failed(parser.parseArgument(argInfo)))
+  if (failed(parser.parseArgument(argInfo, true)))
     return failure();
 
   // Parse the body region, and reuse the operand info as the argument info.
@@ -675,7 +674,7 @@ void TestWithBoundsRegionOp::print(OpAsmPrinter &p) {
   p.printOptionalAttrDict((*this)->getAttrs());
   p << ' ';
   p.printRegionArgument(getRegion().getArgument(0), /*argAttrs=*/{},
-                        /*omitType=*/true);
+                        /*omitType=*/false);
   p << ' ';
   p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
 }
@@ -706,10 +705,11 @@ void TestReflectBoundsOp::inferResultRanges(
   const ConstantIntRanges &range = argRanges[0];
   MLIRContext *ctx = getContext();
   Builder b(ctx);
-  setUminAttr(b.getIndexAttr(range.umin().getZExtValue()));
-  setUmaxAttr(b.getIndexAttr(range.umax().getZExtValue()));
-  setSminAttr(b.getIndexAttr(range.smin().getSExtValue()));
-  setSmaxAttr(b.getIndexAttr(range.smax().getSExtValue()));
+  auto intTy = getType();
+  setUminAttr(b.getIntegerAttr(intTy, range.umin()));
+  setUmaxAttr(b.getIntegerAttr(intTy, range.umax()));
+  setSminAttr(b.getIntegerAttr(intTy, range.smin()));
+  setSmaxAttr(b.getIntegerAttr(intTy, range.smax()));
   setResultRanges(getResult(), range);
 }
 
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 5352d574ac394..9ed9f910f4e2c 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -2733,49 +2733,51 @@ def TestGraphLoopOp : TEST_Op<"graph_loop",
 //===----------------------------------------------------------------------===//
 // Test InferIntRangeInterface
 //===----------------------------------------------------------------------===//
+def InferIntRangeType : AnyTypeOf<[AnyInteger, Index]>;
+
 def TestWithBoundsOp : TEST_Op<"with_bounds",
                           [DeclareOpInterfaceMethods<InferIntRangeInterface>,
                            NoMemoryEffect]> {
-  let arguments = (ins IndexAttr:$umin,
-                       IndexAttr:$umax,
-                       IndexAttr:$smin,
-                       IndexAttr:$smax);
-  let results = (outs Index:$fakeVal);
+  let arguments = (ins APIntAttr:$umin,
+                       APIntAttr:$umax,
+                       APIntAttr:$smin,
+                       APIntAttr:$smax);
+  let results = (outs InferIntRangeType:$fakeVal);
 
-  let assemblyFormat = "attr-dict";
+  let assemblyFormat = "attr-dict `:` type($fakeVal)";
 }
 
 def TestWithBoundsRegionOp : TEST_Op<"with_bounds_region",
                           [DeclareOpInterfaceMethods<InferIntRangeInterface>,
                            SingleBlock, NoTerminator]> {
-  let arguments = (ins IndexAttr:$umin,
-                       IndexAttr:$umax,
-                       IndexAttr:$smin,
-                       IndexAttr:$smax);
-  // The region has one argument of index type
+  let arguments = (ins APIntAttr:$umin,
+                       APIntAttr:$umax,
+                       APIntAttr:$smin,
+                       APIntAttr:$smax);
+  // The region has one argument of any integer type
   let regions = (region SizedRegion<1>:$region);
   let hasCustomAssemblyFormat = 1;
 }
 
 def TestIncrementOp : TEST_Op<"increment",
                          [DeclareOpInterfaceMethods<InferIntRangeInterface>,
-                         NoMemoryEffect]> {
-  let arguments = (ins Index:$value);
-  let results = (outs Index:$result);
+                         NoMemoryEffect, AllTypesMatch<["value", "result"]>]> {
+  let arguments = (ins InferIntRangeType:$value);
+  let results = (outs InferIntRangeType:$result);
 
-  let assemblyFormat = "attr-dict $value";
+  let assemblyFormat = "attr-dict $value `:` type($result)";
 }
 
 def TestReflectBoundsOp : TEST_Op<"reflect_bounds",
-                         [DeclareOpInterfaceMethods<InferIntRangeInterface>]> {
-  let arguments = (ins Index:$value,
-                       OptionalAttr<IndexAttr>:$umin,
-                       OptionalAttr<IndexAttr>:$umax,
-                       OptionalAttr<IndexAttr>:$smin,
-                       OptionalAttr<IndexAttr>:$smax);
-  let results = (outs Index:$result);
-
-  let assemblyFormat = "attr-dict $value";
+                         [DeclareOpInterfaceMethods<InferIntRangeInterface>, AllTypesMatch<["value", "result"]>]> {
+  let arguments = (ins InferIntRangeType:$value,
+                       OptionalAttr<APIntAttr>:$umin,
+                       OptionalAttr<APIntAttr>:$umax,
+                       OptionalAttr<APIntAttr>:$smin,
+                       OptionalAttr<APIntAttr>:$smax);
+  let results = (outs InferIntRangeType:$result);
+
+  let assemblyFormat = "attr-dict $value `:` type($result)";
 }
 
 //===----------------------------------------------------------------------===//



More information about the Mlir-commits mailing list