[Mlir-commits] [mlir] fab4b59 - [mlir] Conversion of ViewOp with memory space to LLVM.

Alex Zinenko llvmlistbot at llvm.org
Wed Aug 5 03:20:00 PDT 2020


Author: Arpith C. Jacob
Date: 2020-08-05T12:19:52+02:00
New Revision: fab4b59961aa35109861493dfe071979d56b4360

URL: https://github.com/llvm/llvm-project/commit/fab4b59961aa35109861493dfe071979d56b4360
DIFF: https://github.com/llvm/llvm-project/commit/fab4b59961aa35109861493dfe071979d56b4360.diff

LOG: [mlir] Conversion of ViewOp with memory space to LLVM.

Handle the case where the ViewOp takes in a memref that has
an memory space.

Reviewed By: ftynse, bondhugula, nicolasvasilache

Differential Revision: https://reviews.llvm.org/D85048

Added: 
    

Modified: 
    mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
    mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index 533ac629ba5a..2ada7c425600 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -2960,8 +2960,10 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<ViewOp> {
 
     // Field 1: Copy the allocated pointer, used for malloc/free.
     Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc);
+    auto srcMemRefType = viewOp.source().getType().cast<MemRefType>();
     Value bitcastPtr = rewriter.create<LLVM::BitcastOp>(
-        loc, targetElementTy.getPointerTo(), allocatedPtr);
+        loc, targetElementTy.getPointerTo(srcMemRefType.getMemorySpace()),
+        allocatedPtr);
     targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr);
 
     // Field 2: Copy the actual aligned pointer to payload.
@@ -2969,7 +2971,8 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<ViewOp> {
     alignedPtr = rewriter.create<LLVM::GEPOp>(loc, alignedPtr.getType(),
                                               alignedPtr, adaptor.byte_shift());
     bitcastPtr = rewriter.create<LLVM::BitcastOp>(
-        loc, targetElementTy.getPointerTo(), alignedPtr);
+        loc, targetElementTy.getPointerTo(srcMemRefType.getMemorySpace()),
+        alignedPtr);
     targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr);
 
     // Field 3: The offset in the resulting type must be 0. This is because of

diff  --git a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
index 6123f68b7e85..9042bf36c1b3 100644
--- a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
@@ -824,6 +824,28 @@ func @view(%arg0 : index, %arg1 : index, %arg2 : index) {
   // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
   %5 = view %0[%arg2][] : memref<2048xi8> to memref<64x4xf32>
 
+  // Test view memory space.
+  // CHECK: llvm.mlir.constant(2048 : index) : !llvm.i64
+  // CHECK: llvm.mlir.undef : !llvm.struct<(ptr<i8, 4>, ptr<i8, 4>, i64, array<1 x i64>, array<1 x i64>)>
+  %6 = alloc() : memref<2048xi8, 4>
+
+  // CHECK: llvm.mlir.undef : !llvm.struct<(ptr<float, 4>, ptr<float, 4>, i64, array<2 x i64>, array<2 x i64>)>
+  // CHECK: %[[BASE_PTR_4:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<i8, 4>, ptr<i8, 4>, i64, array<1 x i64>, array<1 x i64>)>
+  // CHECK: %[[SHIFTED_BASE_PTR_4:.*]] = llvm.getelementptr %[[BASE_PTR_4]][%[[ARG2]]] : (!llvm.ptr<i8, 4>, !llvm.i64) -> !llvm.ptr<i8, 4>
+  // CHECK: %[[CAST_SHIFTED_BASE_PTR_4:.*]] = llvm.bitcast %[[SHIFTED_BASE_PTR_4]] : !llvm.ptr<i8, 4> to !llvm.ptr<float, 4>
+  // CHECK: llvm.insertvalue %[[CAST_SHIFTED_BASE_PTR_4]], %{{.*}}[1] : !llvm.struct<(ptr<float, 4>, ptr<float, 4>, i64, array<2 x i64>, array<2 x i64>)>
+  // CHECK: %[[C0_4:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
+  // CHECK: llvm.insertvalue %[[C0_4]], %{{.*}}[2] : !llvm.struct<(ptr<float, 4>, ptr<float, 4>, i64, array<2 x i64>, array<2 x i64>)>
+  // CHECK: llvm.mlir.constant(4 : index) : !llvm.i64
+  // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm.struct<(ptr<float, 4>, ptr<float, 4>, i64, array<2 x i64>, array<2 x i64>)>
+  // CHECK: llvm.mlir.constant(1 : index) : !llvm.i64
+  // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : !llvm.struct<(ptr<float, 4>, ptr<float, 4>, i64, array<2 x i64>, array<2 x i64>)>
+  // CHECK: llvm.mlir.constant(64 : index) : !llvm.i64
+  // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr<float, 4>, ptr<float, 4>, i64, array<2 x i64>, array<2 x i64>)>
+  // CHECK: llvm.mlir.constant(4 : index) : !llvm.i64
+  // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr<float, 4>, ptr<float, 4>, i64, array<2 x i64>, array<2 x i64>)>
+  %7 = view %6[%arg2][] : memref<2048xi8, 4> to memref<64x4xf32, 4>
+
   return
 }
 


        


More information about the Mlir-commits mailing list