[Mlir-commits] [mlir] 3049c76 - [mlir][vector][spirv] Lower vector.load and vector.store to SPIR-V (#71674)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Nov 24 02:29:48 PST 2023
Author: Kai Wang
Date: 2023-11-24T10:29:43Z
New Revision: 3049c76e43be0db265ea4b2a0497adabc88edcc5
URL: https://github.com/llvm/llvm-project/commit/3049c76e43be0db265ea4b2a0497adabc88edcc5
DIFF: https://github.com/llvm/llvm-project/commit/3049c76e43be0db265ea4b2a0497adabc88edcc5.diff
LOG: [mlir][vector][spirv] Lower vector.load and vector.store to SPIR-V (#71674)
Add patterns to lower vector.load to spirv.load and vector.store to
spirv.store.
Added:
Modified:
mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 9b29179f3687165..dcc6449d3fe8927 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -509,6 +509,76 @@ struct VectorShuffleOpConvert final
}
};
+struct VectorLoadOpConverter final
+ : public OpConversionPattern<vector::LoadOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::LoadOp loadOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto memrefType = loadOp.getMemRefType();
+ auto attr =
+ dyn_cast_or_null<spirv::StorageClassAttr>(memrefType.getMemorySpace());
+ if (!attr)
+ return rewriter.notifyMatchFailure(
+ loadOp, "expected spirv.storage_class memory space");
+
+ const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
+ auto loc = loadOp.getLoc();
+ Value accessChain =
+ spirv::getElementPtr(typeConverter, memrefType, adaptor.getBase(),
+ adaptor.getIndices(), loc, rewriter);
+ if (!accessChain)
+ return rewriter.notifyMatchFailure(
+ loadOp, "failed to get memref element pointer");
+
+ spirv::StorageClass storageClass = attr.getValue();
+ auto vectorType = loadOp.getVectorType();
+ auto vectorPtrType = spirv::PointerType::get(vectorType, storageClass);
+ Value castedAccessChain =
+ rewriter.create<spirv::BitcastOp>(loc, vectorPtrType, accessChain);
+ rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, vectorType,
+ castedAccessChain);
+
+ return success();
+ }
+};
+
+struct VectorStoreOpConverter final
+ : public OpConversionPattern<vector::StoreOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::StoreOp storeOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto memrefType = storeOp.getMemRefType();
+ auto attr =
+ dyn_cast_or_null<spirv::StorageClassAttr>(memrefType.getMemorySpace());
+ if (!attr)
+ return rewriter.notifyMatchFailure(
+ storeOp, "expected spirv.storage_class memory space");
+
+ const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
+ auto loc = storeOp.getLoc();
+ Value accessChain =
+ spirv::getElementPtr(typeConverter, memrefType, adaptor.getBase(),
+ adaptor.getIndices(), loc, rewriter);
+ if (!accessChain)
+ return rewriter.notifyMatchFailure(
+ storeOp, "failed to get memref element pointer");
+
+ spirv::StorageClass storageClass = attr.getValue();
+ auto vectorType = storeOp.getVectorType();
+ auto vectorPtrType = spirv::PointerType::get(vectorType, storageClass);
+ Value castedAccessChain =
+ rewriter.create<spirv::BitcastOp>(loc, vectorPtrType, accessChain);
+ rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, castedAccessChain,
+ adaptor.getValueToStore());
+
+ return success();
+ }
+};
+
struct VectorReductionToDotProd final : OpRewritePattern<vector::ReductionOp> {
using OpRewritePattern::OpRewritePattern;
@@ -614,15 +684,16 @@ struct VectorReductionToDotProd final : OpRewritePattern<vector::ReductionOp> {
void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns) {
- patterns.add<VectorBitcastConvert, VectorBroadcastConvert,
- VectorExtractElementOpConvert, VectorExtractOpConvert,
- VectorExtractStridedSliceOpConvert,
- VectorFmaOpConvert<spirv::GLFmaOp>,
- VectorFmaOpConvert<spirv::CLFmaOp>, VectorInsertElementOpConvert,
- VectorInsertOpConvert, VectorReductionPattern<GL_MAX_MIN_OPS>,
- VectorReductionPattern<CL_MAX_MIN_OPS>, VectorShapeCast,
- VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
- VectorSplatPattern>(typeConverter, patterns.getContext());
+ patterns.add<
+ VectorBitcastConvert, VectorBroadcastConvert,
+ VectorExtractElementOpConvert, VectorExtractOpConvert,
+ VectorExtractStridedSliceOpConvert, VectorFmaOpConvert<spirv::GLFmaOp>,
+ VectorFmaOpConvert<spirv::CLFmaOp>, VectorInsertElementOpConvert,
+ VectorInsertOpConvert, VectorReductionPattern<GL_MAX_MIN_OPS>,
+ VectorReductionPattern<CL_MAX_MIN_OPS>, VectorShapeCast,
+ VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
+ VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter>(
+ typeConverter, patterns.getContext());
}
void mlir::populateVectorReductionToSPIRVDotProductPatterns(
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 6265a057a1b85c9..75b2822a8527363 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -631,3 +631,105 @@ func.func @shape_cast_size1_vector(%arg0 : vector<f32>) -> vector<1xf32> {
%1 = vector.shape_cast %arg0 : vector<f32> to vector<1xf32>
return %1 : vector<1xf32>
}
+
+// -----
+
+module attributes {
+ spirv.target_env = #spirv.target_env<
+ #spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
+ } {
+
+// CHECK-LABEL: @vector_load
+// CHECK-SAME: (%[[ARG0:.*]]: memref<4xf32, #spirv.storage_class<StorageBuffer>>)
+// CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<4xf32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<4 x f32, stride=4> [0])>, StorageBuffer>
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[C0]] : index to i32
+// CHECK: %[[CST1:.+]] = spirv.Constant 0 : i32
+// CHECK: %[[CST2:.+]] = spirv.Constant 0 : i32
+// CHECK: %[[CST3:.+]] = spirv.Constant 1 : i32
+// CHECK: %[[S2:.+]] = spirv.IMul %[[CST3]], %[[S1]] : i32
+// CHECK: %[[S3:.+]] = spirv.IAdd %[[CST2]], %[[S2]] : i32
+// CHECK: %[[S4:.+]] = spirv.AccessChain %[[S0]][%[[CST1]], %[[S3]]] : !spirv.ptr<!spirv.struct<(!spirv.array<4 x f32, stride=4> [0])>, StorageBuffer>, i32, i32
+// CHECK: %[[S5:.+]] = spirv.Bitcast %[[S4]] : !spirv.ptr<f32, StorageBuffer> to !spirv.ptr<vector<4xf32>, StorageBuffer>
+// CHECK: %[[R0:.+]] = spirv.Load "StorageBuffer" %[[S5]] : vector<4xf32>
+// CHECK: return %[[R0]] : vector<4xf32>
+func.func @vector_load(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer>>) -> vector<4xf32> {
+ %idx = arith.constant 0 : index
+ %cst_0 = arith.constant 0.000000e+00 : f32
+ %0 = vector.load %arg0[%idx] : memref<4xf32, #spirv.storage_class<StorageBuffer>>, vector<4xf32>
+ return %0: vector<4xf32>
+}
+
+// CHECK-LABEL: @vector_load_2d
+// CHECK-SAME: (%[[ARG0:.*]]: memref<4x4xf32, #spirv.storage_class<StorageBuffer>>) -> vector<4xf32> {
+// CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<4x4xf32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<16 x f32, stride=4> [0])>, StorageBuffer>
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[C0]] : index to i32
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: %[[S2:.+]] = builtin.unrealized_conversion_cast %[[C1]] : index to i32
+// CHECK: %[[CST0_1:.+]] = spirv.Constant 0 : i32
+// CHECK: %[[CST0_2:.+]] = spirv.Constant 0 : i32
+// CHECK: %[[CST4:.+]] = spirv.Constant 4 : i32
+// CHECK: %[[S3:.+]] = spirv.IMul %[[CST4]], %[[S1]] : i32
+// CHECK: %[[S4:.+]] = spirv.IAdd %[[CST0_2]], %[[S3]] : i32
+// CHECK: %[[CST1:.+]] = spirv.Constant 1 : i32
+// CHECK: %[[S5:.+]] = spirv.IMul %[[CST1]], %[[S2]] : i32
+// CHECK: %[[S6:.+]] = spirv.IAdd %[[S4]], %[[S5]] : i32
+// CHECK: %[[S7:.+]] = spirv.AccessChain %[[S0]][%[[CST0_1]], %[[S6]]] : !spirv.ptr<!spirv.struct<(!spirv.array<16 x f32, stride=4> [0])>, StorageBuffer>, i32, i32
+// CHECK: %[[S8:.+]] = spirv.Bitcast %[[S7]] : !spirv.ptr<f32, StorageBuffer> to !spirv.ptr<vector<4xf32>, StorageBuffer>
+// CHECK: %[[R0:.+]] = spirv.Load "StorageBuffer" %[[S8]] : vector<4xf32>
+// CHECK: return %[[R0]] : vector<4xf32>
+func.func @vector_load_2d(%arg0 : memref<4x4xf32, #spirv.storage_class<StorageBuffer>>) -> vector<4xf32> {
+ %idx_0 = arith.constant 0 : index
+ %idx_1 = arith.constant 1 : index
+ %0 = vector.load %arg0[%idx_0, %idx_1] : memref<4x4xf32, #spirv.storage_class<StorageBuffer>>, vector<4xf32>
+ return %0: vector<4xf32>
+}
+
+// CHECK-LABEL: @vector_store
+// CHECK-SAME: (%[[ARG0:.*]]: memref<4xf32, #spirv.storage_class<StorageBuffer>>
+// CHECK-SAME: %[[ARG1:.*]]: vector<4xf32>
+// CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<4xf32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<4 x f32, stride=4> [0])>, StorageBuffer>
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[C0]] : index to i32
+// CHECK: %[[CST1:.+]] = spirv.Constant 0 : i32
+// CHECK: %[[CST2:.+]] = spirv.Constant 0 : i32
+// CHECK: %[[CST3:.+]] = spirv.Constant 1 : i32
+// CHECK: %[[S2:.+]] = spirv.IMul %[[CST3]], %[[S1]] : i32
+// CHECK: %[[S3:.+]] = spirv.IAdd %[[CST2]], %[[S2]] : i32
+// CHECK: %[[S4:.+]] = spirv.AccessChain %[[S0]][%[[CST1]], %[[S3]]] : !spirv.ptr<!spirv.struct<(!spirv.array<4 x f32, stride=4> [0])>, StorageBuffer>, i32, i32
+// CHECK: %[[S5:.+]] = spirv.Bitcast %[[S4]] : !spirv.ptr<f32, StorageBuffer> to !spirv.ptr<vector<4xf32>, StorageBuffer>
+// CHECK: spirv.Store "StorageBuffer" %[[S5]], %[[ARG1]] : vector<4xf32>
+func.func @vector_store(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer>>, %arg1 : vector<4xf32>) {
+ %idx = arith.constant 0 : index
+ vector.store %arg1, %arg0[%idx] : memref<4xf32, #spirv.storage_class<StorageBuffer>>, vector<4xf32>
+ return
+}
+
+// CHECK-LABEL: @vector_store_2d
+// CHECK-SAME: (%[[ARG0:.*]]: memref<4x4xf32, #spirv.storage_class<StorageBuffer>>
+// CHECK-SAME: %[[ARG1:.*]]: vector<4xf32>
+// CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<4x4xf32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<16 x f32, stride=4> [0])>, StorageBuffer>
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[C0]] : index to i32
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: %[[S2:.+]] = builtin.unrealized_conversion_cast %[[C1]] : index to i32
+// CHECK: %[[CST0_1:.+]] = spirv.Constant 0 : i32
+// CHECK: %[[CST0_2:.+]] = spirv.Constant 0 : i32
+// CHECK: %[[CST4:.+]] = spirv.Constant 4 : i32
+// CHECK: %[[S3:.+]] = spirv.IMul %[[CST4]], %[[S1]] : i32
+// CHECK: %[[S4:.+]] = spirv.IAdd %[[CST0_2]], %[[S3]] : i32
+// CHECK: %[[CST1:.+]] = spirv.Constant 1 : i32
+// CHECK: %[[S5:.+]] = spirv.IMul %[[CST1]], %[[S2]] : i32
+// CHECK: %[[S6:.+]] = spirv.IAdd %[[S4]], %[[S5]] : i32
+// CHECK: %[[S7:.+]] = spirv.AccessChain %[[S0]][%[[CST0_1]], %[[S6]]] : !spirv.ptr<!spirv.struct<(!spirv.array<16 x f32, stride=4> [0])>, StorageBuffer>, i32, i32
+// CHECK: %[[S8:.+]] = spirv.Bitcast %[[S7]] : !spirv.ptr<f32, StorageBuffer> to !spirv.ptr<vector<4xf32>, StorageBuffer>
+// CHECK: spirv.Store "StorageBuffer" %[[S8]], %[[ARG1]] : vector<4xf32>
+func.func @vector_store_2d(%arg0 : memref<4x4xf32, #spirv.storage_class<StorageBuffer>>, %arg1 : vector<4xf32>) {
+ %idx_0 = arith.constant 0 : index
+ %idx_1 = arith.constant 1 : index
+ vector.store %arg1, %arg0[%idx_0, %idx_1] : memref<4x4xf32, #spirv.storage_class<StorageBuffer>>, vector<4xf32>
+ return
+}
+
+} // end module
More information about the Mlir-commits
mailing list