[Mlir-commits] [mlir] acd1007 - [mlir][test] Extend `InferIntRangeInterface` test Ops to arbitrary ints (#91850)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue May 14 11:33:21 PDT 2024


Author: Felix Schneider
Date: 2024-05-14T20:33:16+02:00
New Revision: acd100747fff85e7cfb67caa6c0f1053e820c1ac

URL: https://github.com/llvm/llvm-project/commit/acd100747fff85e7cfb67caa6c0f1053e820c1ac
DIFF: https://github.com/llvm/llvm-project/commit/acd100747fff85e7cfb67caa6c0f1053e820c1ac.diff

LOG: [mlir][test] Extend `InferIntRangeInterface` test Ops to arbitrary ints (#91850)

This PR is in preparation to some extensions to the
`InferIntRangeInterface` around the `nsw` and `nuw` flags supported in
the `arith` dialect and LLVM.

We provide some common inference logic for `index` and `arith` in
`InferIntRangeCommon.h` but our Test Ops are currently fixed to `Index`
Types. As we test the range inference for arith Ops, especially around
the overflow behaviour, it's handy to have native support for the
typical integer types in the test Ops.

This patch
1. Changes the Attributes of `test.with_bounds` ops from `Index` to
`APInt` which matches the internal representation in
`ConstantIntRanges`.
2. Allows the use of `AnyInteger` in addition to `Index` for the
operands and results of the test Ops. This now requires explicit
specification of the type in the IR, where before `Index` was implicit.
3. Requires bounds Attrs to be specified in the precision of the SSA
value, eliminating any implicit truncation or extension. (*Could this
lead to problems?*)

Added: 
    

Modified: 
    mlir/test/Dialect/Arith/int-range-interface.mlir
    mlir/test/Dialect/Arith/int-range-opts.mlir
    mlir/test/Dialect/GPU/int-range-interface.mlir
    mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir
    mlir/test/lib/Dialect/Test/TestOpDefs.cpp
    mlir/test/lib/Dialect/Test/TestOps.td

Removed: 
    


################################################################################
diff  --git a/mlir/test/Dialect/Arith/int-range-interface.mlir b/mlir/test/Dialect/Arith/int-range-interface.mlir
index 02a9827d19d8f..16524b3634723 100644
--- a/mlir/test/Dialect/Arith/int-range-interface.mlir
+++ b/mlir/test/Dialect/Arith/int-range-interface.mlir
@@ -756,3 +756,13 @@ func.func private @callee(%arg0: memref<?xindex, 4>) {
   }
   return
 }
+
+// CHECK-LABEL: func @test_i8_bounds
+// CHECK: test.reflect_bounds {smax = 127 : i8, smin = -128 : i8, umax = -1 : i8, umin = 0 : i8}
+func.func @test_i8_bounds() -> i8 {
+  %cst1 = arith.constant 1 : i8
+  %0 = test.with_bounds { umin = 0 : i8, umax = 255 : i8, smin = -128 : i8, smax = 127 : i8 } : i8
+  %1 = arith.addi %0, %cst1 : i8
+  %2 = test.reflect_bounds %1 : i8
+  return %2: i8
+}

diff  --git a/mlir/test/Dialect/Arith/int-range-opts.mlir b/mlir/test/Dialect/Arith/int-range-opts.mlir
index 4c3c0854ed026..6179003ab4e74 100644
--- a/mlir/test/Dialect/Arith/int-range-opts.mlir
+++ b/mlir/test/Dialect/Arith/int-range-opts.mlir
@@ -5,7 +5,7 @@
 //       CHECK:   return %[[C]]
 func.func @test() -> i1 {
   %cst1 = arith.constant -1 : index
-  %0 = test.with_bounds { umin = 0 : index, umax = 0x7fffffffffffffff : index, smin = 0 : index, smax = 0x7fffffffffffffff : index }
+  %0 = test.with_bounds { umin = 0 : index, umax = 0x7fffffffffffffff : index, smin = 0 : index, smax = 0x7fffffffffffffff : index } : index
   %1 = arith.cmpi eq, %0, %cst1 : index
   return %1: i1
 }
@@ -17,7 +17,7 @@ func.func @test() -> i1 {
 //       CHECK:   return %[[C]]
 func.func @test() -> i1 {
   %cst1 = arith.constant -1 : index
-  %0 = test.with_bounds { umin = 0 : index, umax = 0x7fffffffffffffff : index, smin = 0 : index, smax = 0x7fffffffffffffff : index }
+  %0 = test.with_bounds { umin = 0 : index, umax = 0x7fffffffffffffff : index, smin = 0 : index, smax = 0x7fffffffffffffff : index } : index
   %1 = arith.cmpi ne, %0, %cst1 : index
   return %1: i1
 }
@@ -30,7 +30,7 @@ func.func @test() -> i1 {
 //       CHECK:   return %[[C]]
 func.func @test() -> i1 {
   %cst = arith.constant 0 : index
-  %0 = test.with_bounds { umin = 0 : index, umax = 0x7fffffffffffffff : index, smin = 0 : index, smax = 0x7fffffffffffffff : index }
+  %0 = test.with_bounds { umin = 0 : index, umax = 0x7fffffffffffffff : index, smin = 0 : index, smax = 0x7fffffffffffffff : index } : index
   %1 = arith.cmpi sge, %0, %cst : index
   return %1: i1
 }
@@ -42,7 +42,7 @@ func.func @test() -> i1 {
 //       CHECK:   return %[[C]]
 func.func @test() -> i1 {
   %cst = arith.constant 0 : index
-  %0 = test.with_bounds { umin = 0 : index, umax = 0x7fffffffffffffff : index, smin = 0 : index, smax = 0x7fffffffffffffff : index }
+  %0 = test.with_bounds { umin = 0 : index, umax = 0x7fffffffffffffff : index, smin = 0 : index, smax = 0x7fffffffffffffff : index } : index
   %1 = arith.cmpi slt, %0, %cst : index
   return %1: i1
 }
@@ -55,7 +55,7 @@ func.func @test() -> i1 {
 //       CHECK:   return %[[C]]
 func.func @test() -> i1 {
   %cst1 = arith.constant -1 : index
-  %0 = test.with_bounds { umin = 0 : index, umax = 0x7fffffffffffffff : index, smin = 0 : index, smax = 0x7fffffffffffffff : index }
+  %0 = test.with_bounds { umin = 0 : index, umax = 0x7fffffffffffffff : index, smin = 0 : index, smax = 0x7fffffffffffffff : index } : index
   %1 = arith.cmpi sgt, %0, %cst1 : index
   return %1: i1
 }
@@ -67,7 +67,7 @@ func.func @test() -> i1 {
 //       CHECK:   return %[[C]]
 func.func @test() -> i1 {
   %cst1 = arith.constant -1 : index
-  %0 = test.with_bounds { umin = 0 : index, umax = 0x7fffffffffffffff : index, smin = 0 : index, smax = 0x7fffffffffffffff : index }
+  %0 = test.with_bounds { umin = 0 : index, umax = 0x7fffffffffffffff : index, smin = 0 : index, smax = 0x7fffffffffffffff : index } : index
   %1 = arith.cmpi sle, %0, %cst1 : index
   return %1: i1
 }
@@ -75,28 +75,24 @@ func.func @test() -> i1 {
 // -----
 
 // CHECK-LABEL: func @test
-// CHECK: test.reflect_bounds {smax = 24 : index, smin = 0 : index, umax = 24 : index, umin = 0 : index}
-func.func @test() -> index {
+// CHECK: test.reflect_bounds {smax = 24 : i8, smin = 0 : i8, umax = 24 : i8, umin = 0 : i8}
+func.func @test() -> i8 {
   %cst1 = arith.constant 1 : i8
-  %0 = test.with_bounds { umin = 0 : index, umax = 12 : index, smin = 0 : index, smax = 12 : index }
-  %i8val = arith.index_cast %0 : index to i8
+  %i8val = test.with_bounds { umin = 0 : i8, umax = 12 : i8, smin = 0 : i8, smax = 12 : i8 } : i8
   %shifted = arith.shli %i8val, %cst1 : i8
-  %si = arith.index_cast %shifted : i8 to index
-  %1 = test.reflect_bounds %si
-  return %1: index
+  %1 = test.reflect_bounds %shifted : i8
+  return %1: i8
 }
 
 // -----
 
 // CHECK-LABEL: func @test
-// CHECK: test.reflect_bounds {smax = 127 : index, smin = -128 : index, umax = -1 : index, umin = 0 : index}
-func.func @test() -> index {
+// CHECK: test.reflect_bounds {smax = 127 : i8, smin = -128 : i8, umax = -1 : i8, umin = 0 : i8}
+func.func @test() -> i8 {
   %cst1 = arith.constant 1 : i8
-  %0 = test.with_bounds { umin = 0 : index, umax = 127 : index, smin = 0 : index, smax = 127 : index }
-  %i8val = arith.index_cast %0 : index to i8
+  %i8val = test.with_bounds { umin = 0 : i8, umax = 127 : i8, smin = 0 : i8, smax = 127 : i8 } : i8
   %shifted = arith.shli %i8val, %cst1 : i8
-  %si = arith.index_cast %shifted : i8 to index
-  %1 = test.reflect_bounds %si
-  return %1: index
+  %1 = test.reflect_bounds %shifted : i8
+  return %1: i8
 }
 

diff  --git a/mlir/test/Dialect/GPU/int-range-interface.mlir b/mlir/test/Dialect/GPU/int-range-interface.mlir
index 02aec9dc0476f..980f7e5873e0c 100644
--- a/mlir/test/Dialect/GPU/int-range-interface.mlir
+++ b/mlir/test/Dialect/GPU/int-range-interface.mlir
@@ -5,46 +5,46 @@ func.func @launch_func(%arg0 : index) {
   %0 = test.with_bounds {
     umin = 3 : index, umax = 5 : index,
     smin = 3 : index, smax = 5 : index
-  }
+  } : index
   %1 = test.with_bounds {
     umin = 7 : index, umax = 11 : index,
     smin = 7 : index, smax = 11 : index
-  }
+  } : index
   gpu.launch blocks(%block_id_x, %block_id_y, %block_id_z) in (%grid_dim_x = %0, %grid_dim_y = %1, %grid_dim_z = %arg0)
       threads(%thread_id_x, %thread_id_y, %thread_id_z) in (%block_dim_x = %arg0, %block_dim_y = %0, %block_dim_z = %1) {
 
     // CHECK: test.reflect_bounds {smax = 5 : index, smin = 3 : index, umax = 5 : index, umin = 3 : index}
     // CHECK: test.reflect_bounds {smax = 11 : index, smin = 7 : index, umax = 11 : index, umin = 7 : index}
     // CHECK: test.reflect_bounds {smax = 4294967295 : index, smin = 1 : index, umax = 4294967295 : index, umin = 1 : index}
-    %grid_dim_x0 = test.reflect_bounds %grid_dim_x
-    %grid_dim_y0 = test.reflect_bounds %grid_dim_y
-    %grid_dim_z0 = test.reflect_bounds %grid_dim_z
+    %grid_dim_x0 = test.reflect_bounds %grid_dim_x : index
+    %grid_dim_y0 = test.reflect_bounds %grid_dim_y : index
+    %grid_dim_z0 = test.reflect_bounds %grid_dim_z : index
 
     // CHECK: test.reflect_bounds {smax = 4 : index, smin = 0 : index, umax = 4 : index, umin = 0 : index}
     // CHECK: test.reflect_bounds {smax = 10 : index, smin = 0 : index, umax = 10 : index, umin = 0 : index}
     // CHECK: test.reflect_bounds {smax = 4294967294 : index, smin = 0 : index, umax = 4294967294 : index, umin = 0 : index}
-    %block_id_x0 = test.reflect_bounds %block_id_x
-    %block_id_y0 = test.reflect_bounds %block_id_y
-    %block_id_z0 = test.reflect_bounds %block_id_z
+    %block_id_x0 = test.reflect_bounds %block_id_x : index
+    %block_id_y0 = test.reflect_bounds %block_id_y : index
+    %block_id_z0 = test.reflect_bounds %block_id_z : index
 
     // CHECK: test.reflect_bounds {smax = 4294967295 : index, smin = 1 : index, umax = 4294967295 : index, umin = 1 : index}
     // CHECK: test.reflect_bounds {smax = 5 : index, smin = 3 : index, umax = 5 : index, umin = 3 : index}
     // CHECK: test.reflect_bounds {smax = 11 : index, smin = 7 : index, umax = 11 : index, umin = 7 : index}
-    %block_dim_x0 = test.reflect_bounds %block_dim_x
-    %block_dim_y0 = test.reflect_bounds %block_dim_y
-    %block_dim_z0 = test.reflect_bounds %block_dim_z
+    %block_dim_x0 = test.reflect_bounds %block_dim_x : index
+    %block_dim_y0 = test.reflect_bounds %block_dim_y : index
+    %block_dim_z0 = test.reflect_bounds %block_dim_z : index
 
     // CHECK: test.reflect_bounds {smax = 4294967294 : index, smin = 0 : index, umax = 4294967294 : index, umin = 0 : index}
     // CHECK: test.reflect_bounds {smax = 4 : index, smin = 0 : index, umax = 4 : index, umin = 0 : index}
     // CHECK: test.reflect_bounds {smax = 10 : index, smin = 0 : index, umax = 10 : index, umin = 0 : index}
-    %thread_id_x0 = test.reflect_bounds %thread_id_x
-    %thread_id_y0 = test.reflect_bounds %thread_id_y
-    %thread_id_z0 = test.reflect_bounds %thread_id_z
+    %thread_id_x0 = test.reflect_bounds %thread_id_x : index
+    %thread_id_y0 = test.reflect_bounds %thread_id_y : index
+    %thread_id_z0 = test.reflect_bounds %thread_id_z : index
 
     // The launch bounds are not constant, and so this can't infer anything
     // CHECK: test.reflect_bounds {smax = 4294967294 : index, smin = 0 : index, umax = 4294967294 : index, umin = 0 : index}
     %thread_id_op = gpu.thread_id y
-    %thread_id_op0 = test.reflect_bounds %thread_id_op
+    %thread_id_op0 = test.reflect_bounds %thread_id_op : index
     gpu.terminator
   }
 
@@ -65,9 +65,9 @@ module attributes {gpu.container_module} {
       // CHECK: test.reflect_bounds {smax = 4294967295 : index, smin = 1 : index, umax = 4294967295 : index, umin = 1 : index}
       // CHECK: test.reflect_bounds {smax = 4294967295 : index, smin = 1 : index, umax = 4294967295 : index, umin = 1 : index}
       // CHECK: test.reflect_bounds {smax = 4294967295 : index, smin = 1 : index, umax = 4294967295 : index, umin = 1 : index}
-      %grid_dim_x0 = test.reflect_bounds %grid_dim_x
-      %grid_dim_y0 = test.reflect_bounds %grid_dim_y
-      %grid_dim_z0 = test.reflect_bounds %grid_dim_z
+      %grid_dim_x0 = test.reflect_bounds %grid_dim_x : index
+      %grid_dim_y0 = test.reflect_bounds %grid_dim_y : index
+      %grid_dim_z0 = test.reflect_bounds %grid_dim_z : index
 
       %block_id_x = gpu.block_id x
       %block_id_y = gpu.block_id y
@@ -76,9 +76,9 @@ module attributes {gpu.container_module} {
       // CHECK: test.reflect_bounds {smax = 4294967294 : index, smin = 0 : index, umax = 4294967294 : index, umin = 0 : index}
       // CHECK: test.reflect_bounds {smax = 4294967294 : index, smin = 0 : index, umax = 4294967294 : index, umin = 0 : index}
       // CHECK: test.reflect_bounds {smax = 4294967294 : index, smin = 0 : index, umax = 4294967294 : index, umin = 0 : index}
-      %block_id_x0 = test.reflect_bounds %block_id_x
-      %block_id_y0 = test.reflect_bounds %block_id_y
-      %block_id_z0 = test.reflect_bounds %block_id_z
+      %block_id_x0 = test.reflect_bounds %block_id_x : index
+      %block_id_y0 = test.reflect_bounds %block_id_y : index
+      %block_id_z0 = test.reflect_bounds %block_id_z : index
 
       %block_dim_x = gpu.block_dim x
       %block_dim_y = gpu.block_dim y
@@ -87,9 +87,9 @@ module attributes {gpu.container_module} {
       // CHECK: test.reflect_bounds {smax = 4294967295 : index, smin = 1 : index, umax = 4294967295 : index, umin = 1 : index}
       // CHECK: test.reflect_bounds {smax = 4294967295 : index, smin = 1 : index, umax = 4294967295 : index, umin = 1 : index}
       // CHECK: test.reflect_bounds {smax = 4294967295 : index, smin = 1 : index, umax = 4294967295 : index, umin = 1 : index}
-      %block_dim_x0 = test.reflect_bounds %block_dim_x
-      %block_dim_y0 = test.reflect_bounds %block_dim_y
-      %block_dim_z0 = test.reflect_bounds %block_dim_z
+      %block_dim_x0 = test.reflect_bounds %block_dim_x : index
+      %block_dim_y0 = test.reflect_bounds %block_dim_y : index
+      %block_dim_z0 = test.reflect_bounds %block_dim_z : index
 
       %thread_id_x = gpu.thread_id x
       %thread_id_y = gpu.thread_id y
@@ -98,9 +98,9 @@ module attributes {gpu.container_module} {
       // CHECK: test.reflect_bounds {smax = 4294967294 : index, smin = 0 : index, umax = 4294967294 : index, umin = 0 : index}
       // CHECK: test.reflect_bounds {smax = 4294967294 : index, smin = 0 : index, umax = 4294967294 : index, umin = 0 : index}
       // CHECK: test.reflect_bounds {smax = 4294967294 : index, smin = 0 : index, umax = 4294967294 : index, umin = 0 : index}
-      %thread_id_x0 = test.reflect_bounds %thread_id_x
-      %thread_id_y0 = test.reflect_bounds %thread_id_y
-      %thread_id_z0 = test.reflect_bounds %thread_id_z
+      %thread_id_x0 = test.reflect_bounds %thread_id_x : index
+      %thread_id_y0 = test.reflect_bounds %thread_id_y : index
+      %thread_id_z0 = test.reflect_bounds %thread_id_z : index
 
       %global_id_x = gpu.global_id x
       %global_id_y = gpu.global_id y
@@ -109,9 +109,9 @@ module attributes {gpu.container_module} {
       // CHECK: test.reflect_bounds {smax = 9223372036854775807 : index, smin = -9223372036854775808 : index, umax = -8589934592 : index, umin = 0 : index}
       // CHECK: test.reflect_bounds {smax = 9223372036854775807 : index, smin = -9223372036854775808 : index, umax = -8589934592 : index, umin = 0 : index}
       // CHECK: test.reflect_bounds {smax = 9223372036854775807 : index, smin = -9223372036854775808 : index, umax = -8589934592 : index, umin = 0 : index}
-      %global_id_x0 = test.reflect_bounds %global_id_x
-      %global_id_y0 = test.reflect_bounds %global_id_y
-      %global_id_z0 = test.reflect_bounds %global_id_z
+      %global_id_x0 = test.reflect_bounds %global_id_x : index
+      %global_id_y0 = test.reflect_bounds %global_id_y : index
+      %global_id_z0 = test.reflect_bounds %global_id_z : index
 
       %subgroup_size = gpu.subgroup_size : index
       %lane_id = gpu.lane_id
@@ -122,10 +122,10 @@ module attributes {gpu.container_module} {
       // CHECK: test.reflect_bounds {smax = 127 : index, smin = 0 : index, umax = 127 : index, umin = 0 : index}
       // CHECK: test.reflect_bounds {smax = 4294967295 : index, smin = 1 : index, umax = 4294967295 : index, umin = 1 : index}
       // CHECK: test.reflect_bounds {smax = 4294967294 : index, smin = 0 : index, umax = 4294967294 : index, umin = 0 : index}
-      %subgroup_size0 = test.reflect_bounds %subgroup_size
-      %lane_id0 = test.reflect_bounds %lane_id
-      %num_subgroups0 = test.reflect_bounds %num_subgroups
-      %subgroup_id0 = test.reflect_bounds %subgroup_id
+      %subgroup_size0 = test.reflect_bounds %subgroup_size : index
+      %lane_id0 = test.reflect_bounds %lane_id : index
+      %num_subgroups0 = test.reflect_bounds %num_subgroups : index
+      %subgroup_id0 = test.reflect_bounds %subgroup_id : index
 
       llvm.return
     }
@@ -148,9 +148,9 @@ module attributes {gpu.container_module} {
       // CHECK: test.reflect_bounds {smax = 20 : index, smin = 20 : index, umax = 20 : index, umin = 20 : index}
       // CHECK: test.reflect_bounds {smax = 24 : index, smin = 24 : index, umax = 24 : index, umin = 24 : index}
       // CHECK: test.reflect_bounds {smax = 28 : index, smin = 28 : index, umax = 28 : index, umin = 28 : index}
-      %grid_dim_x0 = test.reflect_bounds %grid_dim_x
-      %grid_dim_y0 = test.reflect_bounds %grid_dim_y
-      %grid_dim_z0 = test.reflect_bounds %grid_dim_z
+      %grid_dim_x0 = test.reflect_bounds %grid_dim_x : index
+      %grid_dim_y0 = test.reflect_bounds %grid_dim_y : index
+      %grid_dim_z0 = test.reflect_bounds %grid_dim_z : index
 
       %block_id_x = gpu.block_id x
       %block_id_y = gpu.block_id y
@@ -159,9 +159,9 @@ module attributes {gpu.container_module} {
       // CHECK: test.reflect_bounds {smax = 19 : index, smin = 0 : index, umax = 19 : index, umin = 0 : index}
       // CHECK: test.reflect_bounds {smax = 23 : index, smin = 0 : index, umax = 23 : index, umin = 0 : index}
       // CHECK: test.reflect_bounds {smax = 27 : index, smin = 0 : index, umax = 27 : index, umin = 0 : index}
-      %block_id_x0 = test.reflect_bounds %block_id_x
-      %block_id_y0 = test.reflect_bounds %block_id_y
-      %block_id_z0 = test.reflect_bounds %block_id_z
+      %block_id_x0 = test.reflect_bounds %block_id_x : index
+      %block_id_y0 = test.reflect_bounds %block_id_y : index
+      %block_id_z0 = test.reflect_bounds %block_id_z : index
 
       %block_dim_x = gpu.block_dim x
       %block_dim_y = gpu.block_dim y
@@ -170,9 +170,9 @@ module attributes {gpu.container_module} {
       // CHECK: test.reflect_bounds {smax = 8 : index, smin = 8 : index, umax = 8 : index, umin = 8 : index}
       // CHECK: test.reflect_bounds {smax = 12 : index, smin = 12 : index, umax = 12 : index, umin = 12 : index}
       // CHECK: test.reflect_bounds {smax = 16 : index, smin = 16 : index, umax = 16 : index, umin = 16 : index}
-      %block_dim_x0 = test.reflect_bounds %block_dim_x
-      %block_dim_y0 = test.reflect_bounds %block_dim_y
-      %block_dim_z0 = test.reflect_bounds %block_dim_z
+      %block_dim_x0 = test.reflect_bounds %block_dim_x : index
+      %block_dim_y0 = test.reflect_bounds %block_dim_y : index
+      %block_dim_z0 = test.reflect_bounds %block_dim_z : index
 
       %thread_id_x = gpu.thread_id x
       %thread_id_y = gpu.thread_id y
@@ -181,9 +181,9 @@ module attributes {gpu.container_module} {
       // CHECK: test.reflect_bounds {smax = 7 : index, smin = 0 : index, umax = 7 : index, umin = 0 : index}
       // CHECK: test.reflect_bounds {smax = 11 : index, smin = 0 : index, umax = 11 : index, umin = 0 : index}
       // CHECK: test.reflect_bounds {smax = 15 : index, smin = 0 : index, umax = 15 : index, umin = 0 : index}
-      %thread_id_x0 = test.reflect_bounds %thread_id_x
-      %thread_id_y0 = test.reflect_bounds %thread_id_y
-      %thread_id_z0 = test.reflect_bounds %thread_id_z
+      %thread_id_x0 = test.reflect_bounds %thread_id_x : index
+      %thread_id_y0 = test.reflect_bounds %thread_id_y : index
+      %thread_id_z0 = test.reflect_bounds %thread_id_z : index
 
       %global_id_x = gpu.global_id x
       %global_id_y = gpu.global_id y
@@ -192,9 +192,9 @@ module attributes {gpu.container_module} {
       // CHECK: test.reflect_bounds {smax = 159 : index, smin = 0 : index, umax = 159 : index, umin = 0 : index}
       // CHECK: test.reflect_bounds {smax = 287 : index, smin = 0 : index, umax = 287 : index, umin = 0 : index}
       // CHECK: test.reflect_bounds {smax = 447 : index, smin = 0 : index, umax = 447 : index, umin = 0 : index}
-      %global_id_x0 = test.reflect_bounds %global_id_x
-      %global_id_y0 = test.reflect_bounds %global_id_y
-      %global_id_z0 = test.reflect_bounds %global_id_z
+      %global_id_x0 = test.reflect_bounds %global_id_x : index
+      %global_id_y0 = test.reflect_bounds %global_id_y : index
+      %global_id_z0 = test.reflect_bounds %global_id_z : index
 
       %subgroup_size = gpu.subgroup_size : index
       %lane_id = gpu.lane_id
@@ -205,10 +205,10 @@ module attributes {gpu.container_module} {
       // CHECK: test.reflect_bounds {smax = 127 : index, smin = 0 : index, umax = 127 : index, umin = 0 : index}
       // CHECK: test.reflect_bounds {smax = 4294967295 : index, smin = 1 : index, umax = 4294967295 : index, umin = 1 : index}
       // CHECK: test.reflect_bounds {smax = 4294967294 : index, smin = 0 : index, umax = 4294967294 : index, umin = 0 : index}
-      %subgroup_size0 = test.reflect_bounds %subgroup_size
-      %lane_id0 = test.reflect_bounds %lane_id
-      %num_subgroups0 = test.reflect_bounds %num_subgroups
-      %subgroup_id0 = test.reflect_bounds %subgroup_id
+      %subgroup_size0 = test.reflect_bounds %subgroup_size : index
+      %lane_id0 = test.reflect_bounds %lane_id : index
+      %num_subgroups0 = test.reflect_bounds %num_subgroups : index
+      %subgroup_id0 = test.reflect_bounds %subgroup_id : index
 
       gpu.return
     }

diff  --git a/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir b/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir
index c74af447d1b1f..2106eeefdca4d 100644
--- a/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir
+++ b/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir
@@ -5,7 +5,7 @@
 // CHECK: return %[[cst]]
 func.func @constant() -> index {
   %0 = test.with_bounds { umin = 3 : index, umax = 3 : index,
-                               smin = 3 : index, smax = 3 : index}
+                               smin = 3 : index, smax = 3 : index} : index
   func.return %0 : index
 }
 
@@ -13,8 +13,8 @@ func.func @constant() -> index {
 // CHECK: %[[cst:.*]] = "test.constant"() <{value = 4 : index}
 // CHECK: return %[[cst]]
 func.func @increment() -> index {
-  %0 = test.with_bounds { umin = 3 : index, umax = 3 : index, smin = 0 : index, smax = 0x7fffffffffffffff : index }
-  %1 = test.increment %0
+  %0 = test.with_bounds { umin = 3 : index, umax = 3 : index, smin = 0 : index, smax = 0x7fffffffffffffff : index } : index
+  %1 = test.increment %0 : index
   func.return %1 : index
 }
 
@@ -22,14 +22,14 @@ func.func @increment() -> index {
 // CHECK: test.reflect_bounds {smax = 4 : index, smin = 3 : index, umax = 4 : index, umin = 3 : index}
 func.func @maybe_increment(%arg0 : i1) -> index {
   %0 = test.with_bounds { umin = 3 : index, umax = 3 : index,
-                               smin = 3 : index, smax = 3 : index}
+                               smin = 3 : index, smax = 3 : index} : index
   %1 = scf.if %arg0 -> index {
     scf.yield %0 : index
   } else {
-    %2 = test.increment %0
+    %2 = test.increment %0 : index
     scf.yield %2 : index
   }
-  %3 = test.reflect_bounds %1
+  %3 = test.reflect_bounds %1 : index
   func.return %3 : index
 }
 
@@ -37,15 +37,15 @@ func.func @maybe_increment(%arg0 : i1) -> index {
 // CHECK: test.reflect_bounds {smax = 4 : index, smin = 3 : index, umax = 4 : index, umin = 3 : index}
 func.func @maybe_increment_br(%arg0 : i1) -> index {
   %0 = test.with_bounds { umin = 3 : index, umax = 3 : index,
-                               smin = 3 : index, smax = 3 : index}
+                               smin = 3 : index, smax = 3 : index} : index
   cf.cond_br %arg0, ^bb0, ^bb1
 ^bb0:
-    %1 = test.increment %0
+    %1 = test.increment %0 : index
     cf.br ^bb2(%1 : index)
 ^bb1:
     cf.br ^bb2(%0 : index)
 ^bb2(%2 : index):
-  %3 = test.reflect_bounds %2
+  %3 = test.reflect_bounds %2 : index
   func.return %3 : index
 }
 
@@ -53,16 +53,16 @@ func.func @maybe_increment_br(%arg0 : i1) -> index {
 // CHECK: test.reflect_bounds {smax = 1 : index, smin = 0 : index, umax = 1 : index, umin = 0 : index}
 func.func @for_bounds() -> index {
   %c0 = test.with_bounds { umin = 0 : index, umax = 0 : index,
-                                smin = 0 : index, smax = 0 : index}
+                                smin = 0 : index, smax = 0 : index} : index
   %c1 = test.with_bounds { umin = 1 : index, umax = 1 : index,
-                                smin = 1 : index, smax = 1 : index}
+                                smin = 1 : index, smax = 1 : index} : index
   %c2 = test.with_bounds { umin = 2 : index, umax = 2 : index,
-                                smin = 2 : index, smax = 2 : index}
+                                smin = 2 : index, smax = 2 : index} : index
 
   %0 = scf.for %arg0 = %c0 to %c2 step %c1 iter_args(%arg2 = %c0) -> index {
     scf.yield %arg0 : index
   }
-  %1 = test.reflect_bounds %0
+  %1 = test.reflect_bounds %0 : index
   func.return %1 : index
 }
 
@@ -70,17 +70,17 @@ func.func @for_bounds() -> index {
 // CHECK: test.reflect_bounds {smax = 9223372036854775807 : index, smin = -9223372036854775808 : index, umax = -1 : index, umin = 0 : index}
 func.func @no_analysis_of_loop_variants() -> index {
   %c0 = test.with_bounds { umin = 0 : index, umax = 0 : index,
-                                smin = 0 : index, smax = 0 : index}
+                                smin = 0 : index, smax = 0 : index} : index
   %c1 = test.with_bounds { umin = 1 : index, umax = 1 : index,
-                                smin = 1 : index, smax = 1 : index}
+                                smin = 1 : index, smax = 1 : index} : index
   %c2 = test.with_bounds { umin = 2 : index, umax = 2 : index,
-                                smin = 2 : index, smax = 2 : index}
+                                smin = 2 : index, smax = 2 : index} : index
 
   %0 = scf.for %arg0 = %c0 to %c2 step %c1 iter_args(%arg2 = %c0) -> index {
-    %1 = test.increment %arg2
+    %1 = test.increment %arg2 : index
     scf.yield %1 : index
   }
-  %2 = test.reflect_bounds %0
+  %2 = test.reflect_bounds %0 : index
   func.return %2 : index
 }
 
@@ -88,8 +88,8 @@ func.func @no_analysis_of_loop_variants() -> index {
 // CHECK: test.reflect_bounds {smax = 4 : index, smin = 3 : index, umax = 4 : index, umin = 3 : index}
 func.func @region_args() {
   test.with_bounds_region { umin = 3 : index, umax = 4 : index,
-                            smin = 3 : index, smax = 4 : index } %arg0 {
-    %0 = test.reflect_bounds %arg0
+                            smin = 3 : index, smax = 4 : index } %arg0 : index {
+    %0 = test.reflect_bounds %arg0 : index
   }
   func.return
 }
@@ -97,7 +97,7 @@ func.func @region_args() {
 // CHECK-LABEL: func @func_args_unbound
 // CHECK: test.reflect_bounds {smax = 9223372036854775807 : index, smin = -9223372036854775808 : index, umax = -1 : index, umin = 0 : index}
 func.func @func_args_unbound(%arg0 : index) -> index {
-  %0 = test.reflect_bounds %arg0
+  %0 = test.reflect_bounds %arg0 : index
   func.return %0 : index
 }
 
@@ -106,7 +106,7 @@ func.func @propagate_across_while_loop_false() -> index {
   // CHECK-DAG: %[[C0:.*]] = "test.constant"() <{value = 0
   // CHECK-DAG: %[[C1:.*]] = "test.constant"() <{value = 1
   %0 = test.with_bounds { umin = 0 : index, umax = 0 : index,
-                          smin = 0 : index, smax = 0 : index }
+                          smin = 0 : index, smax = 0 : index } : index
   %1 = scf.while : () -> index {
     %false = arith.constant false
     // CHECK: scf.condition(%{{.*}}) %[[C0]]
@@ -116,7 +116,7 @@ func.func @propagate_across_while_loop_false() -> index {
     scf.yield
   }
   // CHECK: return %[[C1]]
-  %2 = test.increment %1
+  %2 = test.increment %1 : index
   return %2 : index
 }
 
@@ -125,7 +125,7 @@ func.func @propagate_across_while_loop(%arg0 : i1) -> index {
   // CHECK-DAG: %[[C0:.*]] = "test.constant"() <{value = 0
   // CHECK-DAG: %[[C1:.*]] = "test.constant"() <{value = 1
   %0 = test.with_bounds { umin = 0 : index, umax = 0 : index,
-                          smin = 0 : index, smax = 0 : index }
+                          smin = 0 : index, smax = 0 : index } : index
   %1 = scf.while : () -> index {
     // CHECK: scf.condition(%{{.*}}) %[[C0]]
     scf.condition(%arg0) %0 : index
@@ -134,7 +134,7 @@ func.func @propagate_across_while_loop(%arg0 : i1) -> index {
     scf.yield
   }
   // CHECK: return %[[C1]]
-  %2 = test.increment %1
+  %2 = test.increment %1 : index
   return %2 : index
 }
 
@@ -142,7 +142,7 @@ func.func @propagate_across_while_loop(%arg0 : i1) -> index {
 func.func @dont_propagate_across_infinite_loop() -> index {
   // CHECK: %[[C0:.*]] = "test.constant"() <{value = 0
   %0 = test.with_bounds { umin = 0 : index, umax = 0 : index,
-                          smin = 0 : index, smax = 0 : index }
+                          smin = 0 : index, smax = 0 : index } : index
   // CHECK: %[[loopRes:.*]] = scf.while
   %1 = scf.while : () -> index {
     %true = arith.constant true
@@ -152,8 +152,8 @@ func.func @dont_propagate_across_infinite_loop() -> index {
   ^bb0(%i1: index):
     scf.yield
   }
-  // CHECK: %[[ret:.*]] = test.reflect_bounds %[[loopRes]]
-  %2 = test.reflect_bounds %1
+  // CHECK: %[[ret:.*]] = test.reflect_bounds %[[loopRes]] : index
+  %2 = test.reflect_bounds %1 : index
   // CHECK: return %[[ret]]
   return %2 : index
 }

diff  --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
index 0b676db18af41..bfee0391f6708 100644
--- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
+++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
@@ -663,8 +663,7 @@ ParseResult TestWithBoundsRegionOp::parse(OpAsmParser &parser,
 
   // Parse the input argument
   OpAsmParser::Argument argInfo;
-  argInfo.type = parser.getBuilder().getIndexType();
-  if (failed(parser.parseArgument(argInfo)))
+  if (failed(parser.parseArgument(argInfo, true)))
     return failure();
 
   // Parse the body region, and reuse the operand info as the argument info.
@@ -676,7 +675,7 @@ void TestWithBoundsRegionOp::print(OpAsmPrinter &p) {
   p.printOptionalAttrDict((*this)->getAttrs());
   p << ' ';
   p.printRegionArgument(getRegion().getArgument(0), /*argAttrs=*/{},
-                        /*omitType=*/true);
+                        /*omitType=*/false);
   p << ' ';
   p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
 }
@@ -707,10 +706,11 @@ void TestReflectBoundsOp::inferResultRanges(
   const ConstantIntRanges &range = argRanges[0];
   MLIRContext *ctx = getContext();
   Builder b(ctx);
-  setUminAttr(b.getIndexAttr(range.umin().getZExtValue()));
-  setUmaxAttr(b.getIndexAttr(range.umax().getZExtValue()));
-  setSminAttr(b.getIndexAttr(range.smin().getSExtValue()));
-  setSmaxAttr(b.getIndexAttr(range.smax().getSExtValue()));
+  auto intTy = getType();
+  setUminAttr(b.getIntegerAttr(intTy, range.umin()));
+  setUmaxAttr(b.getIntegerAttr(intTy, range.umax()));
+  setSminAttr(b.getIntegerAttr(intTy, range.smin()));
+  setSmaxAttr(b.getIntegerAttr(intTy, range.smax()));
   setResultRanges(getResult(), range);
 }
 

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 7fc3d22d18958..befe6aa6cede4 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -2734,49 +2734,51 @@ def TestGraphLoopOp : TEST_Op<"graph_loop",
 //===----------------------------------------------------------------------===//
 // Test InferIntRangeInterface
 //===----------------------------------------------------------------------===//
+def InferIntRangeType : AnyTypeOf<[AnyInteger, Index]>;
+
 def TestWithBoundsOp : TEST_Op<"with_bounds",
                           [DeclareOpInterfaceMethods<InferIntRangeInterface>,
                            NoMemoryEffect]> {
-  let arguments = (ins IndexAttr:$umin,
-                       IndexAttr:$umax,
-                       IndexAttr:$smin,
-                       IndexAttr:$smax);
-  let results = (outs Index:$fakeVal);
+  let arguments = (ins APIntAttr:$umin,
+                       APIntAttr:$umax,
+                       APIntAttr:$smin,
+                       APIntAttr:$smax);
+  let results = (outs InferIntRangeType:$fakeVal);
 
-  let assemblyFormat = "attr-dict";
+  let assemblyFormat = "attr-dict `:` type($fakeVal)";
 }
 
 def TestWithBoundsRegionOp : TEST_Op<"with_bounds_region",
                           [DeclareOpInterfaceMethods<InferIntRangeInterface>,
                            SingleBlock, NoTerminator]> {
-  let arguments = (ins IndexAttr:$umin,
-                       IndexAttr:$umax,
-                       IndexAttr:$smin,
-                       IndexAttr:$smax);
-  // The region has one argument of index type
+  let arguments = (ins APIntAttr:$umin,
+                       APIntAttr:$umax,
+                       APIntAttr:$smin,
+                       APIntAttr:$smax);
+  // The region has one argument of any integer type
   let regions = (region SizedRegion<1>:$region);
   let hasCustomAssemblyFormat = 1;
 }
 
 def TestIncrementOp : TEST_Op<"increment",
                          [DeclareOpInterfaceMethods<InferIntRangeInterface>,
-                         NoMemoryEffect]> {
-  let arguments = (ins Index:$value);
-  let results = (outs Index:$result);
+                         NoMemoryEffect, AllTypesMatch<["value", "result"]>]> {
+  let arguments = (ins InferIntRangeType:$value);
+  let results = (outs InferIntRangeType:$result);
 
-  let assemblyFormat = "attr-dict $value";
+  let assemblyFormat = "attr-dict $value `:` type($result)";
 }
 
 def TestReflectBoundsOp : TEST_Op<"reflect_bounds",
-                         [DeclareOpInterfaceMethods<InferIntRangeInterface>]> {
-  let arguments = (ins Index:$value,
-                       OptionalAttr<IndexAttr>:$umin,
-                       OptionalAttr<IndexAttr>:$umax,
-                       OptionalAttr<IndexAttr>:$smin,
-                       OptionalAttr<IndexAttr>:$smax);
-  let results = (outs Index:$result);
-
-  let assemblyFormat = "attr-dict $value";
+                         [DeclareOpInterfaceMethods<InferIntRangeInterface>, AllTypesMatch<["value", "result"]>]> {
+  let arguments = (ins InferIntRangeType:$value,
+                       OptionalAttr<APIntAttr>:$umin,
+                       OptionalAttr<APIntAttr>:$umax,
+                       OptionalAttr<APIntAttr>:$smin,
+                       OptionalAttr<APIntAttr>:$smax);
+  let results = (outs InferIntRangeType:$result);
+
+  let assemblyFormat = "attr-dict $value `:` type($result)";
 }
 
 //===----------------------------------------------------------------------===//


        


More information about the Mlir-commits mailing list