[Mlir-commits] [mlir] 2eb98d8 - [mlir][spirv] Allow bitwidth emulation on runtime arrays
Lei Zhang
llvmlistbot at llvm.org
Mon Apr 12 14:04:27 PDT 2021
Author: Lei Zhang
Date: 2021-04-12T17:04:18-04:00
New Revision: 2eb98d89ac866e32cb56727174e4d1c1413479c8
URL: https://github.com/llvm/llvm-project/commit/2eb98d89ac866e32cb56727174e4d1c1413479c8
DIFF: https://github.com/llvm/llvm-project/commit/2eb98d89ac866e32cb56727174e4d1c1413479c8.diff
LOG: [mlir][spirv] Allow bitwidth emulation on runtime arrays
Runtime arrays are converted from memrefs with unknown
dimensions.
Reviewed By: mravishankar
Differential Revision: https://reviews.llvm.org/D100335
Added:
Modified:
mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
index 397d26b0499d..7dea6e87d105 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
+++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
@@ -994,13 +994,16 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp,
bool isBool = srcBits == 1;
if (isBool)
srcBits = typeConverter.getOptions().boolNumBits;
- auto dstType = typeConverter.convertType(memrefType)
- .cast<spirv::PointerType>()
- .getPointeeType()
- .cast<spirv::StructType>()
- .getElementType(0)
- .cast<spirv::ArrayType>()
- .getElementType();
+ Type pointeeType = typeConverter.convertType(memrefType)
+ .cast<spirv::PointerType>()
+ .getPointeeType();
+ Type structElemType = pointeeType.cast<spirv::StructType>().getElementType(0);
+ Type dstType;
+ if (auto arrayType = structElemType.dyn_cast<spirv::ArrayType>())
+ dstType = arrayType.getElementType();
+ else
+ dstType = structElemType.cast<spirv::RuntimeArrayType>().getElementType();
+
int dstBits = dstType.getIntOrFloatBitWidth();
assert(dstBits % srcBits == 0);
@@ -1136,13 +1139,16 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp,
bool isBool = srcBits == 1;
if (isBool)
srcBits = typeConverter.getOptions().boolNumBits;
- auto dstType = typeConverter.convertType(memrefType)
- .cast<spirv::PointerType>()
- .getPointeeType()
- .cast<spirv::StructType>()
- .getElementType(0)
- .cast<spirv::ArrayType>()
- .getElementType();
+ Type pointeeType = typeConverter.convertType(memrefType)
+ .cast<spirv::PointerType>()
+ .getPointeeType();
+ Type structElemType = pointeeType.cast<spirv::StructType>().getElementType(0);
+ Type dstType;
+ if (auto arrayType = structElemType.dyn_cast<spirv::ArrayType>())
+ dstType = arrayType.getElementType();
+ else
+ dstType = structElemType.cast<spirv::RuntimeArrayType>().getElementType();
+
int dstBits = dstType.getIntOrFloatBitWidth();
assert(dstBits % srcBits == 0);
diff --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
index 86d390a2ce70..82157e0c9973 100644
--- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
+++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
@@ -905,6 +905,19 @@ func @load_store_zero_rank_int(%arg0: memref<i32>, %arg1: memref<i32>) {
return
}
+// CHECK-LABEL: func @load_store_unknown_dim
+// CHECK-SAME: %[[SRC:[a-z0-9]+]]: !spv.ptr<!spv.struct<(!spv.rtarray<i32, stride=4> [0])>, StorageBuffer>,
+// CHECK-SAME: %[[DST:[a-z0-9]+]]: !spv.ptr<!spv.struct<(!spv.rtarray<i32, stride=4> [0])>, StorageBuffer>)
+func @load_store_unknown_dim(%i: index, %source: memref<?xi32>, %dest: memref<?xi32>) {
+ // CHECK: %[[AC0:.+]] = spv.AccessChain %[[SRC]]
+ // CHECK: spv.Load "StorageBuffer" %[[AC0]]
+ %0 = memref.load %source[%i] : memref<?xi32>
+ // CHECK: %[[AC1:.+]] = spv.AccessChain %[[DST]]
+ // CHECK: spv.Store "StorageBuffer" %[[AC1]]
+ memref.store %0, %dest[%i]: memref<?xi32>
+ return
+}
+
} // end module
// -----
More information about the Mlir-commits
mailing list