[Mlir-commits] [mlir] 2abad34 - [mlir][rocdl] Adding vector to ROCDL dialect lowering
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jun 11 07:35:26 PDT 2020
Author: jerryyin
Date: 2020-06-11T14:28:13Z
New Revision: 2abad3433f9f48cb0a103726a9af1ad79603d23d
URL: https://github.com/llvm/llvm-project/commit/2abad3433f9f48cb0a103726a9af1ad79603d23d
DIFF: https://github.com/llvm/llvm-project/commit/2abad3433f9f48cb0a103726a9af1ad79603d23d.diff
LOG: [mlir][rocdl] Adding vector to ROCDL dialect lowering
* Created the vector to ROCDL lowering pass
* The lowering pass lowers vector transferOps to rocdl mubufOps
* Added unit test and functional test
Added:
mlir/include/mlir/Conversion/VectorToROCDL/VectorToROCDL.h
mlir/lib/Conversion/VectorToROCDL/CMakeLists.txt
mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
mlir/test/Conversion/VectorToROCDL/vector-to-rocdl.mlir
mlir/test/mlir-rocm-runner/vector-transferops.mlir
Modified:
mlir/include/mlir/Conversion/Passes.td
mlir/include/mlir/InitAllPasses.h
mlir/lib/Conversion/CMakeLists.txt
mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt
mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 9b9caa46feda..d74e419dbb1a 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -301,4 +301,14 @@ def ConvertVectorToLLVM : Pass<"convert-vector-to-llvm", "ModuleOp"> {
let constructor = "mlir::createConvertVectorToLLVMPass()";
}
+//===----------------------------------------------------------------------===//
+// VectorToROCDL
+//===----------------------------------------------------------------------===//
+
+def ConvertVectorToROCDL : Pass<"convert-vector-to-rocdl", "ModuleOp"> {
+ let summary = "Lower the operations from the vector dialect into the ROCDL "
+ "dialect";
+ let constructor = "mlir::createConvertVectorToROCDLPass()";
+}
+
#endif // MLIR_CONVERSION_PASSES
diff --git a/mlir/include/mlir/Conversion/VectorToROCDL/VectorToROCDL.h b/mlir/include/mlir/Conversion/VectorToROCDL/VectorToROCDL.h
new file mode 100644
index 000000000000..660de02ee36f
--- /dev/null
+++ b/mlir/include/mlir/Conversion/VectorToROCDL/VectorToROCDL.h
@@ -0,0 +1,28 @@
+//===- VectorToROCDL.h - Convert Vector to ROCDL dialect ---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_CONVERSION_VECTORTOROCDL_VECTORTOROCDL_H_
+#define MLIR_CONVERSION_VECTORTOROCDL_VECTORTOROCDL_H_
+
+#include <memory>
+
+namespace mlir {
+class LLVMTypeConverter;
+class OwningRewritePatternList;
+class ModuleOp;
+template <typename OpT>
+class OperationPass;
+
+/// Collect a set of patterns to convert from the GPU dialect to ROCDL.
+void populateVectorToROCDLConversionPatterns(
+ LLVMTypeConverter &converter, OwningRewritePatternList &patterns);
+
+/// Create a pass to convert vector operations to the ROCDL dialect.
+std::unique_ptr<OperationPass<ModuleOp>> createConvertVectorToROCDLPass();
+
+} // namespace mlir
+#endif // MLIR_CONVERSION_VECTORTOROCDL_VECTORTOROCDL_H_
diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h
index f1b78f929664..ca5c8e0dac46 100644
--- a/mlir/include/mlir/InitAllPasses.h
+++ b/mlir/include/mlir/InitAllPasses.h
@@ -30,6 +30,7 @@
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h"
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
+#include "mlir/Conversion/VectorToROCDL/VectorToROCDL.h"
#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
#include "mlir/Dialect/Affine/Passes.h"
#include "mlir/Dialect/GPU/Passes.h"
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index e60a1dfc0e32..698dbd269b8e 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -14,5 +14,6 @@ add_subdirectory(ShapeToStandard)
add_subdirectory(SPIRVToLLVM)
add_subdirectory(StandardToLLVM)
add_subdirectory(StandardToSPIRV)
+add_subdirectory(VectorToROCDL)
add_subdirectory(VectorToLLVM)
add_subdirectory(VectorToSCF)
diff --git a/mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt b/mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt
index 38b6e1a5ea4f..3e7294d54ac2 100644
--- a/mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt
+++ b/mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt
@@ -15,4 +15,5 @@ add_mlir_conversion_library(MLIRGPUtoROCDLTransforms
MLIRROCDLIR
MLIRPass
MLIRStandardToLLVM
+ MLIRVectorToROCDL
)
diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index f81c382152d7..5707075767ed 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -15,6 +15,7 @@
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
+#include "mlir/Conversion/VectorToROCDL/VectorToROCDL.h"
#include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/GPU/Passes.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
@@ -55,6 +56,7 @@ class LowerGpuOpsToROCDLOpsPass
patterns.clear();
populateVectorToLLVMConversionPatterns(converter, patterns);
+ populateVectorToROCDLConversionPatterns(converter, patterns);
populateStdToLLVMConversionPatterns(converter, patterns);
populateGpuToROCDLConversionPatterns(converter, patterns);
LLVMConversionTarget target(getContext());
diff --git a/mlir/lib/Conversion/VectorToROCDL/CMakeLists.txt b/mlir/lib/Conversion/VectorToROCDL/CMakeLists.txt
new file mode 100644
index 000000000000..5cc1f5de2bf0
--- /dev/null
+++ b/mlir/lib/Conversion/VectorToROCDL/CMakeLists.txt
@@ -0,0 +1,19 @@
+add_mlir_conversion_library(MLIRVectorToROCDL
+ VectorToROCDL.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/VectorToROCDL
+
+ DEPENDS
+ MLIRConversionPassIncGen
+ intrinsics_gen
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ MLIRROCDLIR
+ MLIRStandardToLLVM
+ MLIRVector
+ MLIRTransforms
+ )
diff --git a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
new file mode 100644
index 000000000000..a1d483be78f3
--- /dev/null
+++ b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
@@ -0,0 +1,195 @@
+//===- VectorToROCDL.cpp - Vector to ROCDL lowering passes ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to generate ROCDLIR operations for higher-level
+// Vector operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/VectorToROCDL/VectorToROCDL.h"
+
+#include "../PassDetail.h"
+#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
+#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
+#include "mlir/Dialect/GPU/GPUDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+using namespace mlir::vector;
+
+static TransferReadOpOperandAdaptor
+getTransferOpAdapter(TransferReadOp xferOp, ArrayRef<Value> operands) {
+ return OperandAdaptor<TransferReadOp>(operands);
+}
+
+static TransferWriteOpOperandAdaptor
+getTransferOpAdapter(TransferWriteOp xferOp, ArrayRef<Value> operands) {
+ return OperandAdaptor<TransferWriteOp>(operands);
+}
+
+static LogicalResult replaceTransferOpWithMubuf(
+ ConversionPatternRewriter &rewriter, ArrayRef<Value> operands,
+ LLVMTypeConverter &typeConverter, Location loc, TransferReadOp xferOp,
+ LLVM::LLVMType &vecTy, Value &dwordConfig, Value &vindex,
+ Value &offsetSizeInBytes, Value &glc, Value &slc) {
+ rewriter.replaceOpWithNewOp<ROCDL::MubufLoadOp>(
+ xferOp, vecTy, dwordConfig, vindex, offsetSizeInBytes, glc, slc);
+ return success();
+}
+
+static LogicalResult replaceTransferOpWithMubuf(
+ ConversionPatternRewriter &rewriter, ArrayRef<Value> operands,
+ LLVMTypeConverter &typeConverter, Location loc, TransferWriteOp xferOp,
+ LLVM::LLVMType &vecTy, Value &dwordConfig, Value &vindex,
+ Value &offsetSizeInBytes, Value &glc, Value &slc) {
+ auto adaptor = TransferWriteOpOperandAdaptor(operands);
+ rewriter.replaceOpWithNewOp<ROCDL::MubufStoreOp>(xferOp, adaptor.vector(),
+ dwordConfig, vindex,
+ offsetSizeInBytes, glc, slc);
+ return success();
+}
+
+namespace {
+/// Conversion pattern that converts a 1-D vector transfer read/write.
+/// Note that this conversion pass only converts vector x2 or x4 f32
+/// types. For unsupported cases, they will fall back to the vector to
+/// llvm conversion pattern.
+template <typename ConcreteOp>
+class VectorTransferConversion : public ConvertToLLVMPattern {
+public:
+ explicit VectorTransferConversion(MLIRContext *context,
+ LLVMTypeConverter &typeConv)
+ : ConvertToLLVMPattern(ConcreteOp::getOperationName(), context,
+ typeConv) {}
+
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto xferOp = cast<ConcreteOp>(op);
+ auto adaptor = getTransferOpAdapter(xferOp, operands);
+
+ if (xferOp.getVectorType().getRank() > 1 ||
+ llvm::size(xferOp.indices()) == 0)
+ return failure();
+
+ if (!AffineMap::isMinorIdentity(xferOp.permutation_map()))
+ return failure();
+
+ // Have it handled in vector->llvm conversion pass.
+ if (!xferOp.isMaskedDim(0))
+ return failure();
+
+ auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); };
+ LLVM::LLVMType vecTy =
+ toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>();
+ unsigned vecWidth = vecTy.getVectorNumElements();
+ Location loc = op->getLoc();
+
+ // The backend result vector scalarization have trouble scalarize
+ // <1 x ty> result, exclude the x1 width from the lowering.
+ if (vecWidth != 2 && vecWidth != 4)
+ return failure();
+
+ // Obtain dataPtr and elementType from the memref.
+ MemRefType memRefType = xferOp.getMemRefType();
+ // MUBUF instruction operate only on addresspace 0(unified) or 1(global)
+ // In case of 3(LDS): fall back to vector->llvm pass
+ // In case of 5(VGPR): wrong
+ if ((memRefType.getMemorySpace() != 0) &&
+ (memRefType.getMemorySpace() != 1))
+ return failure();
+
+ // Note that the dataPtr starts at the offset address specified by
+ // indices, so no need to calculat offset size in bytes again in
+ // the MUBUF instruction.
+ Value dataPtr = getDataPtr(loc, memRefType, adaptor.memref(),
+ adaptor.indices(), rewriter, getModule());
+
+ // 1. Create and fill a <4 x i32> dwordConfig with:
+ // 1st two elements holding the address of dataPtr.
+ // 3rd element: -1.
+ // 4th element: 0x27000.
+ SmallVector<int32_t, 4> constConfigAttr{0, 0, -1, 0x27000};
+ Type i32Ty = rewriter.getIntegerType(32);
+ VectorType i32Vecx4 = VectorType::get(4, i32Ty);
+ Value constConfig = rewriter.create<LLVM::ConstantOp>(
+ loc, toLLVMTy(i32Vecx4),
+ DenseElementsAttr::get(i32Vecx4, ArrayRef<int32_t>(constConfigAttr)));
+
+ // Treat first two element of <4 x i32> as i64, and save the dataPtr
+ // to it.
+ Type i64Ty = rewriter.getIntegerType(64);
+ Value i64x2Ty = rewriter.create<LLVM::BitcastOp>(
+ loc,
+ LLVM::LLVMType::getVectorTy(
+ toLLVMTy(i64Ty).template cast<LLVM::LLVMType>(), 2),
+ constConfig);
+ Value dataPtrAsI64 = rewriter.create<LLVM::PtrToIntOp>(
+ loc, toLLVMTy(i64Ty).template cast<LLVM::LLVMType>(), dataPtr);
+ Value zero = createIndexConstant(rewriter, loc, 0);
+ Value dwordConfig = rewriter.create<LLVM::InsertElementOp>(
+ loc,
+ LLVM::LLVMType::getVectorTy(
+ toLLVMTy(i64Ty).template cast<LLVM::LLVMType>(), 2),
+ i64x2Ty, dataPtrAsI64, zero);
+ dwordConfig =
+ rewriter.create<LLVM::BitcastOp>(loc, toLLVMTy(i32Vecx4), dwordConfig);
+
+ // 2. Rewrite op as a buffer read or write.
+ Value int1False = rewriter.create<LLVM::ConstantOp>(
+ loc, toLLVMTy(rewriter.getIntegerType(1)),
+ rewriter.getIntegerAttr(rewriter.getIntegerType(1), 0));
+ Value int32Zero = rewriter.create<LLVM::ConstantOp>(
+ loc, toLLVMTy(i32Ty),
+ rewriter.getIntegerAttr(rewriter.getIntegerType(32), 0));
+ return replaceTransferOpWithMubuf(rewriter, operands, typeConverter, loc,
+ xferOp, vecTy, dwordConfig, int32Zero,
+ int32Zero, int1False, int1False);
+ }
+};
+} // end anonymous namespace
+
+void mlir::populateVectorToROCDLConversionPatterns(
+ LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
+ MLIRContext *ctx = converter.getDialect()->getContext();
+ patterns.insert<VectorTransferConversion<TransferReadOp>,
+ VectorTransferConversion<TransferWriteOp>>(ctx, converter);
+}
+
+namespace {
+struct LowerVectorToROCDLPass
+ : public ConvertVectorToROCDLBase<LowerVectorToROCDLPass> {
+ void runOnOperation() override;
+};
+} // namespace
+
+void LowerVectorToROCDLPass::runOnOperation() {
+ LLVMTypeConverter converter(&getContext());
+ OwningRewritePatternList patterns;
+
+ populateVectorToROCDLConversionPatterns(converter, patterns);
+ populateStdToLLVMConversionPatterns(converter, patterns);
+
+ LLVMConversionTarget target(getContext());
+ target.addLegalDialect<ROCDL::ROCDLDialect>();
+
+ if (failed(applyPartialConversion(getOperation(), target, patterns,
+ &converter))) {
+ signalPassFailure();
+ }
+}
+
+std::unique_ptr<OperationPass<ModuleOp>>
+mlir::createConvertVectorToROCDLPass() {
+ return std::make_unique<LowerVectorToROCDLPass>();
+}
diff --git a/mlir/test/Conversion/VectorToROCDL/vector-to-rocdl.mlir b/mlir/test/Conversion/VectorToROCDL/vector-to-rocdl.mlir
new file mode 100644
index 000000000000..1113197aa589
--- /dev/null
+++ b/mlir/test/Conversion/VectorToROCDL/vector-to-rocdl.mlir
@@ -0,0 +1,68 @@
+// RUN: mlir-opt %s -convert-vector-to-rocdl | FileCheck %s
+
+gpu.module @test_read{
+func @transfer_readx2(%A : memref<?xf32>, %base: index) -> vector<2xf32> {
+ %f0 = constant 0.0: f32
+ %f = vector.transfer_read %A[%base], %f0
+ {permutation_map = affine_map<(d0) -> (d0)>} :
+ memref<?xf32>, vector<2xf32>
+ return %f: vector<2xf32>
+}
+// CHECK-LABEL: @transfer_readx2
+// CHECK: rocdl.buffer.load {{.*}} !llvm<"<2 x float>">
+
+func @transfer_readx4(%A : memref<?xf32>, %base: index) -> vector<4xf32> {
+ %f0 = constant 0.0: f32
+ %f = vector.transfer_read %A[%base], %f0
+ {permutation_map = affine_map<(d0) -> (d0)>} :
+ memref<?xf32>, vector<4xf32>
+ return %f: vector<4xf32>
+}
+// CHECK-LABEL: @transfer_readx4
+// CHECK: rocdl.buffer.load {{.*}} !llvm<"<4 x float>">
+
+func @transfer_read_dwordConfig(%A : memref<?xf32>, %base: index) -> vector<4xf32> {
+ %f0 = constant 0.0: f32
+ %f = vector.transfer_read %A[%base], %f0
+ {permutation_map = affine_map<(d0) -> (d0)>} :
+ memref<?xf32>, vector<4xf32>
+ return %f: vector<4xf32>
+}
+// CHECK-LABEL: @transfer_read_dwordConfig
+// CHECK: %[[gep:.*]] = llvm.getelementptr {{.*}}
+// CHECK: [0, 0, -1, 159744]
+// CHECK: %[[i64:.*]] = llvm.ptrtoint %[[gep]]
+// CHECK: llvm.insertelement %[[i64]]
+}
+
+gpu.module @test_write{
+func @transfer_writex2(%A : memref<?xf32>, %B : vector<2xf32>, %base: index) {
+ vector.transfer_write %B, %A[%base]
+ {permutation_map = affine_map<(d0) -> (d0)>} :
+ vector<2xf32>, memref<?xf32>
+ return
+}
+// CHECK-LABEL: @transfer_writex2
+// CHECK: rocdl.buffer.store {{.*}} !llvm<"<2 x float>">
+
+func @transfer_writex4(%A : memref<?xf32>, %B : vector<4xf32>, %base: index) {
+ vector.transfer_write %B, %A[%base]
+ {permutation_map = affine_map<(d0) -> (d0)>} :
+ vector<4xf32>, memref<?xf32>
+ return
+}
+// CHECK-LABEL: @transfer_writex4
+// CHECK: rocdl.buffer.store {{.*}} !llvm<"<4 x float>">
+
+func @transfer_write_dwordConfig(%A : memref<?xf32>, %B : vector<2xf32>, %base: index) {
+ vector.transfer_write %B, %A[%base]
+ {permutation_map = affine_map<(d0) -> (d0)>} :
+ vector<2xf32>, memref<?xf32>
+ return
+}
+// CHECK-LABEL: @transfer_write_dwordConfig
+// CHECK: %[[gep:.*]] = llvm.getelementptr {{.*}}
+// CHECK: [0, 0, -1, 159744]
+// CHECK: %[[i64:.*]] = llvm.ptrtoint %[[gep]]
+// CHECK: llvm.insertelement %[[i64]]
+}
diff --git a/mlir/test/mlir-rocm-runner/vector-transferops.mlir b/mlir/test/mlir-rocm-runner/vector-transferops.mlir
new file mode 100644
index 000000000000..b028f91f8394
--- /dev/null
+++ b/mlir/test/mlir-rocm-runner/vector-transferops.mlir
@@ -0,0 +1,76 @@
+// RUN: mlir-rocm-runner %s --shared-libs=%rocm_wrapper_library_dir/librocm-runtime-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext --entry-point-result=void | FileCheck %s
+
+func @vectransferx2(%arg0 : memref<?xf32>, %arg1 : memref<?xf32>) {
+ %cst = constant 1 : index
+ gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %cst, %grid_y = %cst, %grid_z = %cst)
+ threads(%tx, %ty, %tz) in (%block_x = %cst, %block_y = %cst, %block_z = %cst) {
+ %f0 = constant 0.0: f32
+ %base = constant 0 : index
+ %f = vector.transfer_read %arg0[%base], %f0
+ {permutation_map = affine_map<(d0) -> (d0)>} :
+ memref<?xf32>, vector<2xf32>
+
+ %c = addf %f, %f : vector<2xf32>
+
+ %base1 = constant 1 : index
+ vector.transfer_write %c, %arg1[%base1]
+ {permutation_map = affine_map<(d0) -> (d0)>} :
+ vector<2xf32>, memref<?xf32>
+
+ gpu.terminator
+ }
+ return
+}
+
+func @vectransferx4(%arg0 : memref<?xf32>, %arg1 : memref<?xf32>) {
+ %cst = constant 1 : index
+ gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %cst, %grid_y = %cst, %grid_z = %cst)
+ threads(%tx, %ty, %tz) in (%block_x = %cst, %block_y = %cst, %block_z = %cst) {
+ %f0 = constant 0.0: f32
+ %base = constant 0 : index
+ %f = vector.transfer_read %arg0[%base], %f0
+ {permutation_map = affine_map<(d0) -> (d0)>} :
+ memref<?xf32>, vector<4xf32>
+
+ %c = addf %f, %f : vector<4xf32>
+
+ vector.transfer_write %c, %arg1[%base]
+ {permutation_map = affine_map<(d0) -> (d0)>} :
+ vector<4xf32>, memref<?xf32>
+
+ gpu.terminator
+ }
+ return
+}
+
+func @main() {
+ %cf1 = constant 1.0 : f32
+
+ %arg0 = alloc() : memref<4xf32>
+ %arg1 = alloc() : memref<4xf32>
+
+ %22 = memref_cast %arg0 : memref<4xf32> to memref<?xf32>
+ %23 = memref_cast %arg1 : memref<4xf32> to memref<?xf32>
+
+ %cast0 = memref_cast %22 : memref<?xf32> to memref<*xf32>
+ %cast1 = memref_cast %23 : memref<?xf32> to memref<*xf32>
+
+ call @mgpuMemHostRegisterFloat(%cast0) : (memref<*xf32>) -> ()
+ call @mgpuMemHostRegisterFloat(%cast1) : (memref<*xf32>) -> ()
+
+ %24 = call @mgpuMemGetDeviceMemRef1dFloat(%22) : (memref<?xf32>) -> (memref<?xf32>)
+ %26 = call @mgpuMemGetDeviceMemRef1dFloat(%23) : (memref<?xf32>) -> (memref<?xf32>)
+
+ // CHECK: [1.23, 2.46, 2.46, 1.23]
+ call @vectransferx2(%24, %26) : (memref<?xf32>, memref<?xf32>) -> ()
+ call @print_memref_f32(%cast1) : (memref<*xf32>) -> ()
+
+ // CHECK: [2.46, 2.46, 2.46, 2.46]
+ call @vectransferx4(%24, %26) : (memref<?xf32>, memref<?xf32>) -> ()
+ call @print_memref_f32(%cast1) : (memref<*xf32>) -> ()
+ return
+}
+
+func @mgpuMemHostRegisterFloat(%ptr : memref<*xf32>)
+func @mgpuMemGetDeviceMemRef1dFloat(%ptr : memref<?xf32>) -> (memref<?xf32>)
+func @print_memref_f32(%ptr : memref<*xf32>)
More information about the Mlir-commits
mailing list