[PATCH] D79099: [mlir][vector] let type_cast take non-zero addrspace.

Wen-Heng (Jack) Chung via Phabricator via llvm-commits llvm-commits at lists.llvm.org
Wed Apr 29 09:38:50 PDT 2020


whchung created this revision.
whchung added reviewers: ftynse, aartbik.
whchung added projects: LLVM, MLIR.
Herald added subscribers: llvm-commits, Kayjukh, frgossen, grosul1, Joonsoo, liufengdb, lucyrfox, mgester, arpith-jacob, nicolasvasilache, antiagainst, shauheen, jpienaar, rriddle, mehdi_amini.
Herald added a reviewer: nicolasvasilache.
whchung edited the summary of this revision.

Enhance lowering logic and tests so vector.type_cast take memrefs on non-zero addrspaces.


Repository:
  rG LLVM Github Monorepo

https://reviews.llvm.org/D79099

Files:
  mlir/lib/Dialect/Vector/VectorOps.cpp
  mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir


Index: mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
===================================================================
--- mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -418,6 +418,21 @@
 //       CHECK:   llvm.mlir.constant(0 : index
 //       CHECK:   llvm.insertvalue {{.*}}[2] : !llvm<"{ [8 x [8 x <8 x float>]]*, [8 x [8 x <8 x float>]]*, i64 }">
 
+func @vector_type_cast_non_zero_addrspace(%arg0: memref<8x8x8xf32, 3>) -> memref<vector<8x8x8xf32>, 3> {
+  %0 = vector.type_cast %arg0: memref<8x8x8xf32, 3> to memref<vector<8x8x8xf32>, 3>
+  return %0 : memref<vector<8x8x8xf32>, 3>
+}
+// CHECK-LABEL: llvm.func @vector_type_cast
+//       CHECK:   llvm.mlir.undef : !llvm<"{ [8 x [8 x <8 x float>]] addrspace(3)*, [8 x [8 x <8 x float>]] addrspace(3)*, i64 }">
+//       CHECK:   %[[allocated:.*]] = llvm.extractvalue {{.*}}[0] : !llvm<"{ float addrspace(3)*, float addrspace(3)*, i64, [3 x i64], [3 x i64] }">
+//       CHECK:   %[[allocatedBit:.*]] = llvm.bitcast %[[allocated]] : !llvm<"float addrspace(3)*"> to !llvm<"[8 x [8 x <8 x float>]] addrspace(3)*">
+//       CHECK:   llvm.insertvalue %[[allocatedBit]], {{.*}}[0] : !llvm<"{ [8 x [8 x <8 x float>]] addrspace(3)*, [8 x [8 x <8 x float>]] addrspace(3)*, i64 }">
+//       CHECK:   %[[aligned:.*]] = llvm.extractvalue {{.*}}[1] : !llvm<"{ float addrspace(3)*, float addrspace(3)*, i64, [3 x i64], [3 x i64] }">
+//       CHECK:   %[[alignedBit:.*]] = llvm.bitcast %[[aligned]] : !llvm<"float addrspace(3)*"> to !llvm<"[8 x [8 x <8 x float>]] addrspace(3)*">
+//       CHECK:   llvm.insertvalue %[[alignedBit]], {{.*}}[1] : !llvm<"{ [8 x [8 x <8 x float>]] addrspace(3)*, [8 x [8 x <8 x float>]] addrspace(3)*, i64 }">
+//       CHECK:   llvm.mlir.constant(0 : index
+//       CHECK:   llvm.insertvalue {{.*}}[2] : !llvm<"{ [8 x [8 x <8 x float>]] addrspace(3)*, [8 x [8 x <8 x float>]] addrspace(3)*, i64 }">
+
 func @vector_print_scalar_i32(%arg0: i32) {
   vector.print %arg0 : i32
   return
Index: mlir/lib/Dialect/Vector/VectorOps.cpp
===================================================================
--- mlir/lib/Dialect/Vector/VectorOps.cpp
+++ mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -1463,7 +1463,8 @@
 //===----------------------------------------------------------------------===//
 
 static MemRefType inferVectorTypeCastResultType(MemRefType t) {
-  return MemRefType::get({}, VectorType::get(t.getShape(), t.getElementType()));
+  return MemRefType::get({}, VectorType::get(t.getShape(), t.getElementType()),
+                         {}, t.getMemorySpace());
 }
 
 void TypeCastOp::build(OpBuilder &builder, OperationState &result,


-------------- next part --------------
A non-text attachment was scrubbed...
Name: D79099.260941.patch
Type: text/x-patch
Size: 2697 bytes
Desc: not available
URL: <http://lists.llvm.org/pipermail/llvm-commits/attachments/20200429/5629f6b1/attachment-0001.bin>


More information about the llvm-commits mailing list