[Mlir-commits] [mlir] 7169996 - [mlir] Allow shape dimensions larger than 2^32
River Riddle
llvmlistbot at llvm.org
Fri Dec 3 17:41:35 PST 2021
Author: River Riddle
Date: 2021-12-04T01:29:50Z
New Revision: 71699961592b4f581b65b7671cb5cd1ea0a230f3
URL: https://github.com/llvm/llvm-project/commit/71699961592b4f581b65b7671cb5cd1ea0a230f3
DIFF: https://github.com/llvm/llvm-project/commit/71699961592b4f581b65b7671cb5cd1ea0a230f3.diff
LOG: [mlir] Allow shape dimensions larger than 2^32
Internally we use int64_t to hold shapes, but for some
reason the parser was limiting shapes to unsigned. This
change updates the parser to properly handle int64_t shape
dimensions.
Differential Revision: https://reviews.llvm.org/D115086
Added:
Modified:
mlir/lib/Parser/TypeParser.cpp
mlir/test/IR/invalid.mlir
mlir/test/IR/parser.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Parser/TypeParser.cpp b/mlir/lib/Parser/TypeParser.cpp
index 57442e76360f1..1f505eb9bd197 100644
--- a/mlir/lib/Parser/TypeParser.cpp
+++ b/mlir/lib/Parser/TypeParser.cpp
@@ -504,8 +504,8 @@ Parser::parseDimensionListRanked(SmallVectorImpl<int64_t> &dimensions,
consumeToken();
} else {
// Make sure this integer value is in bound and valid.
- auto dimension = getToken().getUnsignedIntegerValue();
- if (!dimension.hasValue())
+ Optional<uint64_t> dimension = getToken().getUInt64IntegerValue();
+ if (!dimension || *dimension > std::numeric_limits<int64_t>::max())
return emitError("invalid dimension");
dimensions.push_back((int64_t)dimension.getValue());
consumeToken(Token::integer);
diff --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir
index 01082fc336baa..bfd655a5820d5 100644
--- a/mlir/test/IR/invalid.mlir
+++ b/mlir/test/IR/invalid.mlir
@@ -20,6 +20,11 @@ func @nestedtensor(tensor<tensor<i8>>) -> () // expected-error {{invalid tensor
// -----
+// expected-error at +1 {{invalid dimension}}
+#large_dim = tensor<9223372036854775808xf32>
+
+// -----
+
func @illegalmemrefelementtype(memref<?xtensor<i8>>) -> () // expected-error {{invalid memref element type}}
// -----
diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir
index 74c7320c1ab96..8f2f8706b1de2 100644
--- a/mlir/test/IR/parser.mlir
+++ b/mlir/test/IR/parser.mlir
@@ -77,6 +77,9 @@ func private @tensors(tensor<* x f32>, tensor<* x vector<2x4xf32>>,
// CHECK: func private @tensor_encoding(tensor<16x32xf64, "sparse">)
func private @tensor_encoding(tensor<16x32xf64, "sparse">)
+// CHECK: func private @large_shape_dimension(tensor<9223372036854775807xf32>)
+func private @large_shape_dimension(tensor<9223372036854775807xf32>)
+
// CHECK: func private @functions((memref<1x?x4x?x?xi32, #map0>, memref<8xi8>) -> (), () -> ())
func private @functions((memref<1x?x4x?x?xi32, #map0, 0>, memref<8xi8, #map1, 0>) -> (), ()->())
More information about the Mlir-commits
mailing list