[Mlir-commits] [mlir] [mlir][vector][spirv] Lower vector.transfer_read and vector.transfer_write to SPIR-V (PR #69708)

Kai Wang llvmlistbot at llvm.org
Wed Nov 1 03:15:32 PDT 2023


https://github.com/Hsiangkai updated https://github.com/llvm/llvm-project/pull/69708

>From c853ef9f45e03ce1d36dbff1d8f684108091b85b Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Mon, 16 Oct 2023 18:41:34 +0100
Subject: [PATCH] [mlir][vector][spirv] Lower vector.transfer_read and
 vector.transfer_write to SPIR-V

Add patterns to lower vector.transfer_read to spirv.load and
vector.transfer_write to spirv.store.
---
 .../VectorToSPIRV/VectorToSPIRV.cpp           |  97 ++++++++++++++++-
 .../VectorToSPIRV/vector-to-spirv.mlir        | 103 ++++++++++++++++++
 2 files changed, 199 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 9b29179f3687165..4bc206e59a83537 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -509,6 +509,99 @@ struct VectorShuffleOpConvert final
   }
 };
 
+struct VectorTransferReadOpConverter final
+    : public OpConversionPattern<vector::TransferReadOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::TransferReadOp transferReadOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    if (transferReadOp.getMask())
+      return rewriter.notifyMatchFailure(transferReadOp,
+                                         "unsupported transfer_read with mask");
+
+    if (transferReadOp.hasOutOfBoundsDim())
+      return rewriter.notifyMatchFailure(
+          transferReadOp,
+          "unsupported transfer_read with out-of-bound dimensions");
+
+    auto sourceType = transferReadOp.getSource().getType();
+    auto memrefType = dyn_cast<MemRefType>(sourceType);
+    if (!memrefType)
+      return rewriter.notifyMatchFailure(transferReadOp, "not a memref source");
+
+    auto attr =
+        dyn_cast_or_null<spirv::StorageClassAttr>(memrefType.getMemorySpace());
+    if (!attr)
+      return failure();
+
+    const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
+    auto loc = transferReadOp.getLoc();
+    Value accessChain =
+        spirv::getElementPtr(typeConverter, memrefType, adaptor.getSource(),
+                             adaptor.getIndices(), loc, rewriter);
+    if (!accessChain)
+      return failure();
+
+    spirv::StorageClass storageClass = attr.getValue();
+    auto vectorType = transferReadOp.getVectorType();
+    auto vectorPtrType = spirv::PointerType::get(vectorType, storageClass);
+    Value castedAccessChain =
+        rewriter.create<spirv::BitcastOp>(loc, vectorPtrType, accessChain);
+    rewriter.replaceOpWithNewOp<spirv::LoadOp>(transferReadOp, vectorType,
+                                               castedAccessChain);
+
+    return success();
+  }
+};
+
+struct VectorTransferWriteOpConverter final
+    : public OpConversionPattern<vector::TransferWriteOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::TransferWriteOp transferWriteOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    if (transferWriteOp.getMask())
+      return rewriter.notifyMatchFailure(
+          transferWriteOp, "unsupported transfer_write with mask");
+
+    if (transferWriteOp.hasOutOfBoundsDim())
+      return rewriter.notifyMatchFailure(
+          transferWriteOp,
+          "unsupported transfer_write with out-of-bound dimensions");
+
+    auto sourceType = transferWriteOp.getSource().getType();
+    auto memrefType = dyn_cast<MemRefType>(sourceType);
+    if (!memrefType)
+      return rewriter.notifyMatchFailure(transferWriteOp,
+                                         "not a memref source");
+
+    auto attr =
+        dyn_cast_or_null<spirv::StorageClassAttr>(memrefType.getMemorySpace());
+    if (!attr)
+      return failure();
+
+    const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
+    auto loc = transferWriteOp.getLoc();
+    Value accessChain =
+        spirv::getElementPtr(typeConverter, memrefType, adaptor.getSource(),
+                             adaptor.getIndices(), loc, rewriter);
+    if (!accessChain)
+      return failure();
+
+    spirv::StorageClass storageClass = attr.getValue();
+    auto vectorType = transferWriteOp.getVectorType();
+    auto vectorPtrType = spirv::PointerType::get(vectorType, storageClass);
+    Value castedAccessChain =
+        rewriter.create<spirv::BitcastOp>(loc, vectorPtrType, accessChain);
+    rewriter.replaceOpWithNewOp<spirv::StoreOp>(
+        transferWriteOp, castedAccessChain, adaptor.getVector());
+
+    return success();
+  }
+};
+
 struct VectorReductionToDotProd final : OpRewritePattern<vector::ReductionOp> {
   using OpRewritePattern::OpRewritePattern;
 
@@ -622,7 +715,9 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
                VectorInsertOpConvert, VectorReductionPattern<GL_MAX_MIN_OPS>,
                VectorReductionPattern<CL_MAX_MIN_OPS>, VectorShapeCast,
                VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
-               VectorSplatPattern>(typeConverter, patterns.getContext());
+               VectorSplatPattern, VectorTransferReadOpConverter,
+               VectorTransferWriteOpConverter>(typeConverter,
+                                               patterns.getContext());
 }
 
 void mlir::populateVectorReductionToSPIRVDotProductPatterns(
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index eba763eab9c292a..bc2e1765ac1c55a 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -631,3 +631,106 @@ func.func @shape_cast_size1_vector(%arg0 : vector<f32>) -> vector<1xf32> {
   %1 = vector.shape_cast %arg0 : vector<f32> to vector<1xf32>
   return %1 : vector<1xf32>
 }
+
+// -----
+
+module attributes {
+  spirv.target_env = #spirv.target_env<
+    #spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
+  } {
+
+// CHECK-LABEL: @transfer_read
+//  CHECK-SAME: (%[[ARG0:.*]]: memref<4xf32, #spirv.storage_class<StorageBuffer>>)
+//       CHECK:   %[[S0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<4xf32
+//       CHECK:   %[[C0:.+]] = arith.constant 0 : index
+//       CHECK:   %[[S1:.+]] = builtin.unrealized_conversion_cast %[[C0]] : index to i32
+//       CHECK:   %[[CST1:.+]] = spirv.Constant 0 : i32
+//       CHECK:   %[[CST2:.+]] = spirv.Constant 0 : i32
+//       CHECK:   %[[CST3:.+]] = spirv.Constant 1 : i32
+//       CHECK:   %[[S2:.+]] = spirv.IMul %[[CST3]], %[[S1]] : i32
+//       CHECK:   %[[S3:.+]] = spirv.IAdd %[[CST2]], %[[S2]] : i32
+//       CHECK:   %[[S4:.+]] = spirv.AccessChain %[[S0]][%[[CST1]], %[[S3]]] : !spirv.ptr<!spirv.struct<(!spirv.array<4 x f32, stride=4> [0])>, StorageBuffer>, i32, i32
+//       CHECK:   %[[S5:.+]] = spirv.Bitcast %[[S4]] : !spirv.ptr<f32, StorageBuffer> to !spirv.ptr<vector<4xf32>, StorageBuffer>
+//       CHECK:   %[[R0:.+]] = spirv.Load "StorageBuffer" %[[S5]] : vector<4xf32>
+//       CHECK:   return %[[R0]]
+func.func @transfer_read(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer>>) -> vector<4xf32> {
+  %idx = arith.constant 0 : index
+  %cst_0 = arith.constant 0.000000e+00 : f32
+  %0 = vector.transfer_read %arg0[%idx], %cst_0 : memref<4xf32, #spirv.storage_class<StorageBuffer>>, vector<4xf32>
+  return %0: vector<4xf32>
+}
+
+// CHECK-LABEL: @transfer_read_2d
+//  CHECK-SAME: (%[[ARG0:.*]]: memref<4x4xf32, #spirv.storage_class<StorageBuffer>>) -> vector<4xf32> {
+//       CHECK:   %[[S0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<4x4xf32
+//       CHECK:   %[[C0:.+]] = arith.constant 0 : index
+//       CHECK:   %[[S1:.+]] = builtin.unrealized_conversion_cast %[[C0]] : index to i32
+//       CHECK:   %[[C1:.+]] = arith.constant 1 : index
+//       CHECK:   %[[S2:.+]] = builtin.unrealized_conversion_cast %[[C1]] : index to i32
+//       CHECK:   %[[CST0_1:.+]] = spirv.Constant 0 : i32
+//       CHECK:   %[[CST0_2:.+]] = spirv.Constant 0 : i32
+//       CHECK:   %[[CST4:.+]] = spirv.Constant 4 : i32
+//       CHECK:   %[[S3:.+]] = spirv.IMul %[[CST4]], %[[S1]] : i32
+//       CHECK:   %[[S4:.+]] = spirv.IAdd %[[CST0_2]], %[[S3]] : i32
+//       CHECK:   %[[CST1:.+]] = spirv.Constant 1 : i32
+//       CHECK:   %[[S5:.+]] = spirv.IMul %[[CST1]], %[[S2]] : i32
+//       CHECK:   %[[S6:.+]] = spirv.IAdd %[[S4]], %[[S5]] : i32
+//       CHECK:   %[[S7:.+]] = spirv.AccessChain %[[S0]][%[[CST0_1]], %[[S6]]] : !spirv.ptr<!spirv.struct<(!spirv.array<16 x f32, stride=4> [0])>, StorageBuffer>, i32, i32
+//       CHECK:   %[[S8:.+]] = spirv.Bitcast %[[S7]] : !spirv.ptr<f32, StorageBuffer> to !spirv.ptr<vector<4xf32>, StorageBuffer>
+//       CHECK:   %[[R0:.+]] = spirv.Load "StorageBuffer" %[[S8]] : vector<4xf32>
+//       CHECK:   return %[[R0]]
+func.func @transfer_read_2d(%arg0 : memref<4x4xf32, #spirv.storage_class<StorageBuffer>>) -> vector<4xf32> {
+  %idx_0 = arith.constant 0 : index
+  %idx_1 = arith.constant 1 : index
+  %cst_0 = arith.constant 0.000000e+00 : f32
+  %0 = vector.transfer_read %arg0[%idx_0, %idx_1], %cst_0 {in_bounds = [true]} : memref<4x4xf32, #spirv.storage_class<StorageBuffer>>, vector<4xf32>
+  return %0: vector<4xf32>
+}
+
+// CHECK-LABEL: @transfer_write
+//  CHECK-SAME: (%[[ARG0:.*]]: memref<4xf32, #spirv.storage_class<StorageBuffer>>
+//  CHECK-SAME:  %[[ARG1:.*]]: vector<4xf32>
+//       CHECK:   %[[S0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<4xf32
+//       CHECK:   %[[C0:.+]] = arith.constant 0 : index
+//       CHECK:   %[[S1:.+]] = builtin.unrealized_conversion_cast %[[C0]] : index to i32
+//       CHECK:   %[[CST1:.+]] = spirv.Constant 0 : i32
+//       CHECK:   %[[CST2:.+]] = spirv.Constant 0 : i32
+//       CHECK:   %[[CST3:.+]] = spirv.Constant 1 : i32
+//       CHECK:   %[[S2:.+]] = spirv.IMul %[[CST3]], %[[S1]] : i32
+//       CHECK:   %[[S3:.+]] = spirv.IAdd %[[CST2]], %[[S2]] : i32
+//       CHECK:   %[[S4:.+]] = spirv.AccessChain %[[S0]][%[[CST1]], %[[S3]]] : !spirv.ptr<!spirv.struct<(!spirv.array<4 x f32, stride=4> [0])>, StorageBuffer>, i32, i32
+//       CHECK:   %[[S5:.+]] = spirv.Bitcast %[[S4]] : !spirv.ptr<f32, StorageBuffer> to !spirv.ptr<vector<4xf32>, StorageBuffer>
+//       CHECK:   spirv.Store "StorageBuffer" %[[S5]], %[[ARG1]] : vector<4xf32>
+func.func @transfer_write(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer>>, %arg1 : vector<4xf32>) {
+  %idx = arith.constant 0 : index
+  vector.transfer_write %arg1, %arg0[%idx] : vector<4xf32>, memref<4xf32, #spirv.storage_class<StorageBuffer>>
+  return
+}
+
+// CHECK-LABEL: @transfer_write_2d
+//  CHECK-SAME: (%[[ARG0:.*]]: memref<4x4xf32, #spirv.storage_class<StorageBuffer>>
+//  CHECK-SAME:  %[[ARG1:.*]]: vector<4xf32>
+//       CHECK:   %[[S0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<4x4xf32
+//       CHECK:   %[[C0:.+]] = arith.constant 0 : index
+//       CHECK:   %[[S1:.+]] = builtin.unrealized_conversion_cast %[[C0]] : index to i32
+//       CHECK:   %[[C1:.+]] = arith.constant 1 : index
+//       CHECK:   %[[S2:.+]] = builtin.unrealized_conversion_cast %[[C1]] : index to i32
+//       CHECK:   %[[CST0_1:.+]] = spirv.Constant 0 : i32
+//       CHECK:   %[[CST0_2:.+]] = spirv.Constant 0 : i32
+//       CHECK:   %[[CST4:.+]] = spirv.Constant 4 : i32
+//       CHECK:   %[[S3:.+]] = spirv.IMul %[[CST4]], %[[S1]] : i32
+//       CHECK:   %[[S4:.+]] = spirv.IAdd %[[CST0_2]], %[[S3]] : i32
+//       CHECK:   %[[CST1:.+]] = spirv.Constant 1 : i32
+//       CHECK:   %[[S5:.+]] = spirv.IMul %[[CST1]], %[[S2]] : i32
+//       CHECK:   %[[S6:.+]] = spirv.IAdd %[[S4]], %[[S5]] : i32
+//       CHECK:   %[[S7:.+]] = spirv.AccessChain %[[S0]][%[[CST0_1]], %[[S6]]] : !spirv.ptr<!spirv.struct<(!spirv.array<16 x f32, stride=4> [0])>, StorageBuffer>, i32, i32
+//       CHECK:   %[[S8:.+]] = spirv.Bitcast %[[S7]] : !spirv.ptr<f32, StorageBuffer> to !spirv.ptr<vector<4xf32>, StorageBuffer>
+//       CHECK:   spirv.Store "StorageBuffer" %[[S8]], %[[ARG1]] : vector<4xf32>
+func.func @transfer_write_2d(%arg0 : memref<4x4xf32, #spirv.storage_class<StorageBuffer>>, %arg1 : vector<4xf32>) {
+  %idx_0 = arith.constant 0 : index
+  %idx_1 = arith.constant 1 : index
+  vector.transfer_write %arg1, %arg0[%idx_0, %idx_1] {in_bounds = [true]} : vector<4xf32>, memref<4x4xf32, #spirv.storage_class<StorageBuffer>>
+  return
+}
+
+} // end module



More information about the Mlir-commits mailing list