[Mlir-commits] [mlir] 9c3a73a - [mlir][spirv] Support OpenCL when lowering memref load/store
Lei Zhang
llvmlistbot at llvm.org
Mon Sep 19 10:24:44 PDT 2022
Author: Stanley Winata
Date: 2022-09-19T13:24:21-04:00
New Revision: 9c3a73a579ca71529fa11dc0e5acee22500e4d10
URL: https://github.com/llvm/llvm-project/commit/9c3a73a579ca71529fa11dc0e5acee22500e4d10
DIFF: https://github.com/llvm/llvm-project/commit/9c3a73a579ca71529fa11dc0e5acee22500e4d10.diff
LOG: [mlir][spirv] Support OpenCL when lowering memref load/store
-Add awareness to Kernel vs Shader capability for memref to SPIR-V
lowering.
-Add lowering using spv.PtrAccessChain for Kernel capability.
-Enable lowering from scalar pointee types for kernel capabilities.
Reviewed By: antiagainst
Differential Revision: https://reviews.llvm.org/D132714
Added:
Modified:
mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
mlir/test/Conversion/GPUToSPIRV/module-opencl.mlir
mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
index ff1cdacd4d521..9b480f6cc9e3a 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
+++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
@@ -14,6 +14,7 @@
#define MLIR_DIALECT_SPIRV_TRANSFORMS_SPIRVCONVERSION_H
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
#include "mlir/Transforms/DialectConversion.h"
@@ -72,6 +73,9 @@ class SPIRVTypeConverter : public TypeConverter {
/// Returns the options controlling the SPIR-V type converter.
const SPIRVConversionOptions &getOptions() const { return options; }
+ /// Checks if the SPIR-V capability inquired is supported.
+ bool allows(spirv::Capability capability);
+
private:
spirv::TargetEnv targetEnv;
SPIRVConversionOptions options;
@@ -151,10 +155,19 @@ Value linearizeIndex(ValueRange indices, ArrayRef<int64_t> strides,
// TODO: This method assumes that the `baseType` is a MemRefType with AffineMap
// that has static strides. Extend to handle dynamic strides.
-spirv::AccessChainOp getElementPtr(SPIRVTypeConverter &typeConverter,
- MemRefType baseType, Value basePtr,
- ValueRange indices, Location loc,
- OpBuilder &builder);
+Value getElementPtr(SPIRVTypeConverter &typeConverter, MemRefType baseType,
+ Value basePtr, ValueRange indices, Location loc,
+ OpBuilder &builder);
+
+// GetElementPtr implementation for Kernel/OpenCL flavored SPIR-V.
+Value getOpenCLElementPtr(SPIRVTypeConverter &typeConverter,
+ MemRefType baseType, Value basePtr,
+ ValueRange indices, Location loc, OpBuilder &builder);
+
+// GetElementPtr implementation for Vulkan/Shader flavored SPIR-V.
+Value getVulkanElementPtr(SPIRVTypeConverter &typeConverter,
+ MemRefType baseType, Value basePtr,
+ ValueRange indices, Location loc, OpBuilder &builder);
} // namespace spirv
} // namespace mlir
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index d72802be6e1df..766d42bd7ba95 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -192,7 +192,7 @@ class DeallocOpPattern final : public OpConversionPattern<memref::DeallocOp> {
ConversionPatternRewriter &rewriter) const override;
};
-/// Converts memref.load to spv.Load.
+/// Converts memref.load to spv.Load + spv.AccessChain on integers.
class IntLoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
public:
using OpConversionPattern<memref::LoadOp>::OpConversionPattern;
@@ -202,7 +202,7 @@ class IntLoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
ConversionPatternRewriter &rewriter) const override;
};
-/// Converts memref.load to spv.Load.
+/// Converts memref.load to spv.Load + spv.AccessChain.
class LoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
public:
using OpConversionPattern<memref::LoadOp>::OpConversionPattern;
@@ -319,11 +319,11 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
return failure();
auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
- spirv::AccessChainOp accessChainOp =
+ Value accessChain =
spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(),
adaptor.getIndices(), loc, rewriter);
- if (!accessChainOp)
+ if (!accessChain)
return failure();
int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
@@ -333,27 +333,41 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
Type pointeeType = typeConverter.convertType(memrefType)
.cast<spirv::PointerType>()
.getPointeeType();
- Type structElemType = pointeeType.cast<spirv::StructType>().getElementType(0);
Type dstType;
- if (auto arrayType = structElemType.dyn_cast<spirv::ArrayType>())
- dstType = arrayType.getElementType();
- else
- dstType = structElemType.cast<spirv::RuntimeArrayType>().getElementType();
-
+ if (typeConverter.allows(spirv::Capability::Kernel)) {
+ // For OpenCL Kernel, pointer will be directly pointing to the element.
+ dstType = pointeeType;
+ } else {
+ // For Vulkan we need to extract element from wrapping struct and array.
+ Type structElemType =
+ pointeeType.cast<spirv::StructType>().getElementType(0);
+ if (auto arrayType = structElemType.dyn_cast<spirv::ArrayType>())
+ dstType = arrayType.getElementType();
+ else
+ dstType = structElemType.cast<spirv::RuntimeArrayType>().getElementType();
+ }
int dstBits = dstType.getIntOrFloatBitWidth();
assert(dstBits % srcBits == 0);
// If the rewrited load op has the same bit width, use the loading value
// directly.
if (srcBits == dstBits) {
- Value loadVal =
- rewriter.create<spirv::LoadOp>(loc, accessChainOp.getResult());
+ Value loadVal = rewriter.create<spirv::LoadOp>(loc, accessChain);
if (isBool)
loadVal = castIntNToBool(loc, loadVal, rewriter);
rewriter.replaceOp(loadOp, loadVal);
return success();
}
+ // Bitcasting is currently unsupported for Kernel capability /
+ // spv.PtrAccessChain.
+ if (typeConverter.allows(spirv::Capability::Kernel))
+ return failure();
+
+ auto accessChainOp = accessChain.getDefiningOp<spirv::AccessChainOp>();
+ if (!accessChainOp)
+ return failure();
+
// Assume that getElementPtr() works linearizely. If it's a scalar, the method
// still returns a linearized accessing. If the accessing is not linearized,
// there will be offset issues.
@@ -432,11 +446,11 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
auto loc = storeOp.getLoc();
auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
- spirv::AccessChainOp accessChainOp =
+ Value accessChain =
spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(),
adaptor.getIndices(), loc, rewriter);
- if (!accessChainOp)
+ if (!accessChain)
return failure();
int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
@@ -448,12 +462,19 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
Type pointeeType = typeConverter.convertType(memrefType)
.cast<spirv::PointerType>()
.getPointeeType();
- Type structElemType = pointeeType.cast<spirv::StructType>().getElementType(0);
Type dstType;
- if (auto arrayType = structElemType.dyn_cast<spirv::ArrayType>())
- dstType = arrayType.getElementType();
- else
- dstType = structElemType.cast<spirv::RuntimeArrayType>().getElementType();
+ if (typeConverter.allows(spirv::Capability::Kernel)) {
+ // For OpenCL Kernel, pointer will be directly pointing to the element.
+ dstType = pointeeType;
+ } else {
+ // For Vulkan we need to extract element from wrapping struct and array.
+ Type structElemType =
+ pointeeType.cast<spirv::StructType>().getElementType(0);
+ if (auto arrayType = structElemType.dyn_cast<spirv::ArrayType>())
+ dstType = arrayType.getElementType();
+ else
+ dstType = structElemType.cast<spirv::RuntimeArrayType>().getElementType();
+ }
int dstBits = dstType.getIntOrFloatBitWidth();
assert(dstBits % srcBits == 0);
@@ -462,11 +483,19 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
Value storeVal = adaptor.getValue();
if (isBool)
storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter);
- rewriter.replaceOpWithNewOp<spirv::StoreOp>(
- storeOp, accessChainOp.getResult(), storeVal);
+ rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, accessChain, storeVal);
return success();
}
+ // Bitcasting is currently unsupported for Kernel capability /
+ // spv.PtrAccessChain.
+ if (typeConverter.allows(spirv::Capability::Kernel))
+ return failure();
+
+ auto accessChainOp = accessChain.getDefiningOp<spirv::AccessChainOp>();
+ if (!accessChainOp)
+ return failure();
+
// Since there are multi threads in the processing, the emulation will be done
// with atomic operations. E.g., if the storing value is i8, rewrite the
// StoreOp to
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 8d0bde66ebdf9..b56b8c02943d6 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -122,6 +122,10 @@ MLIRContext *SPIRVTypeConverter::getContext() const {
return targetEnv.getAttr().getContext();
}
+bool SPIRVTypeConverter::allows(spirv::Capability capability) {
+ return targetEnv.allows(capability);
+}
+
// TODO: This is a utility function that should probably be exposed by the
// SPIR-V dialect. Keeping it local till the use case arises.
static Optional<int64_t> getTypeNumBytes(const SPIRVConversionOptions &options,
@@ -334,6 +338,12 @@ static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv,
return nullptr;
}
+ // For OpenCL Kernel we can just emit a pointer pointing to the element.
+ if (targetEnv.allows(spirv::Capability::Kernel))
+ return spirv::PointerType::get(arrayElemType, storageClass);
+
+ // For Vulkan we need extra wrapping struct and array to satisfy interface
+ // needs.
if (!type.hasStaticShape()) {
int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride);
@@ -393,6 +403,12 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
return nullptr;
}
+ // For OpenCL Kernel we can just emit a pointer pointing to the element.
+ if (targetEnv.allows(spirv::Capability::Kernel))
+ return spirv::PointerType::get(arrayElemType, storageClass);
+
+ // For Vulkan we need extra wrapping struct and array to satisfy interface
+ // needs.
if (!type.hasStaticShape()) {
int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride);
@@ -712,9 +728,10 @@ Value mlir::spirv::linearizeIndex(ValueRange indices, ArrayRef<int64_t> strides,
return linearizedIndex;
}
-spirv::AccessChainOp mlir::spirv::getElementPtr(
- SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr,
- ValueRange indices, Location loc, OpBuilder &builder) {
+Value mlir::spirv::getVulkanElementPtr(SPIRVTypeConverter &typeConverter,
+ MemRefType baseType, Value basePtr,
+ ValueRange indices, Location loc,
+ OpBuilder &builder) {
// Get base and offset of the MemRefType and verify they are static.
int64_t offset;
@@ -742,6 +759,50 @@ spirv::AccessChainOp mlir::spirv::getElementPtr(
return builder.create<spirv::AccessChainOp>(loc, basePtr, linearizedIndices);
}
+Value mlir::spirv::getOpenCLElementPtr(SPIRVTypeConverter &typeConverter,
+ MemRefType baseType, Value basePtr,
+ ValueRange indices, Location loc,
+ OpBuilder &builder) {
+ // Get base and offset of the MemRefType and verify they are static.
+
+ int64_t offset;
+ SmallVector<int64_t, 4> strides;
+ if (failed(getStridesAndOffset(baseType, strides, offset)) ||
+ llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset()) ||
+ offset == MemRefType::getDynamicStrideOrOffset()) {
+ return nullptr;
+ }
+
+ auto indexType = typeConverter.getIndexType();
+
+ SmallVector<Value, 2> linearizedIndices;
+ auto zero = spirv::ConstantOp::getZero(indexType, loc, builder);
+
+ Value linearIndex;
+ if (baseType.getRank() == 0) {
+ linearIndex = zero;
+ } else {
+ linearIndex =
+ linearizeIndex(indices, strides, offset, indexType, loc, builder);
+ }
+ return builder.create<spirv::PtrAccessChainOp>(loc, basePtr, linearIndex,
+ linearizedIndices);
+}
+
+Value mlir::spirv::getElementPtr(SPIRVTypeConverter &typeConverter,
+ MemRefType baseType, Value basePtr,
+ ValueRange indices, Location loc,
+ OpBuilder &builder) {
+
+ if (typeConverter.allows(spirv::Capability::Kernel)) {
+ return getOpenCLElementPtr(typeConverter, baseType, basePtr, indices, loc,
+ builder);
+ }
+
+ return getVulkanElementPtr(typeConverter, baseType, basePtr, indices, loc,
+ builder);
+}
+
//===----------------------------------------------------------------------===//
// SPIR-V ConversionTarget
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/GPUToSPIRV/module-opencl.mlir b/mlir/test/Conversion/GPUToSPIRV/module-opencl.mlir
index ac22d13815e6e..5413157a1a1ea 100644
--- a/mlir/test/Conversion/GPUToSPIRV/module-opencl.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/module-opencl.mlir
@@ -9,7 +9,7 @@ module attributes {
// CHECK: spv.func
// CHECK-SAME: {{%.*}}: f32
// CHECK-NOT: spv.interface_var_abi
- // CHECK-SAME: {{%.*}}: !spv.ptr<!spv.struct<(!spv.array<12 x f32>)>, CrossWorkgroup>
+ // CHECK-SAME: {{%.*}}: !spv.ptr<f32, CrossWorkgroup>
// CHECK-NOT: spv.interface_var_abi
// CHECK-SAME: spv.entry_point_abi = #spv.entry_point_abi<local_size = dense<[32, 4, 1]> : vector<3xi32>>
gpu.func @basic_module_structure(%arg0 : f32, %arg1 : memref<12xf32, #spv.storage_class<CrossWorkgroup>>) kernel
diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
index 212363c15584b..98e44b958e8ee 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
@@ -109,6 +109,111 @@ func.func @store_i1(%dst: memref<4xi1, #spv.storage_class<StorageBuffer>>, %i: i
// -----
+// Check for Kernel capability, that with proper compute and storage extensions, we don't need to
+// perform special tricks.
+
+module attributes {
+ spv.target_env = #spv.target_env<
+ #spv.vce<v1.0,
+ [
+ Kernel, Addresses, Int8, Int16, Int64, Float16, Float64], []>, #spv.resource_limits<>>
+} {
+
+// CHECK-LABEL: @load_store_zero_rank_float
+func.func @load_store_zero_rank_float(%arg0: memref<f32, #spv.storage_class<CrossWorkgroup>>, %arg1: memref<f32, #spv.storage_class<CrossWorkgroup>>) {
+ // CHECK: [[ARG0:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref<f32, #spv.storage_class<CrossWorkgroup>> to !spv.ptr<f32, CrossWorkgroup>
+ // CHECK: [[ARG1:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref<f32, #spv.storage_class<CrossWorkgroup>> to !spv.ptr<f32, CrossWorkgroup>
+ // CHECK: [[ZERO1:%.*]] = spv.Constant 0 : i32
+ // CHECK: spv.PtrAccessChain [[ARG0]][
+ // CHECK-SAME: [[ZERO1]]
+ // CHECK-SAME: ] :
+ // CHECK: spv.Load "CrossWorkgroup" %{{.*}} : f32
+ %0 = memref.load %arg0[] : memref<f32, #spv.storage_class<CrossWorkgroup>>
+ // CHECK: [[ZERO2:%.*]] = spv.Constant 0 : i32
+ // CHECK: spv.PtrAccessChain [[ARG1]][
+ // CHECK-SAME: [[ZERO2]]
+ // CHECK-SAME: ] :
+ // CHECK: spv.Store "CrossWorkgroup" %{{.*}} : f32
+ memref.store %0, %arg1[] : memref<f32, #spv.storage_class<CrossWorkgroup>>
+ return
+}
+
+// CHECK-LABEL: @load_store_zero_rank_int
+func.func @load_store_zero_rank_int(%arg0: memref<i32, #spv.storage_class<CrossWorkgroup>>, %arg1: memref<i32, #spv.storage_class<CrossWorkgroup>>) {
+ // CHECK: [[ARG0:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref<i32, #spv.storage_class<CrossWorkgroup>> to !spv.ptr<i32, CrossWorkgroup>
+ // CHECK: [[ARG1:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref<i32, #spv.storage_class<CrossWorkgroup>> to !spv.ptr<i32, CrossWorkgroup>
+ // CHECK: [[ZERO1:%.*]] = spv.Constant 0 : i32
+ // CHECK: spv.PtrAccessChain [[ARG0]][
+ // CHECK-SAME: [[ZERO1]]
+ // CHECK-SAME: ] :
+ // CHECK: spv.Load "CrossWorkgroup" %{{.*}} : i32
+ %0 = memref.load %arg0[] : memref<i32, #spv.storage_class<CrossWorkgroup>>
+ // CHECK: [[ZERO2:%.*]] = spv.Constant 0 : i32
+ // CHECK: spv.PtrAccessChain [[ARG1]][
+ // CHECK-SAME: [[ZERO2]]
+ // CHECK-SAME: ] :
+ // CHECK: spv.Store "CrossWorkgroup" %{{.*}} : i32
+ memref.store %0, %arg1[] : memref<i32, #spv.storage_class<CrossWorkgroup>>
+ return
+}
+
+// CHECK-LABEL: func @load_store_unknown_dim
+func.func @load_store_unknown_dim(%i: index, %source: memref<?xi32, #spv.storage_class<CrossWorkgroup>>, %dest: memref<?xi32, #spv.storage_class<CrossWorkgroup>>) {
+ // CHECK: %[[SRC:.+]] = builtin.unrealized_conversion_cast {{.+}} : memref<?xi32, #spv.storage_class<CrossWorkgroup>> to !spv.ptr<i32, CrossWorkgroup>
+ // CHECK: %[[DST:.+]] = builtin.unrealized_conversion_cast {{.+}} : memref<?xi32, #spv.storage_class<CrossWorkgroup>> to !spv.ptr<i32, CrossWorkgroup>
+ // CHECK: %[[AC0:.+]] = spv.PtrAccessChain %[[SRC]]
+ // CHECK: spv.Load "CrossWorkgroup" %[[AC0]]
+ %0 = memref.load %source[%i] : memref<?xi32, #spv.storage_class<CrossWorkgroup>>
+ // CHECK: %[[AC1:.+]] = spv.PtrAccessChain %[[DST]]
+ // CHECK: spv.Store "CrossWorkgroup" %[[AC1]]
+ memref.store %0, %dest[%i]: memref<?xi32, #spv.storage_class<CrossWorkgroup>>
+ return
+}
+
+// CHECK-LABEL: func @load_i1
+// CHECK-SAME: (%[[SRC:.+]]: memref<4xi1, #spv.storage_class<CrossWorkgroup>>, %[[IDX:.+]]: index)
+func.func @load_i1(%src: memref<4xi1, #spv.storage_class<CrossWorkgroup>>, %i : index) -> i1 {
+ // CHECK-DAG: %[[SRC_CAST:.+]] = builtin.unrealized_conversion_cast %[[SRC]] : memref<4xi1, #spv.storage_class<CrossWorkgroup>> to !spv.ptr<i8, CrossWorkgroup>
+ // CHECK-DAG: %[[IDX_CAST:.+]] = builtin.unrealized_conversion_cast %[[IDX]]
+ // CHECK: %[[ZERO_0:.+]] = spv.Constant 0 : i32
+ // CHECK: %[[ZERO_1:.+]] = spv.Constant 0 : i32
+ // CHECK: %[[ONE:.+]] = spv.Constant 1 : i32
+ // CHECK: %[[MUL:.+]] = spv.IMul %[[ONE]], %[[IDX_CAST]] : i32
+ // CHECK: %[[ADD:.+]] = spv.IAdd %[[ZERO_1]], %[[MUL]] : i32
+ // CHECK: %[[ADDR:.+]] = spv.PtrAccessChain %[[SRC_CAST]][%[[ADD]]]
+ // CHECK: %[[VAL:.+]] = spv.Load "CrossWorkgroup" %[[ADDR]] : i8
+ // CHECK: %[[ONE_I8:.+]] = spv.Constant 1 : i8
+ // CHECK: %[[BOOL:.+]] = spv.IEqual %[[VAL]], %[[ONE_I8]] : i8
+ %0 = memref.load %src[%i] : memref<4xi1, #spv.storage_class<CrossWorkgroup>>
+ // CHECK: return %[[BOOL]]
+ return %0: i1
+}
+
+// CHECK-LABEL: func @store_i1
+// CHECK-SAME: %[[DST:.+]]: memref<4xi1, #spv.storage_class<CrossWorkgroup>>,
+// CHECK-SAME: %[[IDX:.+]]: index
+func.func @store_i1(%dst: memref<4xi1, #spv.storage_class<CrossWorkgroup>>, %i: index) {
+ %true = arith.constant true
+ // CHECK-DAG: %[[DST_CAST:.+]] = builtin.unrealized_conversion_cast %[[DST]] : memref<4xi1, #spv.storage_class<CrossWorkgroup>> to !spv.ptr<i8, CrossWorkgroup>
+ // CHECK-DAG: %[[IDX_CAST:.+]] = builtin.unrealized_conversion_cast %[[IDX]]
+ // CHECK: %[[ZERO_0:.+]] = spv.Constant 0 : i32
+ // CHECK: %[[ZERO_1:.+]] = spv.Constant 0 : i32
+ // CHECK: %[[ONE:.+]] = spv.Constant 1 : i32
+ // CHECK: %[[MUL:.+]] = spv.IMul %[[ONE]], %[[IDX_CAST]] : i32
+ // CHECK: %[[ADD:.+]] = spv.IAdd %[[ZERO_1]], %[[MUL]] : i32
+ // CHECK: %[[ADDR:.+]] = spv.PtrAccessChain %[[DST_CAST]][%[[ADD]]]
+ // CHECK: %[[ZERO_I8:.+]] = spv.Constant 0 : i8
+ // CHECK: %[[ONE_I8:.+]] = spv.Constant 1 : i8
+ // CHECK: %[[RES:.+]] = spv.Select %{{.+}}, %[[ONE_I8]], %[[ZERO_I8]] : i1, i8
+ // CHECK: spv.Store "CrossWorkgroup" %[[ADDR]], %[[RES]] : i8
+ memref.store %true, %dst[%i]: memref<4xi1, #spv.storage_class<CrossWorkgroup>>
+ return
+}
+
+} // end module
+
+// -----
+
// Check that access chain indices are properly adjusted if non-32-bit types are
// emulated via 32-bit types.
// TODO: Test i64 types.
More information about the Mlir-commits
mailing list