[Mlir-commits] [mlir] 29ad9d6 - [mlir][spirv] Add lowering for load/store zero-rank memref from std to SPIR-V.

Hanhan Wang llvmlistbot at llvm.org
Fri Feb 21 11:44:28 PST 2020


Author: Hanhan Wang
Date: 2020-02-21T14:41:12-05:00
New Revision: 29ad9d6b26ee92c7843c06392625d894d58658c2

URL: https://github.com/llvm/llvm-project/commit/29ad9d6b26ee92c7843c06392625d894d58658c2
DIFF: https://github.com/llvm/llvm-project/commit/29ad9d6b26ee92c7843c06392625d894d58658c2.diff

LOG: [mlir][spirv] Add lowering for load/store zero-rank memref from std to SPIR-V.

Differential Revision: https://reviews.llvm.org/D74874

Added: 
    

Modified: 
    mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
    mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir
    mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
index 773066148e20..3cf50466e072 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
@@ -69,6 +69,9 @@ static Optional<int64_t> getTypeNumBytes(Type t) {
     if (!elementSize) {
       return llvm::None;
     }
+    if (memRefType.getRank() == 0) {
+      return elementSize;
+    }
     auto dims = memRefType.getShape();
     if (llvm::is_contained(dims, ShapedType::kDynamicSize) ||
         offset == MemRefType::getDynamicStrideOrOffset() ||
@@ -325,8 +328,12 @@ spirv::AccessChainOp mlir::spirv::getElementPtr(
   }
   SmallVector<Value, 2> linearizedIndices;
   // Add a '0' at the start to index into the struct.
-  linearizedIndices.push_back(builder.create<spirv::ConstantOp>(
-      loc, indexType, IntegerAttr::get(indexType, 0)));
+  auto zero = spirv::ConstantOp::getZero(indexType, loc, &builder);
+  linearizedIndices.push_back(zero);
+  // If it is a zero-rank memref type, extract the element directly.
+  if (!ptrLoc) {
+    ptrLoc = zero;
+  }
   linearizedIndices.push_back(ptrLoc);
   return builder.create<spirv::AccessChainOp>(loc, basePtr, linearizedIndices);
 }

diff  --git a/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir
index 1ec8b4c7802b..341df27460a0 100644
--- a/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir
+++ b/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir
@@ -312,3 +312,41 @@ func @sitofp(%arg0 : i32) {
 func @memref_type(%arg0: memref<3xi1>) {
   return
 }
+
+// CHECK-LABEL: @load_store_zero_rank_float
+// CHECK: [[ARG0:%.*]]: !spv.ptr<!spv.struct<!spv.array<1 x f32 [4]> [0]>, StorageBuffer>,
+// CHECK: [[ARG1:%.*]]: !spv.ptr<!spv.struct<!spv.array<1 x f32 [4]> [0]>, StorageBuffer>)
+func @load_store_zero_rank_float(%arg0: memref<f32>, %arg1: memref<f32>) {
+  //      CHECK: [[ZERO1:%.*]] = spv.constant 0 : i32
+  //      CHECK: spv.AccessChain [[ARG0]][
+  // CHECK-SAME: [[ZERO1]], [[ZERO1]]
+  // CHECK-SAME: ] :
+  //      CHECK: spv.Load "StorageBuffer" %{{.*}} : f32
+  %0 = load %arg0[] : memref<f32>
+  //      CHECK: [[ZERO2:%.*]] = spv.constant 0 : i32
+  //      CHECK: spv.AccessChain [[ARG1]][
+  // CHECK-SAME: [[ZERO2]], [[ZERO2]]
+  // CHECK-SAME: ] :
+  //      CHECK: spv.Store "StorageBuffer" %{{.*}} : f32
+  store %0, %arg1[] : memref<f32>
+  return
+}
+
+// CHECK-LABEL: @load_store_zero_rank_int
+// CHECK: [[ARG0:%.*]]: !spv.ptr<!spv.struct<!spv.array<1 x i32 [4]> [0]>, StorageBuffer>,
+// CHECK: [[ARG1:%.*]]: !spv.ptr<!spv.struct<!spv.array<1 x i32 [4]> [0]>, StorageBuffer>)
+func @load_store_zero_rank_int(%arg0: memref<i32>, %arg1: memref<i32>) {
+  //      CHECK: [[ZERO1:%.*]] = spv.constant 0 : i32
+  //      CHECK: spv.AccessChain [[ARG0]][
+  // CHECK-SAME: [[ZERO1]], [[ZERO1]]
+  // CHECK-SAME: ] :
+  //      CHECK: spv.Load "StorageBuffer" %{{.*}} : i32
+  %0 = load %arg0[] : memref<i32>
+  //      CHECK: [[ZERO2:%.*]] = spv.constant 0 : i32
+  //      CHECK: spv.AccessChain [[ARG1]][
+  // CHECK-SAME: [[ZERO2]], [[ZERO2]]
+  // CHECK-SAME: ] :
+  //      CHECK: spv.Store "StorageBuffer" %{{.*}} : i32
+  store %0, %arg1[] : memref<i32>
+  return
+}

diff  --git a/mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir b/mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir
index 9c8be4e8df7e..d89f1fff2fc2 100644
--- a/mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir
@@ -23,3 +23,37 @@ spv.module "Logical" "GLSL450" {
     spv.Return
   }
 }
+
+// -----
+
+spv.module "Logical" "GLSL450" {
+  spv.func @load_store_zero_rank_float(%arg0: !spv.ptr<!spv.struct<!spv.array<1 x f32 [4]> [0]>, StorageBuffer>, %arg1: !spv.ptr<!spv.struct<!spv.array<1 x f32 [4]> [0]>, StorageBuffer>) "None" {
+    // CHECK: [[LOAD_PTR:%.*]] = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr<!spv.struct<!spv.array<1 x f32 [4]> [0]>
+    // CHECK-NEXT: [[VAL:%.*]] = spv.Load "StorageBuffer" [[LOAD_PTR]] : f32
+    %0 = spv.constant 0 : i32
+    %1 = spv.AccessChain %arg0[%0, %0] : !spv.ptr<!spv.struct<!spv.array<1 x f32 [4]> [0]>, StorageBuffer>
+    %2 = spv.Load "StorageBuffer" %1 : f32
+
+    // CHECK: [[STORE_PTR:%.*]] = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr<!spv.struct<!spv.array<1 x f32 [4]> [0]>
+    // CHECK-NEXT: spv.Store "StorageBuffer" [[STORE_PTR]], [[VAL]] : f32
+    %3 = spv.constant 0 : i32
+    %4 = spv.AccessChain %arg1[%3, %3] : !spv.ptr<!spv.struct<!spv.array<1 x f32 [4]> [0]>, StorageBuffer>
+    spv.Store "StorageBuffer" %4, %2 : f32
+    spv.Return
+  }
+
+  spv.func @load_store_zero_rank_int(%arg0: !spv.ptr<!spv.struct<!spv.array<1 x i32 [4]> [0]>, StorageBuffer>, %arg1: !spv.ptr<!spv.struct<!spv.array<1 x i32 [4]> [0]>, StorageBuffer>) "None" {
+    // CHECK: [[LOAD_PTR:%.*]] = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr<!spv.struct<!spv.array<1 x i32 [4]> [0]>
+    // CHECK-NEXT: [[VAL:%.*]] = spv.Load "StorageBuffer" [[LOAD_PTR]] : i32
+    %0 = spv.constant 0 : i32
+    %1 = spv.AccessChain %arg0[%0, %0] : !spv.ptr<!spv.struct<!spv.array<1 x i32 [4]> [0]>, StorageBuffer>
+    %2 = spv.Load "StorageBuffer" %1 : i32
+
+    // CHECK: [[STORE_PTR:%.*]] = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr<!spv.struct<!spv.array<1 x i32 [4]> [0]>
+    // CHECK-NEXT: spv.Store "StorageBuffer" [[STORE_PTR]], [[VAL]] : i32
+    %3 = spv.constant 0 : i32
+    %4 = spv.AccessChain %arg1[%3, %3] : !spv.ptr<!spv.struct<!spv.array<1 x i32 [4]> [0]>, StorageBuffer>
+    spv.Store "StorageBuffer" %4, %2 : i32
+    spv.Return
+  }
+}


        


More information about the Mlir-commits mailing list