[Mlir-commits] [mlir] 29ad9d6 - [mlir][spirv] Add lowering for load/store zero-rank memref from std to SPIR-V.
Hanhan Wang
llvmlistbot at llvm.org
Fri Feb 21 11:44:28 PST 2020
Author: Hanhan Wang
Date: 2020-02-21T14:41:12-05:00
New Revision: 29ad9d6b26ee92c7843c06392625d894d58658c2
URL: https://github.com/llvm/llvm-project/commit/29ad9d6b26ee92c7843c06392625d894d58658c2
DIFF: https://github.com/llvm/llvm-project/commit/29ad9d6b26ee92c7843c06392625d894d58658c2.diff
LOG: [mlir][spirv] Add lowering for load/store zero-rank memref from std to SPIR-V.
Differential Revision: https://reviews.llvm.org/D74874
Added:
Modified:
mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir
mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
index 773066148e20..3cf50466e072 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
@@ -69,6 +69,9 @@ static Optional<int64_t> getTypeNumBytes(Type t) {
if (!elementSize) {
return llvm::None;
}
+ if (memRefType.getRank() == 0) {
+ return elementSize;
+ }
auto dims = memRefType.getShape();
if (llvm::is_contained(dims, ShapedType::kDynamicSize) ||
offset == MemRefType::getDynamicStrideOrOffset() ||
@@ -325,8 +328,12 @@ spirv::AccessChainOp mlir::spirv::getElementPtr(
}
SmallVector<Value, 2> linearizedIndices;
// Add a '0' at the start to index into the struct.
- linearizedIndices.push_back(builder.create<spirv::ConstantOp>(
- loc, indexType, IntegerAttr::get(indexType, 0)));
+ auto zero = spirv::ConstantOp::getZero(indexType, loc, &builder);
+ linearizedIndices.push_back(zero);
+ // If it is a zero-rank memref type, extract the element directly.
+ if (!ptrLoc) {
+ ptrLoc = zero;
+ }
linearizedIndices.push_back(ptrLoc);
return builder.create<spirv::AccessChainOp>(loc, basePtr, linearizedIndices);
}
diff --git a/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir
index 1ec8b4c7802b..341df27460a0 100644
--- a/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir
+++ b/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir
@@ -312,3 +312,41 @@ func @sitofp(%arg0 : i32) {
func @memref_type(%arg0: memref<3xi1>) {
return
}
+
+// CHECK-LABEL: @load_store_zero_rank_float
+// CHECK: [[ARG0:%.*]]: !spv.ptr<!spv.struct<!spv.array<1 x f32 [4]> [0]>, StorageBuffer>,
+// CHECK: [[ARG1:%.*]]: !spv.ptr<!spv.struct<!spv.array<1 x f32 [4]> [0]>, StorageBuffer>)
+func @load_store_zero_rank_float(%arg0: memref<f32>, %arg1: memref<f32>) {
+ // CHECK: [[ZERO1:%.*]] = spv.constant 0 : i32
+ // CHECK: spv.AccessChain [[ARG0]][
+ // CHECK-SAME: [[ZERO1]], [[ZERO1]]
+ // CHECK-SAME: ] :
+ // CHECK: spv.Load "StorageBuffer" %{{.*}} : f32
+ %0 = load %arg0[] : memref<f32>
+ // CHECK: [[ZERO2:%.*]] = spv.constant 0 : i32
+ // CHECK: spv.AccessChain [[ARG1]][
+ // CHECK-SAME: [[ZERO2]], [[ZERO2]]
+ // CHECK-SAME: ] :
+ // CHECK: spv.Store "StorageBuffer" %{{.*}} : f32
+ store %0, %arg1[] : memref<f32>
+ return
+}
+
+// CHECK-LABEL: @load_store_zero_rank_int
+// CHECK: [[ARG0:%.*]]: !spv.ptr<!spv.struct<!spv.array<1 x i32 [4]> [0]>, StorageBuffer>,
+// CHECK: [[ARG1:%.*]]: !spv.ptr<!spv.struct<!spv.array<1 x i32 [4]> [0]>, StorageBuffer>)
+func @load_store_zero_rank_int(%arg0: memref<i32>, %arg1: memref<i32>) {
+ // CHECK: [[ZERO1:%.*]] = spv.constant 0 : i32
+ // CHECK: spv.AccessChain [[ARG0]][
+ // CHECK-SAME: [[ZERO1]], [[ZERO1]]
+ // CHECK-SAME: ] :
+ // CHECK: spv.Load "StorageBuffer" %{{.*}} : i32
+ %0 = load %arg0[] : memref<i32>
+ // CHECK: [[ZERO2:%.*]] = spv.constant 0 : i32
+ // CHECK: spv.AccessChain [[ARG1]][
+ // CHECK-SAME: [[ZERO2]], [[ZERO2]]
+ // CHECK-SAME: ] :
+ // CHECK: spv.Store "StorageBuffer" %{{.*}} : i32
+ store %0, %arg1[] : memref<i32>
+ return
+}
diff --git a/mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir b/mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir
index 9c8be4e8df7e..d89f1fff2fc2 100644
--- a/mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir
@@ -23,3 +23,37 @@ spv.module "Logical" "GLSL450" {
spv.Return
}
}
+
+// -----
+
+spv.module "Logical" "GLSL450" {
+ spv.func @load_store_zero_rank_float(%arg0: !spv.ptr<!spv.struct<!spv.array<1 x f32 [4]> [0]>, StorageBuffer>, %arg1: !spv.ptr<!spv.struct<!spv.array<1 x f32 [4]> [0]>, StorageBuffer>) "None" {
+ // CHECK: [[LOAD_PTR:%.*]] = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr<!spv.struct<!spv.array<1 x f32 [4]> [0]>
+ // CHECK-NEXT: [[VAL:%.*]] = spv.Load "StorageBuffer" [[LOAD_PTR]] : f32
+ %0 = spv.constant 0 : i32
+ %1 = spv.AccessChain %arg0[%0, %0] : !spv.ptr<!spv.struct<!spv.array<1 x f32 [4]> [0]>, StorageBuffer>
+ %2 = spv.Load "StorageBuffer" %1 : f32
+
+ // CHECK: [[STORE_PTR:%.*]] = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr<!spv.struct<!spv.array<1 x f32 [4]> [0]>
+ // CHECK-NEXT: spv.Store "StorageBuffer" [[STORE_PTR]], [[VAL]] : f32
+ %3 = spv.constant 0 : i32
+ %4 = spv.AccessChain %arg1[%3, %3] : !spv.ptr<!spv.struct<!spv.array<1 x f32 [4]> [0]>, StorageBuffer>
+ spv.Store "StorageBuffer" %4, %2 : f32
+ spv.Return
+ }
+
+ spv.func @load_store_zero_rank_int(%arg0: !spv.ptr<!spv.struct<!spv.array<1 x i32 [4]> [0]>, StorageBuffer>, %arg1: !spv.ptr<!spv.struct<!spv.array<1 x i32 [4]> [0]>, StorageBuffer>) "None" {
+ // CHECK: [[LOAD_PTR:%.*]] = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr<!spv.struct<!spv.array<1 x i32 [4]> [0]>
+ // CHECK-NEXT: [[VAL:%.*]] = spv.Load "StorageBuffer" [[LOAD_PTR]] : i32
+ %0 = spv.constant 0 : i32
+ %1 = spv.AccessChain %arg0[%0, %0] : !spv.ptr<!spv.struct<!spv.array<1 x i32 [4]> [0]>, StorageBuffer>
+ %2 = spv.Load "StorageBuffer" %1 : i32
+
+ // CHECK: [[STORE_PTR:%.*]] = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr<!spv.struct<!spv.array<1 x i32 [4]> [0]>
+ // CHECK-NEXT: spv.Store "StorageBuffer" [[STORE_PTR]], [[VAL]] : i32
+ %3 = spv.constant 0 : i32
+ %4 = spv.AccessChain %arg1[%3, %3] : !spv.ptr<!spv.struct<!spv.array<1 x i32 [4]> [0]>, StorageBuffer>
+ spv.Store "StorageBuffer" %4, %2 : i32
+ spv.Return
+ }
+}
More information about the Mlir-commits
mailing list