[Mlir-commits] [mlir] 7169996 - [mlir] Allow shape dimensions larger than 2^32

River Riddle llvmlistbot at llvm.org
Fri Dec 3 17:41:35 PST 2021


Author: River Riddle
Date: 2021-12-04T01:29:50Z
New Revision: 71699961592b4f581b65b7671cb5cd1ea0a230f3

URL: https://github.com/llvm/llvm-project/commit/71699961592b4f581b65b7671cb5cd1ea0a230f3
DIFF: https://github.com/llvm/llvm-project/commit/71699961592b4f581b65b7671cb5cd1ea0a230f3.diff

LOG: [mlir] Allow shape dimensions larger than 2^32

Internally we use int64_t to hold shapes, but for some
reason the parser was limiting shapes to unsigned. This
change updates the parser to properly handle int64_t shape
dimensions.

Differential Revision: https://reviews.llvm.org/D115086

Added: 
    

Modified: 
    mlir/lib/Parser/TypeParser.cpp
    mlir/test/IR/invalid.mlir
    mlir/test/IR/parser.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Parser/TypeParser.cpp b/mlir/lib/Parser/TypeParser.cpp
index 57442e76360f1..1f505eb9bd197 100644
--- a/mlir/lib/Parser/TypeParser.cpp
+++ b/mlir/lib/Parser/TypeParser.cpp
@@ -504,8 +504,8 @@ Parser::parseDimensionListRanked(SmallVectorImpl<int64_t> &dimensions,
         consumeToken();
       } else {
         // Make sure this integer value is in bound and valid.
-        auto dimension = getToken().getUnsignedIntegerValue();
-        if (!dimension.hasValue())
+        Optional<uint64_t> dimension = getToken().getUInt64IntegerValue();
+        if (!dimension || *dimension > std::numeric_limits<int64_t>::max())
           return emitError("invalid dimension");
         dimensions.push_back((int64_t)dimension.getValue());
         consumeToken(Token::integer);

diff  --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir
index 01082fc336baa..bfd655a5820d5 100644
--- a/mlir/test/IR/invalid.mlir
+++ b/mlir/test/IR/invalid.mlir
@@ -20,6 +20,11 @@ func @nestedtensor(tensor<tensor<i8>>) -> () // expected-error {{invalid tensor
 
 // -----
 
+// expected-error at +1 {{invalid dimension}}
+#large_dim = tensor<9223372036854775808xf32>
+
+// -----
+
 func @illegalmemrefelementtype(memref<?xtensor<i8>>) -> () // expected-error {{invalid memref element type}}
 
 // -----

diff  --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir
index 74c7320c1ab96..8f2f8706b1de2 100644
--- a/mlir/test/IR/parser.mlir
+++ b/mlir/test/IR/parser.mlir
@@ -77,6 +77,9 @@ func private @tensors(tensor<* x f32>, tensor<* x vector<2x4xf32>>,
 // CHECK: func private @tensor_encoding(tensor<16x32xf64, "sparse">)
 func private @tensor_encoding(tensor<16x32xf64, "sparse">)
 
+// CHECK: func private @large_shape_dimension(tensor<9223372036854775807xf32>)
+func private @large_shape_dimension(tensor<9223372036854775807xf32>)
+
 // CHECK: func private @functions((memref<1x?x4x?x?xi32, #map0>, memref<8xi8>) -> (), () -> ())
 func private @functions((memref<1x?x4x?x?xi32, #map0, 0>, memref<8xi8, #map1, 0>) -> (), ()->())
 


        


More information about the Mlir-commits mailing list