[Mlir-commits] [mlir] e12db3e - [mlir] Allow index as element type of memref

Stephan Herhut llvmlistbot at llvm.org
Thu Jul 30 05:36:45 PDT 2020


Author: Stephan Herhut
Date: 2020-07-30T14:35:22+02:00
New Revision: e12db3ed997de473b2b7189781dbec7a239a3994

URL: https://github.com/llvm/llvm-project/commit/e12db3ed997de473b2b7189781dbec7a239a3994
DIFF: https://github.com/llvm/llvm-project/commit/e12db3ed997de473b2b7189781dbec7a239a3994.diff

LOG: [mlir] Allow index as element type of memref

Differential Revision: https://reviews.llvm.org/D84934

Added: 
    

Modified: 
    mlir/docs/Rationale/Rationale.md
    mlir/lib/IR/StandardTypes.cpp
    mlir/lib/Parser/TypeParser.cpp
    mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
    mlir/test/IR/invalid.mlir
    mlir/test/IR/parser.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/docs/Rationale/Rationale.md b/mlir/docs/Rationale/Rationale.md
index 22e21383e9036..e906559a14a7f 100644
--- a/mlir/docs/Rationale/Rationale.md
+++ b/mlir/docs/Rationale/Rationale.md
@@ -202,32 +202,19 @@ and described in
 interest
 [starts here](https://www.google.com/url?q=https://youtu.be/Ntj8ab-5cvE?t%3D596&sa=D&ust=1529450150971000&usg=AFQjCNFQHEWL7m8q3eO-1DiKw9zqC2v24Q).
 
-### Index type disallowed in vector/memref types
+### Index type disallowed in vector types
 
-Index types are not allowed as elements of `vector` and `memref` types. Index
+Index types are not allowed as elements of `vector` types. Index
 types are intended to be used for platform-specific "size" values and may appear
 in subscripts, sizes of aggregate types and affine expressions. They are also
 tightly coupled with `affine.apply` and affine.load/store operations; having
 `index` type is a necessary precondition of a value to be acceptable by these
-operations. While it may be useful to have `memref<?xindex>` to express indirect
-accesses, e.g. sparse matrix manipulations or lookup tables, it creates problems
-MLIR is not ready to address yet. MLIR needs to internally store constants of
-aggregate types and emit code operating on values of those types, which are
-subject to target-specific size and alignment constraints. Since MLIR does not
-have a target description mechanism at the moment, it cannot reliably emit such
-code. Moreover, some platforms may not support vectors of type equivalent to
-`index`.
-
-Indirect access use cases can be alternatively supported by providing and
-`index_cast` instruction that allows for conversion between `index` and
-fixed-width integer types, at the SSA value level. It has an additional benefit
-of supporting smaller integer types, e.g. `i8` or `i16`, for small indices
-instead of (presumably larger) `index` type.
-
-Index types are allowed as element types of `tensor` types. The `tensor` type
-specifically abstracts the target-specific aspects that intersect with the
-code-generation-related/lowering-related concerns explained above. In fact, the
-`tensor` type even allows dialect-specific types as element types.
+operations.
+
+We allow `index` types in tensors and memrefs as a code generation strategy has
+to map `index` to an implementation type and hence needs to be able to
+materialize corresponding values. However, the target might lack support for
+`vector` values with the target specfic equivalent of the `index` type.
 
 ### Bit width of a non-primitive type and `index` is undefined
 

diff  --git a/mlir/lib/IR/StandardTypes.cpp b/mlir/lib/IR/StandardTypes.cpp
index 5a9d22148b76f..2d1f8d8eb6f05 100644
--- a/mlir/lib/IR/StandardTypes.cpp
+++ b/mlir/lib/IR/StandardTypes.cpp
@@ -398,7 +398,7 @@ MemRefType MemRefType::getImpl(ArrayRef<int64_t> shape, Type elementType,
   auto *context = elementType.getContext();
 
   // Check that memref is formed from allowed types.
-  if (!elementType.isIntOrFloat() &&
+  if (!elementType.isIntOrIndexOrFloat() &&
       !elementType.isa<VectorType, ComplexType>())
     return emitOptionalError(location, "invalid memref element type"),
            MemRefType();

diff  --git a/mlir/lib/Parser/TypeParser.cpp b/mlir/lib/Parser/TypeParser.cpp
index 9d8d198aa1c84..f5c98f3c6f9dd 100644
--- a/mlir/lib/Parser/TypeParser.cpp
+++ b/mlir/lib/Parser/TypeParser.cpp
@@ -217,7 +217,7 @@ Type Parser::parseMemRefType() {
     return nullptr;
 
   // Check that memref is formed from allowed types.
-  if (!elementType.isIntOrFloat() &&
+  if (!elementType.isIntOrIndexOrFloat() &&
       !elementType.isa<VectorType, ComplexType>())
     return emitError(typeLoc, "invalid memref element type"), nullptr;
 

diff  --git a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
index 2129cf6819a9c..c1ec558da86f1 100644
--- a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
@@ -1277,3 +1277,17 @@ func @bfloat(%arg0: bf16) -> bf16 {
   return %arg0 : bf16
 }
 // CHECK-NEXT: return %{{.*}} : !llvm.bfloat
+
+// -----
+
+// CHECK-LABEL: func @memref_index
+// CHECK-SAME: %arg0: !llvm<"i64*">, %arg1: !llvm<"i64*">,
+// CHECK-SAME: %arg2: !llvm.i64, %arg3: !llvm.i64, %arg4: !llvm.i64)
+// CHECK-SAME: -> !llvm<"{ i64*, i64*, i64, [1 x i64], [1 x i64] }">
+// CHECK32-LABEL: func @memref_index
+// CHECK32-SAME: %arg0: !llvm<"i32*">, %arg1: !llvm<"i32*">,
+// CHECK32-SAME: %arg2: !llvm.i32, %arg3: !llvm.i32, %arg4: !llvm.i32)
+// CHECK32-SAME: -> !llvm<"{ i32*, i32*, i32, [1 x i32], [1 x i32] }">
+func @memref_index(%arg0: memref<32xindex>) -> memref<32xindex> {
+  return %arg0 : memref<32xindex>
+}

diff  --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir
index 2d8474c655f60..dcf04735c9010 100644
--- a/mlir/test/IR/invalid.mlir
+++ b/mlir/test/IR/invalid.mlir
@@ -19,10 +19,6 @@ func @nestedtensor(tensor<tensor<i8>>) -> () // expected-error {{invalid tensor
 
 func @indexvector(vector<4 x index>) -> () // expected-error {{vector elements must be int or float type}}
 
-// -----
-
-func @indexmemref(memref<? x index>) -> () // expected-error {{invalid memref element type}}
-
 // -----
 // Test no map in memref type.
 func @memrefs(memref<2x4xi8, >) // expected-error {{expected list element}}

diff  --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir
index 93db23fd5d0db..8d3d161ef27e0 100644
--- a/mlir/test/IR/parser.mlir
+++ b/mlir/test/IR/parser.mlir
@@ -140,6 +140,9 @@ func @memrefs_compose_with_id(memref<2x2xi8, affine_map<(d0, d1) -> (d0, d1)>,
 func @complex_types(complex<i1>) -> complex<f32>
 
 
+// CHECK: func @memref_with_index_elems(memref<1x?xindex>)
+func @memref_with_index_elems(memref<1x?xindex>)
+
 // CHECK: func @memref_with_complex_elems(memref<1x?xcomplex<f32>>)
 func @memref_with_complex_elems(memref<1x?xcomplex<f32>>)
 


        


More information about the Mlir-commits mailing list